import torch
import os
import json
from torch.utils.data import Dataset
from torch.utils.data import random_split
import cv2 as cv
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader
import pandas as pd


class TuSimpleDataset1(Dataset):
    def __init__(self, root, transform=None, input_size=(448, 448)):
        self.root = root
        self.transform = transform
        self.input_size=input_size
        self.json_files = [names for names in os.listdir(root) if names.endswith('withclasses.json')]
        self.lanes = []
        self.h_samples = []
        # self.classes = []
        self.raw_file = []

        for files in self.json_files:
            with open(os.path.join(root, files), 'r') as file:
                json_annotations = file.readlines()
            for annotates in json_annotations:
                annotation = json.loads(annotates)
                self.lanes.append(annotation['lanes'])
                self.h_samples.append(annotation['h_samples'])
                # self.classes.append(annotation['classes'])
                self.raw_file.append(annotation['raw_file'])

        self.lane_images = [os.path.join(root, path) for path in self.raw_file]

    def create_lane_image(self, item):
        lane_pt_img = np.zeros((self.w, self.h))
        i = 0
        while i < len(self.lanes[item]):
            for v in range(0, len(self.h_samples[item]) - 1):
                point_h_begin = self.h_samples[item][v]
                point_h_end = self.h_samples[item][v + 1]
                point_w_begin = self.lanes[item][i][v]
                point_w_end = self.lanes[item][i][v + 1]

                if point_w_begin != -2 and point_w_end != -2:
                    cv.line(lane_pt_img, (point_w_begin, point_h_begin), (point_w_end, point_h_end), [255, 255, 255], 2)
            i += 1
        return lane_pt_img/255

    def __len__(self):
        return len(self.lanes)

    def __getitem__(self, item):
        lane_img_cv =  cv.normalize(cv.imread(self.lane_images[item]), None, alpha=0, beta=1,
                                    norm_type=cv.NORM_MINMAX) # cv.imread(self.lane_images[item])
        self.lane_img = Image.fromarray(lane_img_cv)
        self.w, self.h, self.c = lane_img_cv.shape
        self.target_img = Image.fromarray((self.create_lane_image(item)).astype(np.uint8))
        if self.transform:
            self.lane_img, self.target_img = self.transform(self.lane_img, self.target_img)
        sample = {'data': self.lane_img, 'label': torch.squeeze(self.target_img)}
        return sample
        #[self.lane_img, torch.squeeze(self.target_img)]

    # def get_splits(self, n_test=0.33):
    #     # determine sizes
    #     test_size = round(n_test * len(self.lanes))
    #     train_size = len(self.lanes) - test_size
    #     # calculate the split
    #     return random_split(self, [train_size, test_size])




class TuSimpleDataset2(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.json_files = [names for names in os.listdir(root) if names.endswith('withclasses.json')]
        self.lanes = []
        self.h_samples = []
        self.classes = []
        self.raw_file = []

        for files in self.json_files:
            with open(os.path.join(root, files), 'r') as file:
                json_annotations = file.readlines()
            for annotates in json_annotations:
                annotation = json.loads(annotates)
                self.lanes.append(annotation['lanes'])
                self.h_samples.append(annotation['h_samples'])
                self.classes.append(annotation['classes'])
                self.raw_file.append(annotation['raw_file'])

        self.lane_images = [os.path.join(root, path) for path in self.raw_file]

    def lane_pts(self, item):
        lane_points = []
        if len(self.lanes[item][0]) == 56:
            if len(self.lanes[item]) == 5:
                w_points = [item for sublist in self.lanes[item] for item in sublist]
                w_points = [float(i)* 0.00078125 for i in w_points]
                class_str = self.classes[item].strip().split(' ')
                class_int = list(map(int, class_str))
                class_int = class_int
                lane_points.extend(w_points)
                h_points = [float(i) * 0.001388889 for i in self.h_samples[item]]
                lane_points.extend(h_points)
                lane_points.extend(class_int)
                for n, i in enumerate(lane_points):
                    if i < 0:
                        lane_points[n] = 0

            if len(self.lanes[item])==4:
                w_points = [item for sublist in self.lanes[item] for item in sublist]
                w_points = [float(i) * 0.00078125 for i in w_points]
                w_points.extend([0]*56)
                class_str = self.classes[item].strip().split(' ')
                class_int = list(map(int, class_str))
                class_int = class_int
                class_int.extend([0]*1)
                lane_points.extend(w_points)
                h_points = [float(i) * 0.001388889 for i in self.h_samples[item]]
                lane_points.extend(h_points)
                lane_points.extend(class_int)
                for n, i in enumerate(lane_points):
                    if i < 0:
                        lane_points[n] = 0

            if len(self.lanes[item])==3:
                w_points = [item for sublist in self.lanes[item] for item in sublist]
                w_points = [float(i) * 0.00078125 for i in w_points]
                w_points.extend([0] * 112)
                class_str = self.classes[item].strip().split(' ')
                class_int = list(map(int, class_str))
                class_int = class_int
                class_int.extend([0] * 2)
                lane_points.extend(w_points)
                h_points = [float(i) * 0.001388889 for i in self.h_samples[item]]
                lane_points.extend(h_points)
                lane_points.extend(class_int)
                for n, i in enumerate(lane_points):
                    if i < 0:
                        lane_points[n] = 0

            if len(self.lanes[item])==2:
                w_points = [item for sublist in self.lanes[item] for item in sublist]
                w_points = [float(i) * 0.00078125 for i in w_points]
                w_points.extend([0] * 168)
                class_str = self.classes[item].strip().split(' ')
                class_int = list(map(int, class_str))
                class_int = class_int
                class_int.extend([0] * 3)
                lane_points.extend(w_points)
                h_points = [float(i) * 0.001388889 for i in self.h_samples[item]]
                lane_points.extend(h_points)
                lane_points.extend(class_int)
                for n, i in enumerate(lane_points):
                    if i < 0:
                        lane_points[n] = 0
        return lane_points


    def __len__(self):
        return len(self.lanes)

    def __getitem__(self, item):
        lane_img_cv = cv.imread(self.lane_images[item])
        self.lane_img = Image.fromarray(lane_img_cv)
        self.target_lanes = list(map(float,self.lane_pts(item)))
        if self.transform:
            self.lane_img, self.target_lanes = self.transform(self.lane_img, self.target_lanes)
        return [self.lane_img, torch.tensor(self.target_lanes)]


# lanedataset = TuSimpleDataset2(root='/home/sandeep/PycharmProjects/LaneATT/datasets/tusimple/', transform=None)
# a = lanedataset[1][1]

class VOCDataset(torch.utils.data.Dataset):
    def __init__(
            self, csv_file, img_dir, label_dir, S=7, B=2, C=20, transform=None,
    ):
        self.annotations = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.transform = transform
        self.S = S
        self.B = B
        self.C = C

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        label_path = os.path.join(self.label_dir, self.annotations.iloc[index, 1])
        boxes = []
        with open(label_path) as f:
            for label in f.readlines():
                class_label, x, y, width, height = [
                    float(x) if float(x) != int(float(x)) else int(x)
                    for x in label.replace("\n", "").split()
                ]

                boxes.append([class_label, x, y, width, height])

        img_path = os.path.join(self.img_dir, self.annotations.iloc[index, 0])
        image = Image.open(img_path)
        boxes = torch.tensor(boxes)

        if self.transform:
            # image = self.transform(image)
            image, boxes = self.transform(image, boxes)

        # Convert To Cells
        label_matrix = torch.zeros((self.S, self.S, self.C + 5 * self.B))
        for box in boxes:
            class_label, x, y, width, height = box.tolist()
            class_label = int(class_label)

            # i,j represents the cell row and cell column
            i, j = int(self.S * y), int(self.S * x)
            x_cell, y_cell = self.S * x - j, self.S * y - i

            """
            Calculating the width and height of cell of bounding box,
            relative to the cell is done by the following, with
            width as the example:

            width_pixels = (width*self.image_width)
            cell_pixels = (self.image_width)

            Then to find the width relative to the cell is simply:
            width_pixels/cell_pixels, simplification leads to the
            formulas below.
            """
            width_cell, height_cell = (
                width * self.S,
                height * self.S,
            )

            # If no object already found for specific cell i,j
            # Note: This means we restrict to ONE object
            # per cell!
            if label_matrix[i, j, 20] == 0:
                # Set that there exists an object
                label_matrix[i, j, 20] = 1

                # Box coordinates
                box_coordinates = torch.tensor(
                    [x_cell, y_cell, width_cell, height_cell]
                )

                label_matrix[i, j, 21:25] = box_coordinates

                # Set one hot encoding for class_label
                label_matrix[i, j, class_label] = 1

        return image, label_matrix