from __future__ import division
import numpy as np
from matplotlib.colors import ListedColormap, Colormap
from matplotlib import cm
import colorsys


def axpositions(figsize, shape, margin=.05, padding=.1, box=(0, 0, 1, 1)):
    fw, fh = figsize
    m, n = shape
    ratio = fw/fh

    xmar = margin/ratio
    ymar = margin

    xpad = padding/ratio
    ypad = padding

    w = (1-(2*xmar+(n-1)*xpad))/n
    h = (1-(2*ymar+(m-1)*ypad))/m

    axpos = []
    for i in range(m):
        for j in range(n):
            px = xmar+j*(w+xpad)
            py = ymar+i*(h+ypad)

            pos = np.array([px * (box[2] - box[0]) + box[0],
                            py * (box[3] - box[1]) + box[1],
                            w * (box[2] - box[0]),
                            h * (box[3] - box[1])])

            axpos.append(pos)

    return np.array(axpos)


def scaled_subplots(fig, shape, **kwargs):
    figsize = tuple(fig.get_size_inches())
    axes = []
    for axpos in axpositions(figsize, shape, **kwargs):
        axes.append(fig.add_axes(axpos))
    return axes


def hsubplots(figwidth, shape, hpad=0, vpad=0, box=(0, 0, 1, 1), ax_aspect=1):
    """
    calculate axes positions for a fixed width
    figure height is dependent on the subplot axes and spacings
    :param figwidth: width in inches of figure to build
    :param shape: (rows, cols)
    :param hpad: horizontal padding
    :param vpad: vertical padding
    :param box: (xll, yll, xur, yur)
    :param ax_aspect: aspect of individual axes
    :return: figsize (w, h), axpositions (3D array)
             positions are returned as a 3D array of row, col, subplot box
    >>>from matplotlib import pyplot as plt
    >>>figsize, positions = hsubplots(10, (2, 3), box=(.05, .05, .95, .95))
    >>>fig = plt.figure(figsize=figsize)
    >>>for i, rowpos in enumerate(positions):
    >>>    for j, pos in enumerate(rowpos):
    >>>        fig.add_axes(pos)
    >>>plt.show()
    """
    m, n = shape

    box = np.array(box)

    axwidth = (box[2] - box[0] - hpad * (n - 1)) / n
    axheight = (box[3] - box[1] - vpad * (m - 1)) / m
    figheight = figwidth * (axwidth / axheight) * ax_aspect

    axpos = []
    for i in range(m):  # rows
        axpos.append([])
        for j in range(n):  # cols
            x = box[0] + j * axwidth + j * hpad
            y = box[1] + i * axheight + i * vpad
            axpos[-1].append([x, y, axwidth, axheight])

    return (figwidth, figheight), np.array(axpos[::-1])
axpos_fixed_width = hsubplots  # backwards compatibility


def annotate(ax, label, offset=.05, dx=0, dy=0, radius=.05, **kwargs):
    x, y = (0+offset, 1-offset)
    ax.plot([x], [y], color='none', mfc='w', mec='k', alpha=.9, ms=.5, transform=ax.transAxes)
    ax.text(x+dx, y+dy, str(label), ha='center', va='center', zorder=3, transform=ax.transAxes, **kwargs)


def alpha_colormap(cmap, alpha):
    if not isinstance(cmap, Colormap):
        cmap = cm.get_cmap(cmap)
    colors = np.concatenate([cmap.colors, alpha+np.zeros((cmap.colors.shape[0], 1))], axis=1)
    return ListedColormap(colors, name=cmap.name+'_alpha')


def discrete_colors(N, s=.7, v=.8):
    for h in np.linspace(0, 1, N+1)[:-1]:
        yield colorsys.hsv_to_rgb(h, s, v)

