"""
Tide datasets
"""

from __future__ import print_function, division
import numpy as np
import pandas as pd
import datetime
from scipy.interpolate import RectBivariateSpline
from ..tools import ttide, transport
import os


def magnitude(a, b):
    return (a**2+b**2)**.5


def RMS(diff):
    return np.sqrt(np.sum(diff**2)/len(diff))


class BaseTide(object):

    def __init__(self, x, y, t, U, V):
        self.x = x
        self.y = y
        self.t = t
        self.U = U
        self.V = V

        m, n, o = self.x.size, self.y.size, self.t.size
        if self.U.shape != (o, n, m) or self.V.shape != (o, n, m):
            raise ValueError('input data of invalid shape, expected ({}, {}, {})'.format(t.size, y.size, x.size))
    @property
    def complex(self):
        return self.U+1j*self.V

    @property
    def shape(self):
        return self.U.shape

    def mean(self):
        return self.x, self.y, self.U.mean(axis=0), self.V.mean(axis=0)

    @property
    def points(self):
        """return grid positions as points"""
        X, Y = np.meshgrid(self.x, self.y)
        X0 = X.reshape(-1)
        Y0 = Y.reshape(-1)
        return X0, Y0

    def save(self, filename):
        if os.path.splitext(filename)[1] != '.npz':
            raise ValueError('filename must have .npz extension')

        U = self.U
        if isinstance(U, np.ma.MaskedArray):
            U = U.filled(np.nan)

        V = self.V
        if isinstance(V, np.ma.MaskedArray):
            V = V.filled(np.nan)

        np.savez(filename, x=self.x, y=self.y, t=self.t, u=U, v=V)

    @classmethod
    def load(cls, filename, **kwargs):
        if os.path.splitext(filename)[1] == '.nc':
            from .matroos import VelocitiesData
            data = VelocitiesData.load(filename=filename)
            return cls.from_dataset(data)
        else:
            npz = np.load(filename)
            U, V = npz['u'], npz['v']
            U = np.ma.masked_invalid(U)
            V = np.ma.masked_invalid(V)
            return cls(npz['x'], npz['y'], npz['t'], U, V, **kwargs)

    def plot_points(self, ax, posscale=.001, **kwargs):
        x, y = self.points
        return ax.scatter(posscale*x, posscale*y, lw=0, **kwargs)

    def plot(self, ax, vscale=5, posstep=5, posscale=.001, color='k', tmax=None, center_color='k', center_dict=None, **kwargs):
        # create flattened position arrays
        X, Y = np.meshgrid(self.x[::posstep], self.y[::posstep])
        x = X.reshape(-1)
        y = Y.reshape(-1)

        # apply max timestep
        if tmax is None:
            imax = None
        else:
            imax = np.argmax(self.t[self.t < tmax])
        sl = slice(0, imax)

        # apply steps
        U = self.U[sl, ::posstep, ::posstep]
        V = self.V[sl, ::posstep, ::posstep]

        # flatten
        U = U.reshape(U.shape[0], -1)
        V = V.reshape(V.shape[0], -1)

        return ax.plot(vscale*U+posscale*x, vscale*V+posscale*y, color=color, **kwargs)

    @classmethod
    def from_dataset(cls, data):
        return cls(data.x, data.y, data.t, data.u, data.v)

    def __add__(self, other):
        if isinstance(other, BaseTide):
            if not (self.t == other.t).all():
                raise ValueError('could not match timeseries')
            elif not (self.x == other.x).all() or not (self.y == other.y).all():
                raise ValueError('could not match locations')
            return BaseTide(self.x, self.y, self.t, self.U+other.U, self.V+other.V)
        else:
            return super(BaseTide, self).__add__(other)

    def __repr__(self):
        return '<{cls} object {startDT}+{timespan.days}days>'.format(
                cls=self.__class__.__name__,
                startDT=self.startDT.strftime('%Y-%m-%d'),
                timespan=self.timespan)

    @property
    def timespan(self):
        return datetime.timedelta(seconds=(self.t[-1] - self.t[0])*60)

    @property
    def startDT(self):
        return datetime.datetime(year=1970, month=1, day=1) + datetime.timedelta(seconds=self.t[0]*60)

    @property
    def endDT(self):
        return datetime.datetime(year=1970, month=1, day=1) + datetime.timedelta(seconds=self.t[-1]*60)

    @property
    def magnitude(self):
        return magnitude(self.U, self.V)

    def interp(self, px, py, key=None):
        return RectBivariateSpline(self.y, self.x, key(self)).ev(py, px)

    @property
    def elongation(self):
        frac = self.max_velocity / self.min_velocity
        return 1 - 1 / (frac + 1)  # scale between 0 and 1

    @property
    def tidal_axis(self, component='M2'):
        I0 = np.argmax(self.magnitude, axis=0)
        I1, I2 = np.indices((self.V.shape[1:]))
        return np.arctan(self.V[I0, I1, I2]/self.U[I0, I1, I2])

    @property
    def max_velocity(self):
        m = np.amax(self.magnitude, axis=0)
        if isinstance(m, np.ma.MaskedArray):
            m = m.filled(0)
        return np.ma.masked_equal(m, 0)

    @property
    def min_velocity(self):
        m = np.amin(self.magnitude, axis=0)
        if isinstance(m, np.ma.MaskedArray):
            m = m.filled(0)
        return np.ma.masked_equal(m, 0)

    def transport_approx(self, depth=30, d50=300e-6, vanRijn=False):
        if vanRijn:
            fn = transport.bed_load_vRijn
        else:
            fn = transport.bed_load_transport
        return fn(self.complex, depth, d50)

    @property
    def mean_velocity(self):
        return self.U.mean(axis=0) + 1j*self.V.mean(axis=0)


class TidalComponent(BaseTide):

    def __init__(self, *args, **kwargs):
        super(TidalComponent, self).__init__(*args)
        self.name = kwargs.get('name', None)

    @classmethod
    def load(cls, filename, name):
        return super(TidalComponent, cls).load(filename, name=name)

    def __repr__(self):
        return '<{cls} {name} {startDT}+{timespan.days}days>'.format(
                cls=self.__class__.__name__,
                name=self.name,
                startDT=self.startDT.strftime('%Y-%m-%d'),
                timespan=self.timespan)


class Tide(BaseTide):

    def __init__(self, x, y, t, U, V):
        t, U, V = self.check_timeseries(t, U, V)
        super(Tide, self).__init__(x, y, t, U, V)

    def check_timeseries(self, t, U, V):
        m, n, o = U.shape

        tdiff = np.diff(t)
        dt = np.amin(tdiff)
        if (tdiff == dt).all():
            return t, U, V

        t_new = np.arange(np.amin(t), np.amax(t), dt)
        old_mask = np.in1d(t, t_new)
        new_mask = np.in1d(t_new, t)

        U_new = np.ma.masked_invalid(np.zeros((t_new.size, n, o))*np.nan)
        V_new = np.ma.masked_invalid(np.zeros((t_new.size, n, o))*np.nan)

        U_new[new_mask, :, :] = U[old_mask, :, :]
        V_new[new_mask, :, :] = V[old_mask, :, :]

        return t_new, U_new, V_new

    def _calculate_time_index(self, constitnames, period=None, dt=.5):
        if period is None:
            if 'S2' in constitnames:
                period = 'sn'
            else:
                period = 'M2'

        if period == 'sn':
            T = 14.77*24
        elif period == 'M2':
            T = 12.42
        return int(T / dt)

    def reduce_point(self, t, vel, dt=.5, lat=52.5, constitnames=('M2', 'M4', 'S2'), period=None):
        if isinstance(period, int):
            ti_max = period
        else:
            ti_max = self._calculate_time_index(constitnames, period, dt)

        _, _, _, TIDE = ttide.reduce(vel, dt=dt, lat=52.5, constitnames=constitnames, output=False)
        return t, TIDE[:ti_max].reshape(-1)

    def reduce(self, dt=.5, constitnames=('M0', 'M2', 'M4', 'S2'), period=None, **kwargs):
        ti_max = self._calculate_time_index(constitnames, period, dt)

        shape = (ti_max, self.x.size, self.y.size)

        t = self.t[:ti_max]
        U = np.zeros(shape, dtype=float)
        V = np.zeros(shape, dtype=float)

        if 'M0' in constitnames:
            constitnames = [c for c in constitnames if c != 'M0']
            vel_M0 = self.mean_velocity
            U += vel_M0.real
            V += vel_M0.imag

        if constitnames:
            for i in np.arange(self.y.size):
                for j in np.arange(self.x.size):
                    veldata = self.U[:, i, j] + 1j*self.V[:, i, j]
                    if veldata.mask.all():
                        U[:, i, j] = np.ma.MaskedArray(np.zeros(t.size), mask=np.ones(t.size))
                        V[:, i, j] = np.ma.MaskedArray(np.zeros(t.size), mask=np.ones(t.size))
                    else:
                        _, vel = self.reduce_point(self.t, veldata, dt=dt, constitnames=constitnames, period=ti_max, **kwargs)
                        U[:, i, j] += vel.real
                        V[:, i, j] += vel.imag

        return Tide(self.x, self.y, t, U, V)

    def reduce_point_as_ellipse(self, vel, dt=.5, lat=52.5, constitnames=('M2', 'M4', 'S2')):
        _, freq, props, tide_vel = ttide.reduce(vel, dt=dt, lat=lat, constitnames=constitnames, output=False)
        return freq, props[:, 0], props[:, 2], props[:, 4], props[:, 6]

    def reduce_as_ellipses(self, constitnames=('M0', 'M2', 'M4', 'S2'), dt=.5, lat=52.5):
        shape = len(constitnames), self.y.size, self.x.size

        names = constitnames
        freq = np.zeros(shape, dtype=float)
        Amaj = np.zeros(shape, dtype=complex)
        Amin = np.zeros(shape, dtype=complex)
        incl = np.zeros(shape, dtype=float)
        pha = np.zeros(shape, dtype=float)

        i = 0

        if 'M0' in constitnames:
            constitnames = [c for c in constitnames if c != 'M0']
            Amaj[i, :, :] = self.mean_velocity
            i += 1

        if constitnames:
            for ii in np.arange(self.y.size):
                for jj in np.arange(self.x.size):
                    veldata = self.U[:, ii, jj] + 1j*self.V[:, ii, jj]
                    if veldata.mask.all():
                        _freq, _Amaj, _Amin, _incl, _pha = np.nan, np.nan, np.nan, np.nan, np.nan
                    else:
                        _freq, _Amaj, _Amin, _incl, _pha = self.reduce_point_as_ellipse(veldata, dt=dt, lat=lat, constitnames=constitnames)
                    freq[i:, ii, jj] = _freq
                    Amaj[i:, ii, jj] = _Amaj
                    Amin[i:, ii, jj] = _Amin
                    incl[i:, ii, jj] = _incl
                    pha[i:, ii, jj] = _pha

        return TidalEllipsesGrid(self.x, self.y, names, freq, Amaj, Amin, incl, pha)


"""

Ellipse based approach to tidal components

"""
def ellipse(t, Amaj, Amin, angvel, phase=0, incl=0):
    A1 = .5*(Amaj + Amin)
    A2 = .5*(Amaj - Amin)

    return A1*np.exp(1j*(angvel*t-phase+incl)) + A2*np.exp(-1j*(angvel*t-phase-incl))


def from_ttide(names, freq, props, dt=.5):
    ell = TidalEllipseSet.from_ttide(names, freq, props)

    return ell.as_timeseries(dt=dt)


class Ellipse(object):

    def __init__(self, freq, Amaj, Amin, inclination, phase, angles='degree'):
        self.freq = np.ma.masked_invalid(freq)
        self.Amajor = np.ma.masked_invalid(Amaj)
        self.Aminor = np.ma.masked_invalid(Amin)

        if angles == 'degree':
            inclination = inclination*np.pi/180
            phase = phase*np.pi/180
        elif angles != 'radian':
            raise ValueError('invalid value for angles: "degree" | "radian" expected')

        self.inclination = np.ma.masked_invalid(inclination)
        self.phase = np.ma.masked_invalid(phase)

    @property
    def T(self):
        return 1/self.freq.min()

    def as_timeseries(self, dt=.5, tmin=0, tspan=None):
        """
        return tidal ellipse as velocity timeseries with times in hrs
        :param dt: timestep (hrs)
        :param tmin: start time (hrs)
        :param tspan: end time (hrs)
        :return: timeseries, complex velocities
        """
        if tspan is None:
            tspan = self.T
        t = np.arange(tmin, tmin+tspan, dt)

        vel = self.from_timeseries(t)
        return t, vel

    def from_timeseries(self, t):
        ell = ellipse(2*np.pi * t.reshape(-1, 1),
                  self.Amajor.reshape(1, *self.shape),
                  self.Aminor.reshape(1, *self.shape),
                  self.freq.reshape(1, *self.shape),
                  self.phase.reshape(1, *self.shape),
                  self.inclination.reshape(1, *self.shape))

        return ell

    def as_dataframe(self):
        return pd.DataFrame(dict(name=self.name,
                                 freq=self.freq,
                                 Amajor=self.Amajor,
                                 Aminor=self.Aminor,
                                 inclination=self.inclination,
                                 phase=self.phase))

    @property
    def shape(self):
        return self.freq.shape

    @property
    def elongation(self):
        return 1 - 1 / (self.Amajor / self.Aminor + 1)  # scale between 0 and 1

    def __str__(self):
        return str(self.as_dataframe())


class TidalEllipse(object):

    def __init__(self, name, *args, **kwargs):
        self.name = name
        super(TidalEllipse, self).__init__(*args, **kwargs)


class TidalEllipseSet(Ellipse):

    def __init__(self, names, freq, Amaj, Amin, inclination, phase, **kwargs):
        size = len(names)
        self.names = names
        for v in (freq, Amaj, Amin, inclination, phase):
            if v.size != size:
                raise ValueError('input not of equal size')

        super(TidalEllipseSet, self).__init__(freq, Amaj, Amin, inclination, phase, **kwargs)

    def __len__(self):
        return len(self.names)

    def __getitem__(self, item):
        """
        get an ellipse by name or index
        :param item: name (str) or index (int)
        :return: TidalEllipse instance
        """
        if isinstance(item, str):
            try:
                item = self.names.index(item)
            except IndexError:
                raise KeyError('{} not in {}'.format(item, self.names))

        try:
            return TidalEllipse(self.names[item],
                                self.freq[item],
                                self.Amajor[item],
                                self.Aminor[item],
                                self.inclination[item],
                                self.phase[item])
        except IndexError:
            raise IndexError('index {} out of range for {} of length {}'.format(
                    item, self.__class__.__name__, len(self)))

    @classmethod
    def from_ttide(cls, names, freq, props):
        return cls(names,
                   freq,
                   Amaj=props[:, 0],
                   Amin=props[:, 2],
                   inclination=props[:, 4],
                   phase=props[:, 6],
                   angles='degree')

    def as_total_timeseries(self, *args, **kwargs):
        t, vel = self.as_timeseries(*args, **kwargs)
        return t, np.sum(vel, axis=1)


class TidalEllipsesGrid(Ellipse):

    def __init__(self, x, y, names, freq, Amaj, Amin, inclination, phase, **kwargs):
        shape = len(names), y.size, x.size

        for n, v in zip(("freq", "Amaj", "Amin", "inclination", "phase"),
                        (freq, Amaj, Amin, inclination, phase)):
            if v.shape != shape:
                raise ValueError('input {} not of expected shape {}'.format(n, shape))
            
        self.x = x 
        self.y = y 
        self.names = names

        super(TidalEllipsesGrid, self).__init__(freq, Amaj, Amin, inclination, phase, **kwargs)

    def as_tide_object(self, **kwargs):
        t, vel = self.as_total_timeseries(**kwargs)
        return Tide(self.x, self.y, t, vel.real, vel.imag)

    def __str__(self):
        return '<{} object for {}'.format(self.__class__.__name__, self.names)
