import numpy as np

from surrogate.benchmark import Benchmark
from surrogate.partitioning.random import RandomPartitioning


class ModelSelection:

    def __init__(self,
                 *models,
                 defaults=None,
                 params=None,
                 min_train_error=None,
                 metric="mae") -> None:

        super().__init__()
        self.min_train_error = min_train_error
        self.metric = metric
        self.models = models
        self.defaults = defaults
        self.params = params

    def do(self, X, y, partitions=None, **kwargs):

        self.benchmark = Benchmark(models=self.models,
                                   params=self.params,
                                   defaults=self.defaults,
                                   kwargs=kwargs,
                                   raise_exception=False)

        if partitions is None:
            partitions = RandomPartitioning(X, int(0.7 * len(X)), n_sets=1).do()
            # partitions = CrossvalidationPartitioning(X, 10).do()

        self.benchmark.do(X, y, partitions=partitions)
        results = self.benchmark.group_by_model()

        L = []
        for obj, entry in results.items():
            max_trn_error = np.array([e["trn_" + self.metric] for e in entry]).max()
            avg_vld_error = np.array([e[self.metric] for e in entry]).mean()
            L.append((obj, max_trn_error, avg_vld_error))

        if self.min_train_error is not None:
            L = [e for e in L if e[1] < self.min_train_error]

            if len(L) == 0:
                print("WARNING: No model with minimum training error of %s was found." % self.min_train_error)
                L = sorted(L, key=lambda x: x[1])
            else:
                L = sorted(L, key=lambda x: x[2])

        else:
            L = sorted(L, key=lambda x: x[2])

        model = L[0][0]
        model.fit(X, y, **kwargs)

        return model
