import copy
import time
import warnings
from multiprocessing.pool import ThreadPool

import numpy as np

from surrogate.metrics import calc_metric
from surrogate.util.factory import dfs
from surrogate.util.misc import empty_dict_if_none


class Entry:

    def __init__(self,
                 name=None,
                 clazz=None,
                 params=None,
                 defaults=None,
                 kwargs=None,
                 obj=None) -> None:
        super().__init__()
        self.name = name
        self.clazz = clazz
        self.params = empty_dict_if_none(params)
        self.defaults = empty_dict_if_none(defaults)
        self.kwargs = empty_dict_if_none(kwargs)
        self.obj = obj


def fit(model, X, y, kwargs, show_warning, raise_exception):
    try:
        model = copy.deepcopy(model)
        model.fit(X, y, **kwargs)
        return model

    except Exception as e:
        if show_warning:
            print("WARNING: %s has failed: %s" % (model, str(e)))

        if raise_exception:
            raise e


class Benchmark:

    def __init__(self,
                 models=[],
                 defaults=None,
                 kwargs=None,
                 params=None,
                 metrics=["mse", "r2", "mae", "corr-spear", "corr-kendall"],
                 show_warning=False,
                 raise_exception=False,
                 n_threads=None,
                 verbose=False) -> None:

        super().__init__()
        self.metrics = metrics
        self.data = []
        self.show_warning = show_warning
        self.raise_exception = raise_exception
        self.n_threads = n_threads
        self.verbose = verbose

        self.models = []
        for clazz in models:
            self.add_models(clazz, params=params, defaults=defaults, kwargs=kwargs)

    def add_models(self, clazz, params=None, defaults=None, kwargs=None):
        if params is None:
            params = clazz.hyperparameters()

        for entry in dfs(params):
            self.add_model(clazz, params=entry, defaults=defaults, kwargs=kwargs)

    def add_model(self, clazz, params=None, defaults=None, kwargs=None, name=None):
        params = empty_dict_if_none(params)
        defaults = empty_dict_if_none(defaults)

        _param = dict(params)
        for key, val in defaults.items():
            _param[key] = val

        obj = clazz(**_param)

        if name is None:
            # name = "%s [id=%s]" % (clazz.__name__, id(obj))
            # name = "%s %s" % (clazz.__name__, params)
            name = clazz.__name__

        self.models.append(Entry(name=name, params=params, defaults=defaults, kwargs=kwargs, obj=obj))

    def add_model_by_obj(self, name, obj):
        self.models.append(Entry(name=name, obj=obj))

    def do(self, X, y, X_test=None, y_test=None, partitions=None, label=None):

        if X_test is not None and y_test is not None:
            partitions = [(range(len(X)), range(len(X), len(X) + len(X_test)))]
            X, y = np.row_stack([X, X_test]), np.concatenate([y, y_test])
        else:
            if partitions is None:
                raise Exception("Either provide a test set or partitions!")

        if label is None:
            label = time.time()

        ret = []

        # create entry for each surrogate fitting to be processed
        for i, (train, test) in enumerate(partitions):
            for entry in self.models:
                val = dict(
                    benchmark="benchmark@%s" % id(self),
                    label=label,
                    model=entry.name,
                    obj=entry.obj,
                    params=str(entry.params),
                    kwargs=entry.kwargs,
                    partition=i,
                    train=train,
                    test=test,
                )

                ret.append(val)

        # fit all the models and store the predictions for the dat sets
        if self.n_threads is None:
            models = self.fit_serialized(X, y, ret)
        else:
            models = self.fit_parallelized(X, y, ret)

        for k in range(len(ret)):
            result, model = ret[k], models[k]

            result["y_hat"] = None
            result["trn_mae"] = float("inf")
            for metric in self.metrics:
                result[metric] = float("inf")

            try:
                trn, tst = result["train"], result["test"]
                y_hat = model.predict(X[tst])

                result["y_hat"] = y_hat
                # result["trn_mae"] = calc_metric("mae", y[trn], model.predict(X[trn]))

                for metric in self.metrics:
                    warnings.simplefilter("ignore")
                    result[metric] = calc_metric(metric, y[tst], y_hat)

                if self.verbose:
                    print(result["model"], result["params"], result["obj"].time)
            except:
                pass

        # filter out if the model has failed for whatever reason
        ret = [e for e in ret if e["y_hat"] is not None]

        self.data.extend(ret)

        return ret

    def fit_parallelized(self, X, y, results):
        with ThreadPool(self.n_threads) as pool:
            models = pool.starmap(fit, [(e["obj"], X[e["train"]], y[e["train"]], e["kwargs"],
                                         self.show_warning, self.raise_exception) for e in results])
            return models

    def fit_serialized(self, X, y, results):
        ret = []
        for result in results:
            model, trn, kwargs = result["obj"], result["train"], result["kwargs"]
            model = fit(model, X[trn], y[trn], kwargs, self.show_warning, self.raise_exception)
            ret.append(model)
        return ret

    def export(self, file):
        np.save(file, self.data)

    def group_by_model(self):
        H = {}
        for entry in self.data:
            obj = entry["obj"]
            if obj not in H:
                H[obj] = []
            H[obj].append(entry)

        return H

    def results(self, metric="mae"):
        import pandas as pd
        df = pd.DataFrame.from_dict(self.data)
        tbl = df.groupby(["model", "params"]).agg({metric: ['median', 'min', 'max', 'mean', 'std']})
        return tbl

    def print(self, metric="mae"):
        print(self.results(metric).to_string())
