# -*- coding: utf-8 -*-
"""
Created on Mon Jul  8 14:28:52 2019

@author: PezijM1
"""

###############################################################################
# import statements
###############################################################################
import os
import numpy as np
import pandas as pd
import pastas as ps

import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import datetime

import main_functions as ps_sm

###############################################################################
# script settings
###############################################################################
save_fig = 0
residuals = 0
normal_test = 0

plt.close('all')

# set latex interpreter
latex = 1
if latex == 1:
    plt.rc('text', usetex=True)
    plt.rc('font', family='serif')
    
def rmse(predictions, targets):
    '''
    function for calculating root mean square error
    '''
    return np.sqrt(((targets - predictions) ** 2).mean())

def setlabel(ax, label, loc=2, borderpad=0.2, **kwargs):
    '''
    function for annotating subplots (a, b, c....)
    Derived from:
        https://stackoverflow.com/questions/22508590/enumerate-plots-in-matplotlib-figure
    '''
    
    legend = ax.get_legend()
    if legend:
        ax.add_artist(legend)
        
    line, = ax.plot(np.NaN, np.NaN, color='none', label=label)
    
    label_legend = ax.legend(handles=[line],loc=loc,handlelength=0,handleheight=0,handletextpad=0,borderaxespad=0,borderpad=borderpad,frameon=False,**kwargs)
    label_legend.remove()
    
    ax.add_artist(label_legend)
    
    line.remove()
    
    
###############################################################################
# read data
###############################################################################

# get station metadata
station_info = ps_sm.read_sm_metadata()

# check if data is already in dictionary
try:
    data
    print('Data already in memory')
    
# if not existing...
except NameError:
    
    # read various datasets in dictionary
    print('Reading data...')
    
    # initialize dictionary
    data = dict()
    data['sm'] = ps_sm.read_sm_data(station_info)    
    data['prec'] = ps_sm.read_prec_data()
    data['et_ref'] = ps_sm.read_et_ref_data_csv()
    data['smap'] = ps_sm.read_sm_smap()
    
###############################################################################
# process data
###############################################################################
  
# define empty EVP list
EVP = []
prec_A = []
prec_a = []
ET_A = []
ET_a = []

rmse_obs = []
rmse_sim = []

rmse_SMAP = []
rmse_sim_loc = []
insit_loc = [1, 3, 6, 8, 9, 10, 12, 13, 14, 15, 16]

# define time step
t = np.arange(0.1, 110, 0.1)

# get station metadata
station_info = ps_sm.read_sm_metadata()

# for every station...
for ix, row in station_info[:].iterrows():

    # get station name
    stat_name = row['station_name']
    print(stat_name)
    
    # slice station data
    data_stat = dict()
    data_stat['sm5'] = data['sm'][stat_name]['5 cm VWC [m^3/m^3]']    
    data_stat['prec'] = data['prec'][stat_name]
    data_stat['et_ref'] = data['et_ref'][stat_name]
    data_stat['smap'] = data['smap'][stat_name]
    
    # slice time series
    data_stat_sliced = ps_sm.slice_time_series(data_stat)
    
    
    # define observational series
    ob_series = data_stat_sliced['smap']
    
    # create a Pastas time series model object
    ml = ps.Model(oseries=ob_series,
                  name=stat_name,
                  log_level='INFO',
                  noisemodel=True)

    ## define stress series   
    # precipitation
    sm_prec = ps.StressModel(stress=data_stat_sliced['prec'],
                             rfunc=ps.Exponential,
                             name='prec',
                             up=True)
    
    # 
    sm_et = ps.StressModel(stress=data_stat_sliced['et_ref'],
                           rfunc=ps.Exponential,
                           name='ref_ET',
                           up=False)
    
    # add stress series
    ml.add_stressmodel(sm_prec)
    ml.add_stressmodel(sm_et)
  
    ml.solve(solver=ps.LeastSquares,
                tmin='2017-01-01',
                tmax='2019-01-01',
                report=True,                fit_constant=False)

    obs = ml.observations()
    sim = ml.simulate()
    
    sim2 = ml.simulate(tmin='2016-01-01',
                       tmax='2017-01-01')
    
###############################################################################
# calculate rmse
###############################################################################
    ix_sim2 = sim2.index
    obs_crop = ob_series.reindex(ix_sim2)
    
    sim2_crop = sim2.reindex()
    
    insitu_crop = data_stat_sliced['sm5'].loc[ix_sim2]
    
    rmse_sim.append(rmse(sim2, insitu_crop))
    rmse_obs.append(rmse(obs_crop, insitu_crop))
    
   
    rmse_SMAP.append(rmse(sim2, obs_crop))
###############################################################################
# plot data
###############################################################################
    prec = ml.get_stress('prec')
    et = ml.get_stress('ref_ET')

    # open plot objects
    fig, ax = plt.subplots(figsize=(6, 4),
                           nrows=3,
                           sharex=True,
                           gridspec_kw={'height_ratios': [5, 2, 2]})
    
    # plot SMAP observations
    ax[0].scatter(ob_series.index, ob_series.values,
            label='SMAP',
            color='#7570b3',
            s=1,
            zorder=1)
    
    # plot pastas training
    ax[0].plot(sim.index, sim.values,
            label='TFN training',
            color='#e6ab02',
            zorder=2)
    
    # plot pastas training
    ax[0].plot(sim2.index, sim2.values,
            label='TFN prediction',
            color='#1b9e77',
            zorder=3)
    
    ax[0].plot(insitu_crop.index, insitu_crop.values,
            label='In situ',
            color='#d95f02')
    
    
    ax[1].bar(prec.index, prec.values)
    ax[2].bar(et.index, et.values)  
        
    # figure settings
    ax[0].set_xlim(datetime.date(2016, 1, 1), datetime.date(2019, 1, 1))
    ax[0].set_ylim(0, 0.8)
    ax[1].set_ylim(0, 25)
    ax[2].set_ylim(0, 6)
    ax[0].grid()
    ax[1].grid()
    ax[2].grid()
    ax[0].legend(fancybox=False,
              framealpha=1,
              ncol=4,
              fontsize=9,
              labelspacing=1)
    
    
    # labels
    ax[0].set_title('Station ' + str(int(stat_name[-2:])))
    ax[2].set_xlabel('Date [year-month]')
    ax[0].set_ylabel('Volumetric moisture \n content \n [$m^3 m^{-3}$]')
    ax[1].set_ylabel('P \n[mm]')
    ax[2].set_ylabel('ET \n[mm]')
    fig.autofmt_xdate()
    fig.tight_layout()
    
    ax[0].set_ylim(0, 1)
    ax[1].set_ylim(0, 30)
    ax[1].yaxis.set_ticks([0, 30])    
    ax[2].yaxis.set_ticks([0, 6])
    
    setlabel(ax[0], '(A)')
    setlabel(ax[1], '(B)')
    setlabel(ax[2], '(C)')
    
    # get EVP
    EVP.append(ml.stats.evp())
    
    if save_fig == 1:
        # save to fig
        
        save_name = os.path.join('figures',
                                 'paper',
                                 '3panels',
                                 'smap_TFN_' + stat_name)

        fig.savefig(save_name + '.png',
                    dpi=300,
                    bbox_inches='tight')

        fig.savefig(save_name + '.pdf',
                    dpi=300,
                    bbox_inches='tight')
        
    plt.close(fig)
#%%
rmse_sim_loc = [rmse_sim[i] for i in insit_loc]

ix_plot = [x+1 for x in insit_loc]

    
fig, ax = plt.subplots(figsize=(5, 6))

ax.scatter(rmse_SMAP, np.arange(1,21), zorder=93, label="SMAP validation")
ax.scatter(rmse_sim_loc, ix_plot, zorder=95, label="Field validation",
           marker="s")

ax.set_xlim(0, 0.25)
ax.set_ylim(0, 21)

ax.set_yticks(np.arange(1,21))

for i in np.arange(1,21):
    ax.hlines(i, 0, 0.25,
              linestyles="dashed", 
              linewidths=0.5,
              zorder=2)

ax.invert_yaxis()

ax.legend(fancybox=False,
              framealpha=1,
              loc="lower right")

ax.set_xlabel('RMSE [$m^3\ m^{-3}$]')
ax.set_ylabel("Location [-]")


#fig.savefig("TFN_SMAP_insitu_RMSE_2016.png",
#            dpi=300,
#            bbox_inches="tight")
#
#fig.savefig("TFN_SMAP_insitu_RMSE_2016.pdf",
#            dpi=300,
#            bbox_inches="tight")