import numpy as np

from pycheapconstr.algorithms.sansga2 import SANSGA2, my_select_points_with_maximum_distance
from pycheapconstr.misc.surrogate import Surrogate
from pycheapconstr.sampling.energy_sampling import EnergySampling
from pymoo.algorithms.so_de import DifferentialEvolutionMating
from pymoo.docs import parse_doc_string
from pymoo.model.population import Population
from surrogate.models.rbf import RBF
from surrogate.selection import ModelSelection


class ExactConstraintsSurrogate(Surrogate):

    def _evaluate(self, X, out, *args, **kwargs):
        F, G = [], []

        for k in range(self.n_obj):
            f_hat = self.surr_F[k].predict(X)
            F.append(f_hat)

        out["F"] = np.column_stack(F)

        if self.n_constr > 0:
            G = self.problem.evaluate(X, return_values_of=["G"], only_inexpensive_constraints=True)
            out["G"] = G


class ICSANSGA2(SANSGA2):

    def _initialize(self):

        if self.problem.n_constr == 0:
            super()._initialize()
        else:
            def constr(X):
                return self.problem.evaluate(X, return_values_of=["CV"])[:, 0]

            pop = EnergySampling(constr).do(self.problem, self.pop_size)
            self.evaluator.eval(self.problem, pop, algorithm=self)

            pop = self.survival.do(self.problem, pop, len(pop), algorithm=self,
                                   n_min_infeas_survive=self.min_infeas_pop_size)

            self.pop, self.off = pop, pop

    def fit(self):

        proto = RBF

        X, F, G = self.pop.get("X", "F", "G")

        surr = ExactConstraintsSurrogate(model=None).initialize(self.problem)
        surr.surr_F = []
        for k in range(self.problem.n_obj):
            model = ModelSelection(proto).do(X, F[:, k])
            model.fit(X, F[:, k])
            surr.surr_F.append(model)

        return surr

    def step_sampling(self, n_sampling, cand):
        if cand is None:
            sampling_off = self.mating.do(self.problem, self.opt, n_sampling, algorithm=self)

        else:

            # this has been added after submitting the paper
            n_points = 200
            if np.random.random() < 1.0:
                mating = self.mating
                sampling_off = mating.do(self.problem, self.opt, n_points, algorithm=self)
            else:

                mating = DifferentialEvolutionMating("random")
                sampling_off = mating.do(self.problem, self.opt, n_points, algorithm=self)

            # sampling_off = LHS().do(self.problem, 200, algorithm=self)
            others = Population.merge(self.pop, cand)

            feas = self.problem.evaluate(sampling_off.get("X"), return_values_of=["feasible"],
                                         only_inexpensive_constraints=True)[:, 0]
            sampling_off = sampling_off[feas]

            if len(sampling_off) > 0:

                I = my_select_points_with_maximum_distance(self.problem, sampling_off.get("X"), others.get("X"),
                                                           n_sampling)
                sampling_off = sampling_off[I]
            else:
                return Population()

        sampling_off.set("strategy", "sampling")

        return sampling_off


parse_doc_string(SANSGA2.__init__)
