# -*- coding: utf-8 -*-
"""
Created on Mon Jul  8 14:28:52 2019

@author: PezijM1
"""

###############################################################################
# import statements
###############################################################################
import os
import numpy as np
import pandas as pd
import pastas as ps

import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import datetime

import main_functions as ps_sm

###############################################################################
# script settings
###############################################################################
save_fig = 0
residuals = 0
normal_test = 0

plt.close("all")

# set latex interpreter
latex = 1
if latex == 1:
    plt.rc("text", usetex=True)
    plt.rc("font", family="serif")

def rmse(predictions, targets):
    """
    function for calculating root mean square error
    """
    return np.sqrt(((targets - predictions) ** 2).mean())

###############################################################################
# read data
###############################################################################

# get station metadata
station_info = ps_sm.read_sm_metadata()

# check if data is already in dictionary
try:
    data
    print("Data already in memory")

# if not existing...
except NameError:

    # read various datasets in dictionary
    print("Reading data...")

    # initialize dictionary
    data = dict()
    data["sm"] = ps_sm.read_sm_data(station_info)
    data["prec"] = ps_sm.read_prec_data()
    data["et_ref"] = ps_sm.read_et_ref_data_csv()
    data["smap"] = ps_sm.read_sm_smap()

###############################################################################
# process data
###############################################################################

# define time step
t = np.arange(0.1, 110, 0.1)

# get station metadata
station_info = ps_sm.read_sm_metadata()

## drop stations
station_info = station_info[station_info.station_name!='ITCSM_01']
station_info = station_info[station_info.station_name!='ITCSM_03']
station_info = station_info[station_info.station_name!='ITCSM_05']
station_info = station_info[station_info.station_name!='ITCSM_06']
station_info = station_info[station_info.station_name!='ITCSM_08']
station_info = station_info[station_info.station_name!='ITCSM_12']
station_info = station_info[station_info.station_name!='ITCSM_18']
station_info = station_info[station_info.station_name!='ITCSM_19']
station_info = station_info[station_info.station_name!='ITCSM_20']
#

sensitivity = dict()

sensitivity['Summer'] = ("2016-04-01", "2016-10-01")
sensitivity['Winter'] = ("2016-10-01", "2017-04-01")
sensitivity['2016'] =  ("2016-01-01", "2017-01-01")
sensitivity['2017'] =  ("2017-01-01", "2018-01-01")
sensitivity['2016-2017'] = ("2016-01-01", "2018-01-01")

rmse_dict = dict()

# loop over sensitivity items
for key, value in sensitivity.items():
    # define empty EVP list
    EVP = []
    prec_A = []
    prec_a = []
    ET_A = []
    ET_a = []
    
    rmse_obs = []
    rmse_sim = []
    rmse_smap = []
    print(value)
    
    # for every station...
    for ix, row in station_info[:].iterrows():
    
        # get station name
        stat_name = row["station_name"]
        print(stat_name)
    
        # slice station data
        data_stat = dict()
        data_stat["sm5"] = data["sm"][stat_name]["5 cm VWC [m^3/m^3]"]
        data_stat["prec"] = data["prec"][stat_name]
        data_stat["et_ref"] = data["et_ref"][stat_name]
        data_stat["smap"] = data["smap"][stat_name]
    
        # slice time series
        data_stat_sliced = ps_sm.slice_time_series(data_stat)
    
        # define observational series
        ob_series = data_stat_sliced["smap"]
    
        # create a Pastas time series model object
        ml = ps.Model(oseries=ob_series, 
                      name=stat_name, 
                      log_level="ERROR", 
                      noisemodel=True)
    
        ## define stress series
        # precipitation
        sm_prec = ps.StressModel(stress=data_stat_sliced["prec"], 
                                 rfunc=ps.Exponential, 
                                 name="prec", 
                                 up=True)
        
        # evapotranspiration
        sm_et = ps.StressModel(stress=data_stat_sliced["et_ref"], 
                               rfunc=ps.Exponential, 
                               name="ref_ET", 
                               up=False)
    
        # add stress series
        ml.add_stressmodel(sm_prec)
        ml.add_stressmodel(sm_et)
    
        ml.solve(solver=ps.LeastSquares,
                 tmin=value[0],
                 tmax=value[1],
                 report=True,
                 fit_constant=False)
    
        obs = ml.observations()
        sim = ml.simulate()
    
        sim2 = ml.simulate(tmin="2018-01-01", tmax="2019-01-01")
    
        ###############################################################################
        # calculate rmse
        ###############################################################################
    
        ix_sim2 = sim2.index
        obs_crop = ob_series.reindex(ix_sim2)
        insitu_crop = data_stat_sliced["sm5"].loc[ix_sim2]
    
        rmse_sim.append(rmse(sim2, insitu_crop))
        rmse_obs.append(rmse(obs_crop, insitu_crop))
        
        # crop 
        obs_crop = obs_crop.dropna()
        sim2_crop = sim2.reindex(obs_crop.index)
        rmse_smap.append(rmse(sim2_crop, obs_crop))
        ###############################################################################
        # plot data
        ###############################################################################
    
        prec = ml.get_stress("prec")
        et = ml.get_stress("ref_ET")
    
    
    
        fig, ax = ps_sm.plot_1panels(ob_series, 
                                     sim, 
                                     sim2, 
                                     insitu_crop, 
                                     prec, 
                                     et,
                                     stat_name)
    
        # get EVP
        EVP.append(ml.stats.evp())
    
        if save_fig == 1:
            # save to fig
            dir_name = os.path.join("figures", "paper", "sens", key)
            if not os.path.exists(dir_name):
                os.makedirs(dir_name)
            
            save_name =  os.path.join(dir_name, "smap_TFN_" + stat_name)
    
            fig.savefig(save_name + ".png", dpi=300, bbox_inches="tight")
            fig.savefig(save_name + ".pdf", dpi=300, bbox_inches="tight")
    
   
        plt.close(fig)

    rmse_dict[key] = rmse_smap
    del ml, rmse_sim

#%%
index = []
d_data = []

for i, (key, val) in enumerate(rmse_dict.items()):
    index.append(key)
    d_data.append(val)
    
#%%
fig, ax = plt.subplots(figsize=(6,3))

bp = ax.boxplot(d_data,
                showfliers=False,
               patch_artist=True)
    
# change outline color, fill color and linewidth of the boxes
for box in bp['boxes']:
    
    # change outline color
    box.set(color='k',
           linewidth=1)
    
    # change fill color
    box.set(facecolor='#1b9e77')
            
## change color and linewidth of the whiskers
for whisker in bp['whiskers']:
    whisker.set(color='k', 
                linewidth=1)
    
## change color and linewidth of the caps
for cap in bp['caps']:
    cap.set(color='k', 
            linewidth=1)
    
## change color and linewidth of the medians
for median in bp['medians']:
    median.set(color='k', 
               linewidth=1)
    
## change the style of fliers and their fill
for flier in bp['fliers']:
    flier.set(marker='o', 
              color='#e7298a', 
              alpha=0.5)
    
## create ustom horizontal axis labels
cus_index = index.copy()
cus_index[0] = "Summer \n 2016"
cus_index[1] = "Winter \n 2016-2017"
ax.set_xticklabels(cus_index)

## Remove top axes and right axes ticks
ax.get_xaxis().tick_bottom()
ax.get_yaxis().tick_left()

# set vertical axis limits
ax.set_ylim(0, 0.3)

# set vertical axis label
ax.set_ylabel('RMSE [$m^3 m^{-3}$]')

#fig.savefig('sensitivity_TFN.pdf',
#            bbox_inches='tight',
#            dpi=300)