import json
import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from collections import defaultdict
from sklearn.svm import OneClassSVM
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
import shutup; shutup.please()
from utils import trim_columns_for_each_experiment

pd.options.display.max_columns = 500
pd.options.display.width = 0

how_many = 200
np.random.seed(42)

META_COLS =  ['datetime', 'longitude', 'latitude', 'safety_region', 'geometry']
DATA_COLS = ['svf-25m', 'ws_loc', 'wd_loc', 'ws_loc_6h', 'wd_loc_6h', 'BG2017', 'trees-pct', 'trees-height', 'spi1', 'spi3', 'spi6', 'top10nl_Height', 'dist-top10nl_waterbody_line', 'dist-top10nl_waterbody_surface', 'trees_pixel', 'trees_buff', 'closest_tree', 'avg_dist']

# The gamma parameter decides that how much curvature we want in a decision boundary.
GAMMAS = np.round(np.random.uniform(low=0.1, high=3, size=how_many), decimals=2)
print("\nGammas: ", GAMMAS)

# The nus decide how many outliers we expect to encounter in our training data
NUS = np.round(np.random.uniform(low=0.01, high=0.1, size=how_many), decimals=2)
print("\nNus: ", NUS)

TRAINING_SIZE = 0.7

RANDOM_STATES = np.random.randint(low=1, high=how_many, size=how_many)
print("\nRandom states: ", RANDOM_STATES)


# ################
# # Main program #
# ################

spaces = [1000, 5000]

lmonname = ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"]

folder_name = "four_experiments"
experiments = [("1.4", "landuse-cobra"), ("1.4", "no-landuse-cobra"), ("1.4", "landuse-no-cobra"), ("1.4", "no-landuse-no-cobra") ]

dic_models_per_experiment = defaultdict(dict)
for experiment in experiments:
    features_version = experiment[0]
    name_exp = experiment[1]

    path_in_dmg = r"../data/features/Stormdamage_v1.4.csv"
    path_in_grid = r"../data/grid/{0}m/Grid_feat_{0}m_{1}.csv"
    path_ou_csv = r"../data/predictions/csv/four_experiments/{2}/{0}m_ensemble_boot/Grid_pred_{0}m_boot_{1}.csv"
    path_ou_param = r"../data/params/{0}"

    dic_suptitle = {"landuse-cobra":"v1.4 (Exp1: incl. land use and Cobra data)".format(features_version),
                    "no-landuse-cobra": "v1.4 (Exp2: no land use, incl. Cobra data)".format(features_version),
                    "landuse-no-cobra": "v1.4 (Exp3: incl. land use, no Cobra data)".format(features_version),
                    "no-landuse-no-cobra":"v1.4 (Exp4: no land use, no Cobra data)".format(features_version)}


    # Reading and pre-processing data
    # -------------------------------------------------------------
    print("\nReading data and sanity checks")
    print("-" * 80)
    df_dmg = pd.read_csv(path_in_dmg, sep=";", header=0, index_col=["rowid"], infer_datetime_format=True, parse_dates=["datetime"])
    print("Original shape: ", df_dmg.shape)
    df_dmg = df_dmg.mask(df_dmg["svf-25m"] < 0)
    print("Clean SVF: ", df_dmg.shape)
    df_dmg = df_dmg.dropna(axis=0) # Drops observations from 2023, since we have no SPI for them
    print("Drops rows with NaN: ", df_dmg.shape)
    print("Years in the train/test dataset: ", np.unique(df_dmg["datetime"].dt.year))

    df_dmg = trim_columns_for_each_experiment(name_exp, df_dmg)

    DATA_COLS_EXP = df_dmg.columns[~df_dmg.columns.isin(META_COLS)]
    print("Experiment: \t\t{0}".format(name_exp))
    print("Training columns: ", DATA_COLS_EXP.tolist())
    print("-" * 80)
    print()

    groups = df_dmg.groupby(df_dmg["datetime"].dt.month)
    dic_mon = defaultdict(object)
    for name, group in groups:
        dic_mon[name] = group

    ltotal = []
    for key in sorted(dic_mon.keys()):
        print("\nProcessing: ", lmonname[key-1])
        df = dic_mon[key]

        # Separating data from metadata and scaling
        scaler = MinMaxScaler()
        arr_meta = df[META_COLS]
        arr_data = df[DATA_COLS_EXP].to_numpy()
        arr_scaled = scaler.fit_transform(arr_data)

        dic_board = defaultdict(list)

        # Training all the models in the ensemble with the selected randomized parameters
        dic_models = defaultdict(list)
        filename = path_ou_param.format("ensemble_params_boot_permon_v{0}_{1}m_{2}.txt".format(features_version, spaces[0], name_exp))
        with open(filename, "w", newline="") as w:
            for i in range(how_many):
                key = (GAMMAS[i], NUS[i])
                clf = OneClassSVM(kernel='rbf', gamma=GAMMAS[i], nu=NUS[i], shrinking=True)
                xtrain, xtest = train_test_split(arr_scaled, train_size=TRAINING_SIZE, random_state=RANDOM_STATES[i])
                # print(i, "Shape of training and testing: ", xtrain.shape, xtest.shape)
                dic_models[key].append(clf.fit(xtrain))

                dic_params = clf.get_params()
                dic_params['features_version'] = features_version
                dic_params['spaces'] = spaces[0]
                dic_params['n_support'] = clf.n_support_.tolist()
                dic_params['idx_support'] = clf.support_.tolist()
                dic_params['fit_status'] = clf.fit_status_
                dic_params['random_state'] = RANDOM_STATES[i].tolist()

                w.write(str(dic_params))
                w.write("\n")

        dic_models_per_experiment[name_exp] = dic_models

# Applying the ensemble of models to the geographic space
sd = datetime.datetime(2022, 2, 18)
ed = datetime.datetime(2022, 2, 19)
date_range = pd.date_range(start=sd, end=ed, freq="H")[:-1]

for experiment in experiments:
    print("\n\nRUNNING: ", experiment)
    features_version = experiment[0]
    name_exp = experiment[1]
    for s in spaces:

        dates_not_found = []
        for date in date_range:
            try:
                print("Processing: ", date)
                # Reading each of the hourly frames and its features
                date_str = date.strftime("%Y-%m-%d_%H")
                path_cur = path_in_grid.format(s, date.strftime("%Y-%m-%d_%H"))
                print("Current path: ", path_cur)

                # Reading and pre-processing data
                # -------------------------------------------------------------
                print("\nReading data and sanity checks")
                print("-" * 80)
                df_cur = pd.read_csv(path_cur, sep=";", header=0, index_col=0)
                print("Original shape: ", df_cur.shape)
                df_cur = df_cur.mask(df_cur["svf-25m"] < 0)
                print("Clean SVF: ", df_cur.shape)
                df_cur = df_cur.dropna(axis=0)  # Drops observations from 2023, since we have no SPI for them
                print("Drops rows with NaN: ", df_cur.shape)

                df_cur.columns = [*df_cur.columns[:-1], 'geometry']
                # df_cur = df_cur.loc[:, df_cur.columns != 'BG2017']
                df_cur['bomen-pct'].fillna(0, inplace=True)
                df_cur['bomen-hoogte'].fillna(0, inplace=True)
                df_cur['spi1'].fillna(0, inplace=True)
                df_cur['spi3'].fillna(0, inplace=True)
                df_cur['spi6'].fillna(0, inplace=True)
                df_cur['svf-25m'].fillna(-1, inplace=True)

                df_cur = trim_columns_for_each_experiment(name_exp, df_cur)

                # Separating data from metadata and scaling
                scaler = MinMaxScaler()
                arr_meta = df_cur[["datetime", "longitude", "latitude", "geometry"]]
                arr_data = df_cur.iloc[:, 3:-1].to_numpy()
                xtest_cur = scaler.fit_transform(arr_data)

                # Predicting and piling up predictions in a dataframe
                i = 0
                dic_models = dic_models_per_experiment[name_exp]
                for key in sorted(dic_models.keys()):
                    colname = "pred_mod{0}".format(str(i).zfill(2))
                    model = dic_models[key][0]
                    pred_model = model.predict(xtest_cur)
                    arr_meta[colname] = pred_model
                    # print("\nModel: ", model)
                    # print("Uniques: ", np.unique(pred_model, return_counts=True))
                    i += 1

                path_ou_cur = path_ou_csv.format( s, date_str, name_exp)
                print("Writing in: ", path_ou_cur)
                arr_meta.to_csv(path_ou_cur, sep=";", header=True, index=True, index_label="rowid")

            except FileNotFoundError:
                dates_not_found.append(date)

        print("\tDates not found: ", np.unique(dates_not_found), len(np.unique(dates_not_found)))
#
