import numpy as np
from collections import namedtuple


def transect_peak_uncertainty(x, y, mask, peaks=True):
    assert x.shape == y.shape == mask.shape, (x.shape, y.shape, mask.shape)

    # invert y axis if troughs are analysed
    pt_mod = 1 if peaks else -1

    y = pt_mod*y

    dx = np.concatenate([np.diff(x), [0]])
    dy = np.concatenate([np.diff(y), [0]])

    pi = np.arange(mask.size)[mask]

    alts = np.array([pi - 1, pi+1])
    m = np.argmax(y[alts], axis=0)
    alt_pi = alts[m, np.arange(m.size)]

    peaksets = np.sort(np.concatenate([pi.reshape(1, -1), alt_pi.reshape(1, -1)]), axis=0)

    xp = x[peaksets[0, :]], x[peaksets[1, :]]
    yp = y[peaksets[0, :]], y[peaksets[1, :]]
    dyp = dy[peaksets[0, :]-1] / dx[0], dy[peaksets[1, :]] / dx[0]

    xi = (yp[1] - yp[0] + dyp[0]*xp[0] - dyp[1] * xp[1]) / (dyp[0] - dyp[1])
    yi = dyp[0] * (xi - xp[0]) + yp[0]

    err_x = xi - x[pi]
    err_y = yi - y[pi]

    xi[err_y < 0] = np.nan
    yi[err_y < 0] = np.nan
    err_x[err_y < 0] = np.nan
    err_y[err_y < 0] = np.nan

    return xi, pt_mod*yi, err_x, pt_mod*err_y


def peak_uncertainty(X, Y, Z, pm, tm):
    Result = namedtuple('UncertaintyResult', ['xerr_crest', 'zerr_crest', 'xerr_trough', 'zerr_trough'])

    # coordinates to positions starting at 0
    pos = np.ma.masked_invalid((X**2+Y**2)**.5)
    pos -= pos.min(axis=0)

    # empty results matrices
    xp_err = np.zeros(X.shape)*np.nan
    zp_err = np.zeros(X.shape)*np.nan

    xt_err = np.zeros(X.shape)*np.nan
    zt_err = np.zeros(X.shape)*np.nan

    # loop transects
    for i in range(X.shape[1]):
        m = ~pos[:, i].mask
        tr_pos = pos[:, i][m]
        tr_z = Z[:, i][m]
        tr_pm = pm[:, i][m]
        tr_tm = tm[:, i][m]

        # crests
        _, _, xe, ye = transect_peak_uncertainty(tr_pos, tr_z, tr_pm, peaks=True)
        xp_err[pm[:, i], i] = xe
        zp_err[pm[:, i], i] = ye

        # troughs
        _, _, xe, ye = transect_peak_uncertainty(tr_pos, tr_z, tr_tm, peaks=False)
        xt_err[tm[:, i], i] = xe
        zt_err[tm[:, i], i] = ye

    return Result(np.ma.masked_invalid(xp_err),
                  np.ma.masked_invalid(zp_err),
                  np.ma.masked_invalid(xt_err),
                  np.ma.masked_invalid(zt_err))

