### Created by Jaho Koo, IHE Delft, TU Delft, K-water

import os
import pandas as pd
import time
import torch
from scipy.spatial.distance import cdist
from sklearn.cluster import KMeans
import numpy as np
from scipy.spatial.distance import pdist, squareform
from scipy.optimize import minimize
from itertools import combinations
from tqdm import tqdm



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

result_path = "D:\\"
input_path = "D:\\"
model_path = "D:\\"


def scenario_reduce_forward(x, cdn=None, cadd=1, cred=0, w=None, dist="energy", p=1):
    xn, _ = x.shape
    if cdn is None:
        cdn = int(np.floor(np.log(xn)))
    if w is None:
        w = np.ones(xn) / xn
    w = w / np.sum(np.abs(w))

    distx = squareform(pdist(x)) ** p

    if dist == "energy":
        b = 2 * (distx @ w)
        def get_val(subs, detailed=False, initial_weights=None):
            def obj_fun(ww):
                val = ww @ b[subs] - ww @ distx[np.ix_(subs, subs)] @ ww
                if abs(val) < 1e-4:  # 조기 종료 조건
                    return 0
                return val

            if initial_weights is None:
                initial_weights = np.ones(len(subs)) / len(subs)

            cons = ({'type': 'eq', 'fun': lambda ww: np.sum(ww) - 1})
            res = minimize(obj_fun, initial_weights, method='SLSQP',
                           constraints=cons, bounds=[(0, 1) for _ in range(len(subs))],
                           options={'maxiter': 100, 'ftol': 1e-3, 'disp': False})
            if detailed:
                return (res.fun ** (1 / p), res.x)
            else:
                return res.fun ** (1 / p)
    else:  # wasserstein
        def get_val(subs, detailed=False):
            J = np.setdiff1d(range(xn), subs)
            if len(subs) == 1:
                ji = np.repeat(subs, len(J))
            else:
                ji = subs[np.argsort(subs)[np.argmin(distx[np.ix_(J, subs)], axis=1)]]
            wr = w.copy()
            # print(ji)
            for i, j in enumerate(ji):
                wr[j] += w[J[i]]
            www = wr[subs]
            if len(subs) == 1:
                tmin = distx[J, subs[0]]
            else:
                tmin = np.min(distx[np.ix_(J, subs)], axis=1)
            if detailed:
                return (np.sum(w[J] * tmin) ** (1 / p), www)
            else:
                return np.sum(w[J] * tmin) ** (1 / p)

    act_index = np.array([], dtype=int)
    detail = []

    while len(act_index) < cdn:
        remaining = cdn - len(act_index)
        add_count = min(cadd, remaining)

        if len(act_index) == 0:
            tmp_index = np.arange(xn)
        else:
            tmp_index = np.setdiff1d(np.arange(xn), act_index)

        cbm = np.array(list(combinations(tmp_index, add_count)))
        ccbm = np.column_stack([np.tile(act_index, (len(cbm), 1)), cbm])

        dcbm = np.array([get_val(comb) for comb in ccbm])
        best_index = np.argmin(dcbm)
        act_index = ccbm[best_index]

        act_out = get_val(act_index, detailed=True)
        detail.append({'subset': act_index, 'weights': act_out[1], 'opt': act_out[0]})

        if cred > 0 and len(act_index) > cred:
            remove_count = min(cred, len(act_index) - 1)  # Ensure we don't remove all scenarios
            rcbm = np.array(list(combinations(act_index, len(act_index) - remove_count)))
            rdcbm = np.array([get_val(comb) for comb in rcbm])
            act_index = rcbm[np.argmin(rdcbm)]

            act_out = get_val(act_index, detailed=True)
            detail.append({'subset': act_index, 'weights': act_out[1], 'opt': act_out[0]})

    act_index = np.sort(act_index)
    act_out = get_val(act_index, detailed=True)
    return x[act_index], act_out[1]


def scenario_reduce_k_median(scenarios, reduced_num, max_iter=1000):
    n, d = scenarios.shape

    kmeans = KMeans(n_clusters=reduced_num, n_init=10, max_iter=300)
    labels = kmeans.fit_predict(scenarios)

    for _ in range(max_iter):
        medians = np.array([np.median(scenarios[labels == i], axis=0) for i in range(reduced_num)])
        distances = cdist(scenarios, medians, metric='cityblock')
        new_labels = np.argmin(distances, axis=1)
        if np.array_equal(labels, new_labels):
            break
        labels = new_labels

    final_means = np.array([np.median(scenarios[labels == i], axis=0) for i in range(reduced_num)])
    reduced_scenarios = np.zeros((reduced_num, d))
    for i in range(reduced_num):
        cluster_scenarios = scenarios[labels == i]
        if len(cluster_scenarios) > 0:
            cluster_distances = cdist(cluster_scenarios, [final_means[i]], metric='cityblock').ravel()
            closest_idx = np.argmin(cluster_distances)
            reduced_scenarios[i] = cluster_scenarios[closest_idx]
        else:
            all_distances = cdist(scenarios, [final_means[i]], metric='cityblock').ravel()
            closest_idx = np.argmin(all_distances)
            reduced_scenarios[i] = scenarios[closest_idx]
    probabilities = np.array([(labels == i).sum() / n for i in range(reduced_num)])

    return reduced_scenarios, probabilities


def scenario_reduce_manhattan(scenarios, reduced_num, max_iter=1000, w=None):
    n, d = scenarios.shape
    kmeans = KMeans(n_clusters=reduced_num, n_init=10, max_iter=300)
    if w == None:
        labels = kmeans.fit(scenarios)
    else:
        labels = kmeans.fit(scenarios, sample_weight=w)

    for _ in range(max_iter):
        means = np.array([np.mean(scenarios[labels == i], axis=0) for i in range(reduced_num)])
        distances = cdist(scenarios, means, metric='cityblock')  # cityblock은 manhattan distance
        new_labels = np.argmin(distances, axis=1)
        if np.array_equal(labels, new_labels):
            break
        labels = new_labels

    final_means = np.array([np.median(scenarios[labels == i], axis=0) for i in range(reduced_num)])
    reduced_scenarios = np.zeros((reduced_num, d))
    for i in range(reduced_num):
        cluster_scenarios = scenarios[labels == i]
        if len(cluster_scenarios) > 0:
            cluster_distances = cdist(cluster_scenarios, [final_means[i]], metric='cityblock').ravel()
            closest_idx = np.argmin(cluster_distances)
            reduced_scenarios[i] = cluster_scenarios[closest_idx]
        else:
            all_distances = cdist(scenarios, [final_means[i]], metric='cityblock').ravel()
            closest_idx = np.argmin(all_distances)
            reduced_scenarios[i] = scenarios[closest_idx]

    probabilities = np.array([(labels == i).sum() / n for i in range(reduced_num)])

    return reduced_scenarios, probabilities



def scenario_reduce_k_mean(scenarios, reduced_num, w=None):
    kmeans = KMeans(n_clusters=reduced_num, random_state=42)
    if w == None:
        kmeans.fit(scenarios)
    else:
        kmeans.fit(scenarios, sample_weight=w)
    distances = cdist(scenarios, kmeans.cluster_centers_, metric='euclidean')
    closest_scenario_indices = distances.argmin(axis=0)
    reduced_scenarios = scenarios[closest_scenario_indices]

    if w == None:
        cluster_sizes = np.bincount(kmeans.labels_)
        reduced_probabilities = cluster_sizes / len(scenarios)
    else:
        reduced_probabilities = np.zeros(reduced_num)
        for i, label in enumerate(kmeans.labels_):
            reduced_probabilities[label] += w[i]

    return reduced_scenarios, reduced_probabilities


def Scenarios_1S_(predictions, F, T_scenarios = 10, method='wasserstein'):
    if method == 'wasserstein':
        selected_scenarios, probabilities = scenario_reduce_forward(predictions, cdn=T_scenarios, cadd=1, cred=0, dist='wassersetin')
    elif method == 'energy':
        selected_scenarios, probabilities = scenario_reduce_forward(predictions, cdn=T_scenarios, cadd=1, cred=0, dist='energy')
    elif method == 'kmedian':
        selected_scenarios, probabilities = scenario_reduce_k_median(scenarios=predictions, reduced_num=T_scenarios)
    elif method == 'kmean':
        selected_scenarios, probabilities = scenario_reduce_k_mean(scenarios=predictions, reduced_num=T_scenarios)
    else:
        print("No that kind of method")
    sc = pd.DataFrame(selected_scenarios, columns=[i for i in range(F)])
    sc['p'] = probabilities
    return sc