from __future__ import print_function, division
import numpy as np
import pandas as pd


__all__ = ['PointDataset', 'load', 'save']


class PointDataset(object):
    """
    dataset of point data

    """

    def __init__(self, data, columns=None, **meta):
        if not isinstance(data, pd.DataFrame):
            data = pd.DataFrame(data, columns=columns)
        self.df = data
        self.meta = meta

    @property
    def data(self):
        return self.df.as_matrix()

    @property
    def columns(self):
        return self.df.columns

    def __getitem__(self, item):
        return self.df[item]

    def __setitem__(self, key, value):
        self.df[key] = value

    def __delitem__(self, key):
        self.df.drop(key, axis=1, inplace=True)

    def subset(self, keys):
        return PointDataset(self.df[keys], **self.meta)

    @property
    def normalized(self):
        data = self.data - self.data.min(axis=0).reshape(1, -1)
        data = data / data.max(axis=0).reshape(1, -1)
        return data

    @property
    def shape(self):
        return self.df.shape

    def as_dict(self):
        return self.df.as_dict()

    def keys(self):
        for col in self.columns:
            yield col

    def values(self):
        for col in self.columns:
            yield self.df[col]

    def items(self):
        for col in self.columns:
            yield col, self.df[col]

    @classmethod
    def load(cls, f, columns=None, index_col=False, **kwargs):
        df = pd.read_csv(f, names=columns, index_col=index_col, **kwargs)
        return cls(df)

    def delete_by_mask(self, rowmask):
        if rowmask.ndim != 1:
            raise ValueError('rowmask must be a 1D array')
        return PointDataset(self.data[~rowmask, :], self.columns)

    def save(self, filename, index=False, sep=',', **kwargs):
        self.df.to_csv(filename, sep=sep, index=index, **kwargs)

    def __str__(self):
        return '{}.{}\n    {}'.format(
            __name__,
            self.__class__.__name__,
            str(self.df).replace('\n', '\n    '))


def load(filename, **kwargs):
    return PointDataset.load(filename, **kwargs)


def save(data, columns, filename, **kwargs):
    return PointDataset(data, columns).save(filename, **kwargs)
