from pathlib import Path

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import ConnectionPatch, RegularPolygon
from matplotlib.path import Path as mPath
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from quantify_core.data.handling import (
    load_dataset,
    set_datadir,
)

set_datadir(Path("../data"))


def generate_vertices(dataset, vertices_list):
    vertices = []
    for i in range(len(vertices_list)):
        vert = (
            np.where(dataset.x0 >= vertices_list[i][0])[0][0],
            np.where(dataset.x1 >= vertices_list[i][1])[0][0],
        )
        vertices.append(vert)
    return vertices


def generate_masks(dataset, vertices_list):
    nx, ny = len(dataset.x0), len(dataset.x1)
    poly_verts = generate_vertices(dataset, vertices_list)

    x, y = np.meshgrid(np.arange(nx), np.arange(ny))
    x, y = x.flatten(), y.flatten()

    points = np.vstack((x, y)).T

    path = mPath(poly_verts)
    mask_inner = path.contains_points(points)
    mask_inner = mask_inner.reshape((ny, nx))
    mask_outer = np.logical_not(mask_inner)

    return mask_inner, mask_outer


def plot_vertices(ax, vertices_list, color):
    for c in vertices_list:
        hexagon = RegularPolygon(
            (c[0], c[1]), numVertices=5, radius=0.4, alpha=1, edgecolor="white"
        )
        ax.add_patch(hexagon)
    for c in range(len(vertices_list) - 1):
        conn = ConnectionPatch(
            xyA=(vertices_list[c][0], vertices_list[c][1]),
            coordsA="data",
            axesA=ax,
            xyB=(vertices_list[c + 1][0], vertices_list[c + 1][1]),
            coordsB="data",
            axesB=ax,
            color=color,
        )
        ax.add_artist(conn)
    conn = ConnectionPatch(
        xyA=(vertices_list[-1][0], vertices_list[-1][1]),
        coordsA="data",
        axesA=ax,
        xyB=(vertices_list[0][0], vertices_list[0][1]),
        coordsB="data",
        axesB=ax,
        color=color,
    )
    ax.add_artist(conn)


splitting_data_vincent_vega_high = load_dataset(
    "20241218-043859-573-85575c-cavity_lateral_finesse_scan"
)
finesse_vincent_vega_high = splitting_data_vincent_vega_high.y1.values.transpose()[::-1]
splitting_vincent_vega_high = splitting_data_vincent_vega_high.y5.values.transpose()[
    ::-1
]
rsquared_vincent_vega_high = splitting_data_vincent_vega_high.y9.values.transpose()[
    ::-1
]
x_values_splitting_vega_high = splitting_data_vincent_vega_high.x0
y_values_splitting_vega_high = splitting_data_vincent_vega_high.x1

splitting_data_vincent_vega_zoom = load_dataset(
    "20241217-002005-736-12e3cf-cavity_lateral_finesse_scan"
)
finesse_vincent_vega_zoom = splitting_data_vincent_vega_zoom.y1.values.transpose()[::-1]
splitting_vincent_vega_zoom = splitting_data_vincent_vega_zoom.y5.values.transpose()[
    ::-1
]
rsquared_vincent_vega_zoom_finesse = splitting_data_vincent_vega_zoom.y8.transpose()[
    ::-1
]
rsquared_vincent_vega_zoom = splitting_data_vincent_vega_zoom.y9.values.transpose()[
    ::-1
]
x_values_splitting_vega_zoom = splitting_data_vincent_vega_zoom.x0
y_values_splitting_vega_zoom = splitting_data_vincent_vega_zoom.x1

for i in range(len(x_values_splitting_vega_zoom)):
    for j in range(len(y_values_splitting_vega_zoom)):
        if rsquared_vincent_vega_zoom[i, j] <= 0.95:
            splitting_vincent_vega_zoom[i, j] = np.nan
        if rsquared_vincent_vega_zoom_finesse[i, j] <= 0.95:
            finesse_vincent_vega_zoom[i, j] = np.nan

# For histrogram, from Fig. 3
INNER_VERTICES_LIST_VV_H = [  # (x,y) in um, starting from the top left, clockwise
    (12, 50),
    (14, 68),
    (56, 100),
    (93, 50),
    (42, 10),
]
mask_diamond_vincent_vega, _ = generate_masks(
    splitting_data_vincent_vega_high, INNER_VERTICES_LIST_VV_H
)

splitting_vincent_vega_high_diamond = np.zeros_like(splitting_vincent_vega_high)

# only values on diamond for histrogram in (b)
for i in range(len(mask_diamond_vincent_vega)):
    for j in range(len(mask_diamond_vincent_vega[0])):
        if mask_diamond_vincent_vega[i, j] and (
            rsquared_vincent_vega_high[i, j] > 0.95
        ):
            splitting_vincent_vega_high_diamond[i, j] = splitting_vincent_vega_high[
                i, j
            ]
        else:
            splitting_vincent_vega_high_diamond[i, j] = np.nan

splitting_wo_nans = splitting_vincent_vega_high_diamond[
    ~np.isnan(splitting_vincent_vega_high_diamond)
]

# bin size first bin
hist, bins = np.histogram(
    splitting_wo_nans,
    bins=100,
    range=(0.0, 20),
)

fig = plt.figure(figsize=(8, 6))
panel = gridspec.GridSpec(
    2, 2, width_ratios=[1, 1], wspace=0.4, hspace=0.25, figure=fig
)
ax_1 = plt.subplot(panel[0, 0])
ax_2 = plt.subplot(panel[0, 1])
ax_3 = plt.subplot(panel[1, 0])
ax_4 = plt.subplot(panel[1, 1])

cmap = plt.cm.viridis
cmap.set_bad(color="black")

ax_1.text(-0.3, 1.05, "(a)", transform=ax_1.transAxes, fontsize=12, va="top")
img = ax_1.pcolormesh(
    x_values_splitting_vega_high,
    y_values_splitting_vega_high,
    splitting_vincent_vega_high,
    vmax=15,
    cmap=cmap,
)
cbar = fig.colorbar(img)
cbar.set_label("Polarization Splitting (GHz)")
ax_1.set_xlabel("X Position (µm)")
ax_1.set_ylabel("Y Position (µm)")
ax_1.set_xlim(0, 110)
ax_1.set_ylim(0, 110)
plot_vertices(ax_1, INNER_VERTICES_LIST_VV_H, "white")
mark_inset(ax_1, ax_3, loc1=2, loc2=1, fc="none", ec="0.5", ls=":")

ax_2.text(-0.24, 1.1, "(b)", transform=ax_2.transAxes, fontsize=12, va="top")
ax_2.bar(
    bins[1:-1],
    hist[1:],
    alpha=0.7,
    width=0.2,
    label="Diamond",
    color=plt.colormaps.get_cmap("viridis")(60),
)
ax_2.set_xlabel("Polarization Splitting (GHz)")
ax_2.set_ylabel("Occurrences")
ax_2.legend(loc="upper right")

print("Polarization spltting, occurrences in zero bin: {}".format(hist[0]))
print(
    "Polarization spltting, occurrences in all other bins: {}".format(np.sum(hist[1:]))
)

ax_3.text(-0.3, 1.05, "(c)", transform=ax_3.transAxes, fontsize=12, va="top")
img = ax_3.pcolormesh(
    x_values_splitting_vega_zoom,
    y_values_splitting_vega_zoom,
    splitting_vincent_vega_zoom,
    vmax=10,
    cmap=cmap,
)
cbar = fig.colorbar(img)
cbar.set_label("Polarization Splitting (GHz)")
ax_3.set_xticks([30, 35, 40, 45, 50])
ax_3.set_xlabel("X Position (µm)")
ax_3.set_ylabel("Y Position (µm)")

ax_4.text(-0.3, 1.05, "(d)", transform=ax_4.transAxes, fontsize=12, va="top")
img = ax_4.pcolormesh(
    x_values_splitting_vega_zoom,
    y_values_splitting_vega_zoom,
    finesse_vincent_vega_zoom,
    vmax=10000,
    cmap=cmap,
)
cbar = fig.colorbar(img)
cbar.set_label(r"Finesse $\mathcal{F}$")
ax_4.set_xticks([30, 35, 40, 45, 50])
ax_4.set_xlabel("X Position (µm)")
ax_4.set_ylabel("Y Position (µm)")

fig.savefig(Path("../Fig_5.png"), dpi=600, bbox_inches="tight")
