from copy import deepcopy

import numpy as np

from pymoo.model.problem import Problem


class Surrogate(Problem):

    def __init__(self, model=None, CV=False, callback=None):

        super().__init__()
        self.model = model
        self.problem = None
        self.surr_F = None
        self.surr_G = None

        self.n_obj = None
        self.n_constr = None

        self.CV = CV

        self.my_callback = callback

    def initialize(self, problem):
        super().__init__(problem.n_var, problem.n_obj, problem.n_constr, problem.xl, problem.xu)
        self.problem = problem

        self.surr_F = []
        self.surr_G = []

        if self.model is not None:
            # self.model.norm_X = ZeroToOneNormalization(xl=problem.xl, xu=problem.xu)

            for k in range(self.n_obj):
                self.surr_F.append(deepcopy(self.model))
            for k in range(self.n_constr):
                self.surr_G.append(deepcopy(self.model))

        return self

    def _evaluate(self, X, out, *args, **kwargs):
        F, G = [], []

        for k in range(self.n_obj):
            f_hat = self.surr_F[k].predict(X)
            F.append(f_hat)

        out["F"] = np.column_stack(F)

        if self.n_constr > 0:
            for k in range(self.n_constr):
                g_hat = self.surr_G[k].predict(X)
                G.append(g_hat)

            out["G"] = np.column_stack(G)

        if hasattr(self.problem, '_post'):
            self.problem._post(X, out)

    def fit(self, pop, **kwargs):
        X, F, G, CV = pop.get("X", "F", "G", "CV")

        if self.CV:
            G = CV

        for k in range(self.problem.n_obj):
            self.surr_F[k].fit(X, F[:, k], **kwargs)

        for k in range(self.problem.n_constr):
            self.surr_G[k].fit(X, G[:, k], **kwargs)

    def optimize(self, pop, **kwargs):
        X, F, G, CV = pop.get("X", "F", "G", "CV")

        if self.CV:
            G = CV

        for k in range(self.n_obj):
            self.surr_F[k].optimize(X, F[:, k], **kwargs)

        for k in range(self.n_constr):
            self.surr_G[k].optimize(X, G[:, k], **kwargs)

