"""Module to produce surface elevation change rates

... from the glacier surface elevation dataset published
with DOI 10.4121/955d7f5a-0e3f-4166-a411-f0dcc4557cb2
"""

from cryoswath.misc import find_region_id, load_glacier_outlines
from cryoswath.l4 import fill_voids, trend_with_seasons
import numpy as np
import xarray as xr
from warnings import catch_warnings, filterwarnings


def _filter_insuficient_data(ds):
    ds = ds.where(ds.elev_diff_obs_count > 3).where(ds.elev_diff_error < 30)
    ds = ds.where(
        np.logical_and(
            (~ds.elev_diff.isel(time=slice(None, 30)).isnull()).sum("time") > 5,
            (~ds.elev_diff.isel(time=slice(-30, None)).isnull()).sum("time") > 5,
        )
    )
    return ds


def _fit_trend_with_seasons(da):
    return da.transpose("time", "y", "x").curvefit(
        coords="time",
        func=trend_with_seasons,
        param_names=["trend", "offset", "amp_yearly", "phase_yearly", "amp_semiyr", "phase_semiyr"],
        bounds={
            "amp_yearly": (0, np.inf),
            "phase_yearly": [-np.pi, np.pi],
            "amp_semiyr": (0, np.inf),
            "phase_semiyr": [-np.pi, np.pi],
        },
        errors="ignore",
    )


def _calculate_residuals(fit_params, da):
    model_vals = xr.apply_ufunc(
        trend_with_seasons, da.time.astype("int"), *[fit_params.sel(param=p) for p in fit_params.param], dask="allowed"
    )
    return da - model_vals.rename(da.name)


def _filter_uncertain_trends(fit_results):
    mask = np.logical_and(
        # trend variance < 2 (m/yr)^2
        fit_results.curvefit_covariance.sel(cov_i="trend", cov_j="trend") < 2,
        # yearly and semi-yearly amplitude standard deviation < 10 m
        np.logical_and(
            fit_results.curvefit_covariance.sel(cov_i="amp_yearly", cov_j="amp_yearly") < 100,
            fit_results.curvefit_covariance.sel(cov_i="amp_semiyr", cov_j="amp_semiyr") < 100,
        ),
    )
    return fit_results.where(mask)


def build_trend_georaster(ds):
    # Filter cells with too little data and ensure reasonable dask chunks
    # that are *not* not cut along the time dimension
    ds = ds.chunk(dict(time=-1, x=100, y=100)).pipe(_filter_insuficient_data)
    fit_params = _fit_trend_with_seasons(ds.elev_diff).curvefit_coefficients
    residuals = _calculate_residuals(fit_params, ds.elev_diff)
    outlier = np.abs(residuals) > 2 * residuals.std("time")
    trend_fit_results = (
        _fit_trend_with_seasons(ds.elev_diff.where(~outlier))
        .pipe(_filter_uncertain_trends)
        .sel(param="trend", cov_i="trend", cov_j="trend")
        .drop_vars(["param", "cov_i", "cov_j"])
    )
    basin_gdf = load_glacier_outlines(find_region_id(ds), "basins", False)
    filled = fill_voids(
        xr.Dataset()
        .assign(
            trend=trend_fit_results.curvefit_coefficients,
            trend_std=trend_fit_results.curvefit_covariance**0.5,
        )
        .rio.write_crs(ds.rio.crs).compute(),
        "trend",
        "trend_std",
        basin_shapes=basin_gdf,
        outlier_replace=True,
        outlier_limit=2,
    )[["trend", "trend_std"]].transpose("y", "x")
    filled.trend.attrs["_FillValue"] = np.nan
    filled.trend_std.attrs["_FillValue"] = np.nan
    return filled


if __name__ == "__main__":
    from argparse import ArgumentParser

    parser = ArgumentParser(description="Calculates surface elevation change rate")
    parser.add_argument("input_path", nargs=1)
    parser.add_argument("output_path", nargs=1)

    with catch_warnings():
        filterwarnings("ignore", "invalid value encountered in divide", RuntimeWarning)
        raster = build_trend_georaster(xr.open_dataset(parser.parse_args().input_path[0]))
    raster.rio.to_raster(parser.parse_args().output_path[0])
