import datetime
import rioxarray
import numpy as np
import xarray as xr
import pandas as pd
import geopandas as gpd
import cartopy.crs as ccrs
from utils import maximize
import matplotlib.pyplot as plt
from cartopy.io.shapereader import Reader
from geocube.api.core import make_geocube

pd.options.display.max_columns = 500
pd.options.display.width = 0

################
# Main program #
################

spaces = [1000, 5000]
proj4str = "+proj=sterea +lat_0=52.15616055555555 +lon_0=5.38763888888889 +k=0.9999079 +x_0=155000 +y_0=463000 +ellps=bessel +towgs84=565.4171,50.3319,465.5524,-0.398957388243134,0.343987817378283,-1.87740163998045,4.0725 +units=m +no_defs"

experiments = [("1.4", "landuse-cobra"), ("1.4", "no-landuse-cobra"), ("1.4", "landuse-no-cobra"), ("1.4", "no-landuse-no-cobra") ]
folder_name = "four_experiments"

sd = datetime.datetime(2022, 2, 18)
ed = datetime.datetime(2022, 2, 19)
date_range = pd.date_range(start=sd, end=ed, freq="H")[:-1]

# Pre-loading some map elements
projection = ccrs.epsg(28992)
xmin, xmax, ymin, ymax = [130000, 200000, 357000, 427000]
bbox_full = (xmin, ymin, xmax, ymax)

for experiment in experiments:
    features_version = experiment[0]
    name_exp = experiment[1]

    path_in_csv = r"../data/predictions/csv/{1}/{2}/{3}"
    path_ou_nc = r"../data/predictions/nc/{1}/{2}/{3}"
    path_ou_png = r"../data/predictions/png/{1}/{2}/{3}"
    path_in_rails = r"../data/geospatial/NATREG_SpoorWegen/railways_RDNew.shp"
    path_in_roads = r"../data/geospatial/NWB_wegen/roads_RDNew.shp"

    file_template_csv = "{0}m_ensemble_boot/Grid_pred_{0}m_boot_{1}.csv"
    file_template_nc = "{0}m_ensemble_boot/Grid_pred_{0}m_boot_{1}.nc"
    file_template_png = "{0}m_ensemble_boot/Grid_pred_{0}m_boot_{1}.png"

    rail_geometries = list(Reader(path_in_rails, bbox=bbox_full).geometries())
    road_geometries = list(Reader(path_in_roads, bbox=bbox_full).geometries())

    for s in spaces:
        i = 0
        for date in date_range:
            print("Processing: ", date)
            date_str = date.strftime("%Y-%m-%d_%H")
            cur_file = file_template_csv.format(s, date_str)
            path_cur = path_in_csv.format(features_version, folder_name, name_exp, cur_file)
            print(path_cur)

            try:

                df_cur = pd.read_csv(path_cur, sep=";", header=0, index_col=0)
                df_cur = df_cur.reset_index()
                df_cur = df_cur.set_index(["rowid", "datetime", "longitude", "latitude", "geometry"]) # Hides columns from the mean operator
                df_mean = df_cur.mean(axis=1)
                df_mean = df_mean.reset_index().set_index("rowid")
                df_mean.columns.array[-1] = 'mean_pred_ens'

                # We turn the dataframe into one in RD_New, because then the resolution
                # of the output dataframe can be done in one line using the geocube
                # package. It's more cumbersome with lat/lon coordinates, due to the
                # 'resolution' parameter that requires knowing the dims of the output
                # array beforehand.

                gdf = gpd.GeoDataFrame(df_mean, geometry=gpd.points_from_xy(df_mean.longitude, df_mean.latitude), crs="EPSG:4326")
                gdf_rdnew = gdf.to_crs(epsg=28992)

                # This positions the dataframe results into a 2D array with a particular resolution
                out_grid = make_geocube(vector_data=gdf_rdnew, measurements=["mean_pred_ens"], resolution=(-s, s))
                xmin = gdf_rdnew["geometry"].bounds["minx"].min()
                xmax = gdf_rdnew["geometry"].bounds["maxx"].max()
                ymin = gdf_rdnew["geometry"].bounds["miny"].min()
                ymax = gdf_rdnew["geometry"].bounds["maxy"].max()

                xs = np.arange(xmin, xmax, s)
                ys = np.arange(ymin, ymax, s)

                # Saving results into raster format
                da = xr.DataArray(out_grid["mean_pred_ens"].values[::-1], coords=[ys, xs], dims=["y", "x"])
                da = da.expand_dims(time=[date])
                da.name = 'mean_pred_ens'

                da.attrs['standard_name'] = 'tree toppling suitability'
                da.attrs['long_name'] = 'Average classification of ensemble for tree toppling'
                da.attrs['coverage_content_type'] = 'modelResult'
                da.attrs['units'] = '[-1, 1]'

                ds = da.to_dataset(name="oc-ens-avg-pred")
                ds.rio.write_grid_mapping(inplace=True)
                ds.rio.write_crs("epsg:28992", inplace=True)
                ds.spatial_ref.attrs['proj4'] = proj4str

                cur_file_nc = file_template_nc.format(s, date_str)
                path_cur_ou_nc = path_ou_nc.format(features_version, folder_name, name_exp, cur_file_nc)
                ds.to_netcdf(path_cur_ou_nc, mode="w", format="NETCDF4", engine="netcdf4", encoding={'time':{'units':'seconds since 1970-01-01 00:00:00'}})

                # Now saving the results into png format
                # For quick-and-dirty visualization, uncomment the next lines
                # gdf.plot(column="mean_pred_ens", marker="s", markersize=220, legend=True)
                # plt.show()
                # plot = gdf_rdnew.plot(column="mean_pred_ens", ax=ax, alpha=0.5, cmap=plt.cm.viridis)

                ds_rdnew = ds.rio.reproject("EPSG:28992")
                fig, ax = plt.subplots(nrows=1, ncols=1, subplot_kw={'projection': projection})
                fig.suptitle("Hourly suitability for tree toppling (ensemble mean)", size=24)
                img = ds_rdnew["oc-ens-avg-pred"].plot(ax=ax, alpha=0.7, cmap=plt.cm.RdPu, add_colorbar=False, zorder=0, vmin=-1, vmax=1) #, cbar_kwargs=dict(pad=0.15))

                ax.add_geometries(rail_geometries, crs=projection, linewidth=1.00, facecolor='none', edgecolor='#24272C', label="rail", zorder=1)
                ax.add_geometries(road_geometries, crs=projection, linewidth=1.00, facecolor='none', edgecolor='#4E5359', label="road", zorder=3)
                ax.set_title("Time: {0}".format(date_str), size=20, pad=10, fontweight="bold", loc="right")

                cb = plt.colorbar(img)  # , orientation="horizontal", pad=0.15)
                cb.set_label(label=da.attrs['long_name'], labelpad=50, size=20) #, weight='bold')
                cb.ax.tick_params(labelsize=20)

                maximize()

                cur_file_png = file_template_png.format(s, date_str)
                path_cur_ou_png = path_ou_png.format(features_version, folder_name, name_exp, cur_file_png)
                print("writing png in: ", path_cur_ou_png)
                fig.savefig(path_cur_ou_png, dpi=150)

                i += 1
            except FileNotFoundError:
                pass