import copy

import numpy as np
from sklearn.cluster import KMeans

from pycheapconstr.misc.surrogate import Surrogate
from pymoo.algorithms.nsga2 import NSGA2
from pymoo.algorithms.so_de import DifferentialEvolutionMating
from pymoo.docs import parse_doc_string
from pymoo.model.duplicate import DefaultDuplicateElimination
from pymoo.model.population import Population
from pymoo.operators.sampling.latin_hypercube_sampling import LHS
from pymoo.util.misc import norm_eucl_dist
from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting
from pymoo.util.normalization import normalize
from pymoo.util.roulette import RouletteWheelSelection
from pymoo.visualization.scatter import Scatter
from surrogate.models.rbf import RBF
from surrogate.util.normalization import ZeroToOneNormalization


def my_select_points_with_maximum_distance(problem, X, others, n_select, selected=[]):
    n_points, n_dim = X.shape

    # calculate the distance matrix
    D = norm_eucl_dist(problem, X, X)
    dist_to_others = norm_eucl_dist(problem, X, others)
    D = np.column_stack([D, dist_to_others])

    # if no selection provided
    if len(selected) == 0:
        selected = [dist_to_others.min(axis=1).argmax()]

    # create variables to store what selected and what not
    not_selected = [i for i in range(n_points) if i not in selected]

    # now select the points until sufficient ones are found
    while len(selected) < n_select:
        # find point that has the maximum distance to all others
        index_in_not_selected = D[not_selected][:, selected].min(axis=1).argmax()
        I = not_selected[index_in_not_selected]

        # add it to the selected and remove from not selected
        selected.append(I)
        not_selected = [i for i in range(n_points) if i not in selected]

    return selected


class SANSGA2(NSGA2):

    def __init__(self,
                 n_offsprings=5,
                 sampling=LHS(),
                 surr_exploit=0.7,
                 inexp_constraints=False,
                 # survival=RankAndCrowdingSurvival(nds=NonDominatedSorting(epsilon=0.001)),
                 **kwargs):

        super().__init__(n_offsprings=n_offsprings,
                         sampling=sampling,
                         **kwargs)

        self.default = False
        self.archive = Population()
        self.inexp_constraints = inexp_constraints
        self.surr_exploit = surr_exploit

    def setup(self, problem, **kwargs):
        super().setup(problem, **kwargs)
        self.pop_size = 11 * problem.n_var - 1

    def _initialize(self):
        super()._initialize()
        self.archive = Population.merge(self.archive, self.off)
        self.pop.set("strategy", "init")

    def _next(self):

        if self.default:
            super()._next()
            return

        n_surrogate = int(self.n_offsprings * self.surr_exploit)

        off = Population()

        cand = None

        if n_surrogate > 0:
            surr_off, cand, groups, S = self.step_surr(n_surrogate)
            off = Population.merge(off, surr_off)

        n_sampling = self.n_offsprings - n_surrogate

        if n_sampling > 0:
            sampling_off = self.step_sampling(n_sampling, cand)
            off = Population.merge(off, sampling_off)

        self.off = off
        self.evaluator.eval(self.problem, self.off, algorithm=self)
        self.archive = Population.merge(self.archive, self.off)

        merged = Population.merge(self.pop, self.off)
        self.pop = self.survival.do(self.problem, merged, int(self.n_offsprings * 10),
                                    algorithm=self)

        I = NonDominatedSorting().do(self.pop.get("F"), only_non_dominated_front=True)

        strat = self.off.get("strategy")
        exploit = np.where(strat == "surrogate")[0]
        explore = np.where(strat == "sampling")[0]

        visualize = False

        if visualize:

            sc = Scatter(legend=True)
            sc.add(self.pop.get("F"), label="pop", alpha=0.5)

            if cand is not None:
                for group in groups:
                    sc.add(cand[group].get("F"), label="cand", alpha=0.2)

                sc.add(cand.get("F")[S], color='purple', marker="x", s=70)

            if len(exploit) > 0:
                sc.add(self.off[exploit].get("F"), label="exploit", color="red")
            if len(explore) > 0:
                sc.add(self.off[explore].get("F"), label="explore", color="green")

            for ind in Population.merge(self.pop, self.off):
                if not ind.get("feasible"):
                    sc.add(ind.F, marker="x", color="black", alpha=0.3)

            sc.show()

    def step_sampling(self, n_sampling, cand):

        if cand is None:
            sampling_off = self.mating.do(self.problem, self.opt, n_sampling, algorithm=self)

        else:

            n_points = 200
            if np.random.random() < 1.0:
                mating = self.mating
                sampling_off = mating.do(self.problem, self.opt, n_points, algorithm=self)
            else:

                mating = DifferentialEvolutionMating("random")
                sampling_off = mating.do(self.problem, self.opt, n_points, algorithm=self)

            # sampling_off = LHS().do(self.problem, 200, algorithm=self)
            others = Population.merge(self.pop, cand)
            I = my_select_points_with_maximum_distance(self.problem, sampling_off.get("X"), others.get("X"),
                                                       n_sampling)
            sampling_off = sampling_off[I]

        sampling_off.set("strategy", "sampling")

        return sampling_off

    def step_surr(self, n_surrogate):

        surr = self.fit()

        n_surr_gen = 20
        n_surr_cand = 100

        cpy = copy.deepcopy(self)
        cpy.survival.nds = NonDominatedSorting(epsilon=0.001)
        cpy.problem = surr
        cpy.n_offsprings = n_surr_cand
        cpy.default = True
        cpy.display = None
        for k in range(n_surr_gen):
            cpy.next()

        cand = DefaultDuplicateElimination().do(cpy.opt, self.pop)

        if len(cand) <= n_surrogate:
            S = np.arange(len(cand))
            groups = None
        else:

            ideal = cpy.opt.get("F").min(axis=0)
            nadir = cpy.opt.get("F").max(axis=0) + 1e-16
            vals = normalize(cand.get("F"), ideal, nadir)

            kmeans = KMeans(n_clusters=n_surrogate, random_state=0).fit(vals)
            groups = [[] for _ in range(n_surrogate)]
            for k, i in enumerate(kmeans.labels_):
                groups[i].append(k)

            S = []

            for group in groups:
                fitness = cand[group].get("crowding").argsort()
                selection = RouletteWheelSelection(fitness, larger_is_better=False)
                I = group[selection.next()]
                S.append(I)

        opt = cand[S]

        return Population.new(X=opt.get("X"), F_hat=opt.get("F"), CV_hat=opt.get("CV"),
                              strategy="surrogate"), cand, groups, S

    def fit(self):

        model = RBF(norm_X=ZeroToOneNormalization(*self.problem.bounds()))
        surr = Surrogate(model=model).initialize(self.problem)
        surr.fit(self.pop)

        return surr


parse_doc_string(SANSGA2.__init__)
