from PyQt4 import QtGui
import numpy as np
from matplotlib import pyplot as plt
from viridis import viridis2 as viridis
from analysis import utils, fourier
from .polygon_mask import FourierMaskCreator
from . import qtcanvas


def get_clim(datamean, datarange):
    if datarange < 0:
        datarange = np.abs(datarange)

    return utils.get_limits((datamean-.5*datarange, datamean+.5*datarange))


class PolygonFilterToolbar(QtGui.QWidget):

    def __init__(self, parent=None):
        super(PolygonFilterToolbar, self).__init__(parent)
        self.build()

    def build(self):
        pass


class FourierFilterWindow(QtGui.QMainWindow):

    def __init__(self, x, y, data, parent=None):
        super(FourierFilterWindow, self).__init__(parent)
        self.x, self.y, self.data = x, y, data

        self.f = fourier.Fourier.transform(
            self.x, self.y, utils.prepare(self.data, unmask=True))

        self.f_prep = fourier.Fourier.transform(
            self.x, self.y, utils.prepare(self.data, unmask=True, taper=True, avg=True))

        self.data_range = np.amax(self.data) - np.amin(self.data)
        self.vmin, self.vmax = get_clim(np.mean(data), self.data_range)

        self.build()

    def build(self):
        self.mainwidget = QtGui.QWidget()
        self.setCentralWidget(self.mainwidget)

        self.layout = QtGui.QVBoxLayout()
        self.mainwidget.setLayout(self.layout)

        self.toolbar = PolygonFilterToolbar()
        self.layout.addWidget(self.toolbar)

        self.figure = plt.figure(figsize=(15, 4))
        self.canvas = qtcanvas.CustomCanvas(self.figure)
        self.layout.addWidget(self.canvas)
        self.plot(self.figure)
        self.draw()


    def plot(self, fig):
        self.axes = []
        self.colorbars = {}

        # Original data (ax1)
        ax = fig.add_axes([.04, .1, .25, .8])
        self.axes.append(ax)
        c = ax.pcolormesh(self.x,
                          self.y,
                          self.data,
                          cmap=viridis,
                          vmin=self.vmin,
                          vmax=self.vmax)

        ax.set_title('Before')
        self.colorbars[ax] = fig.colorbar(c)
        self.colorbars[ax].set_clim(self.vmin, self.vmax)
        self.colorbars[ax].draw_all()

        # Fourier (ax2)
        ax = fig.add_axes([.38, .1, .25, .8])
        self.axes.append(ax)
        self.f_prep.plot(ax, scale=True)
        xlims = ax.get_xlim()
        ylims = ax.get_ylim()
        w = np.diff(xlims)
        h = np.diff(ylims)
        x1, x2 = xlims + (w / 4) * np.array([1, -1])
        y1, y2 = ylims + (h / 4) * np.array([1, -1])
        default_poly = ((x1, y1), (x1, y2), (x2, y2), (x2, y1))

        # add mask creator to axes
        self.mask_creator = FourierMaskCreator(
            ax, self.data.shape, self.f_prep.kx, self.f_prep.ky,
            callback=self._apply_mask, poly_xy=default_poly)

        # New data (ax3)
        ax = fig.add_axes([.7, .1, .25, .8])
        self.axes.append(ax)
        c = ax.pcolormesh(self.x,
                                     self.y,
                                     self.data,
                                     cmap=viridis,
                                     vmin=self.vmin,
                                     vmax=self.vmax)

        ax.set_title('After')
        self.colorbars[ax] = fig.colorbar(c)
        self.colorbars[ax].set_clim(self.vmin, self.vmax)
        self.colorbars[ax].draw_all()

    def _apply_mask(self, mask):
        f_masked = self.f.apply_mask(mask, shift=False)
        newdata = f_masked.reverse(self.data.shape, nanmask=self.data.mask)
        vmin, vmax = get_clim(np.mean(newdata), self.data_range)

        ax = self.axes[2]
        ax.hold(False)
        ax.pcolormesh(self.x, self.y, newdata, cmap=viridis, vmin=vmin, vmax=vmax)
        ax.set_title('After')
        self.colorbars[ax].set_clim(vmin, vmax)
        self.colorbars[ax].draw_all()
        self.canvas.draw()

    def draw(self):
        self.canvas.updateGeometry()
        self.canvas.draw()

    def resizeEvent(self, event):
        self.draw()


def run(x, y, Z):
    import sys

    app = QtGui.QApplication(sys.argv)

    w = FourierFilterWindow(x, y, Z)
    w.setWindowTitle('Fourier filter for sand wave data')
    # w.resize(1000, 300)
    w.show()

    sys.exit(app.exec_())