from __future__ import print_function, division
import numpy as np
from functools import partial


def percentile_hist(ax, x, y, color1='k', color2='r', levels=(10, 25, 50, 75, 90), lw=1, **kwargs):
    fns = []
    props = []
    for l in levels:
        fns.append(lambda x, l=l: np.percentile(x, l))
        color = color2 if l == 50 else color1
        props.append(dict(lw=lw, color=color, label='${}\%$'.format(int(l))))

    return _line_hist(ax, x, y, fns, props=props, **kwargs)


def std_hist(ax, x, y, color1='k', color2='r', levels=(-2, -1, 0, 1, 2), lw=1, **kwargs):

    fns = []
    props = []
    for l in levels:
        if l > 9:
            raise ValueError('std levels may not exceed 9')

        fns.append(lambda x, l=l: np.mean(x)+l*np.std(x))
        color = color2 if l == 0 else color1
        if l == 0:
            label = '$\mu$'
        else:
            label = '$\mu{}$'.format(str(l).replace('1', '', 1)+'\sigma')
        props.append(dict(lw=lw, color=color, label=label))

    return _line_hist(ax, x, y, fns, props=props, **kwargs)


def line_hist(ax, x, y, mode='std', **kwargs):
    if mode == 'std':
        return std_hist(ax, x, y, **kwargs)
    if mode == 'percentile':
        return percentile_hist(ax, x, y, **kwargs)


def _line_hist(ax, x, y, fns, N=20, props=()):
    thresholds = np.linspace(x.min(), x.max(), N+2)

    inds = np.digitize(x, thresholds[1:-1])

    xpos = []
    lines = []

    for ifn, fn in enumerate(fns):
        line_values = []

        for i, t in enumerate(thresholds[:-1]):
            if ifn == 0:
                x_lower, x_upper = t, thresholds[i+1]
                xpos.append(.5*(x_lower+x_upper))
            values = y[inds == i]

            if values.size < 3:
                line_values.append(np.nan)
            else:
                line_values.append(fn(values))
        lines.append(np.array(line_values))

    xpos = np.array(xpos)
    lhs = []
    for i, l in enumerate(lines):
        mask = ~np.isnan(l)
        try:
            plot_kw = props[i]
        except IndexError:
            plot_kw = dict()

        lh, = ax.plot(xpos[mask], l[mask], **plot_kw)
        lhs.append(lh)

    return tuple(lhs)
