### Created by Jaho Koo, IHE Delft, TU Delft, K-water

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def get_device():
    return device

class MCDropoutBNN(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate=0.1, activation_fn=nn.ReLU):
        super(MCDropoutBNN, self).__init__()
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.output_dim = output_dim
        self.dropout_rate = dropout_rate
        self.activation_fn = activation_fn

        layers = []
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(self.activation_fn())
            layers.append(nn.Dropout(dropout_rate))
            prev_dim = hidden_dim
        layers.append(nn.Linear(prev_dim, output_dim))

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

    def predict(self, x, num_samples=100):
        self.train()  # Enable dropout during inference
        with torch.no_grad():
            predictions = torch.stack([self.forward(x) for _ in range(num_samples)])
        return predictions.mean(dim=0), predictions.std(dim=0)

    def predict_scenarios(self, x, num_scenarios=1000, ST=10):
        self.train()

        scenarios = []
        with torch.no_grad():
            for _ in range(num_scenarios):
                scenario = []
                for _ in range(ST):
                    output = self.forward(x)
                    scenario.append(output.detach().cpu().numpy())
                scenarios.append(np.mean(scenario, axis=0))
        return np.array(scenarios)



def get_activation(name):
    return {
        'relu': nn.ReLU,
        'tanh': nn.Tanh,
        'sigmoid': nn.Sigmoid
    }[name]


class MSELoss_(nn.Module):
    def __init__(self, lambda_reg=0.00):
        super(MSELoss_, self).__init__()
        self.mse = nn.MSELoss()
        # self.model = model
        self.lambda_reg = lambda_reg

    def forward(self, outputs, targets, model):
        mse_loss = self.mse(outputs, targets)
        l2_loss = 0
        for param in model.parameters():
            l2_loss += torch.sum(param ** 2)
        return mse_loss + self.lambda_reg * l2_loss


def train_model_mse(model, X_train, y_train, epochs, lr, batch_size, lambda_reg=0.0):
    criterion = MSELoss_(lambda_reg=lambda_reg)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    dataset = torch.utils.data.TensorDataset(X_train, y_train)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

    for epoch in range(epochs):
        model.train()
        for batch_X, batch_y in dataloader:
            optimizer.zero_grad()
            outputs = model(batch_X)
            loss = criterion(outputs, batch_y, model)
            loss.backward()
            optimizer.step()

    return model


class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.0001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True



def train_model_mse_early_stop(model, X_train, y_train, X_val, y_val, epochs, lr, batch_size, lambda_reg=0.0):
    criterion = MSELoss_(lambda_reg=lambda_reg)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_dataset = TensorDataset(X_train, y_train)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    val_dataset = TensorDataset(X_val, y_val)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    early_stopping = EarlyStopping(patience=20, min_delta=0.001)

    for epoch in range(epochs):
        model.train()
        train_losses = []
        for batch_X, batch_y in train_dataloader:
            optimizer.zero_grad()
            outputs = model(batch_X)
            loss = criterion(outputs, batch_y, model)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

        model.eval()
        val_losses = []
        with torch.no_grad():
            for batch_X, batch_y in val_dataloader:
                val_outputs = model(batch_X)
                val_loss = criterion(val_outputs, batch_y, model)
                val_losses.append(val_loss.item())

        sum_train_loss = sum(train_losses)
        sum_val_loss = sum(val_losses)

        early_stopping(sum_val_loss)
        if early_stopping.early_stop and epoch > 250:
            print(f"Early stopping {epoch+1}/{epochs}, sum train loss {sum_train_loss}, sum val loss {sum_val_loss}")
            break

    return model


def generate_uncertain_scenarios(model, X_input, num_scenarios=100):
    model.eval()
    with torch.no_grad():
        mean_pred, std_pred = model.predict(X_input, num_samples=num_scenarios)
        scenarios = torch.normal(mean_pred.unsqueeze(0).repeat(num_scenarios, 1, 1),
                                 std_pred.unsqueeze(0).repeat(num_scenarios, 1, 1))

    return scenarios, mean_pred, std_pred

def convert_torch_tensor(X):
    return torch.tensor(X, dtype=torch.float32).to(device)
