import numpy as np
import scipy
from scipy.optimize import golden

from pymoo.algorithms.so_pso import PSO
from pymoo.model.problem import Problem
from pymoo.optimize import minimize
from pymoo.util.misc import vectorized_cdist
from surrogate.custom.kernel import Gaussian, squared_dist
from surrogate.custom.lsq import LSQQR, LSQ
from surrogate.model import Model


def calc_center_sigma(X, centers, n_neighbors=10):
    D = squared_dist(X, centers)
    np.fill_diagonal(D, np.inf)
    D.sort(axis=1)
    D = D[:, :-1]

    if n_neighbors is not None:
        D = D[:, :n_neighbors]

    sigma = D.mean(axis=1)
    return sigma


class RBF(Model):

    def __init__(self,
                 alpha=1e-16,
                 kernel=Gaussian(1.0),
                 optimize=False,
                 **kwargs):

        super().__init__(**kwargs)
        self.weights = None
        self.alpha = alpha
        self.kernel = kernel
        self.centers = None
        self.optimize = optimize


    def _fit(self,
             X,
             y,
             centers=None,
             calc_error_loo=False,
             calc_error_gcv=False,
             calc_mle=False,
             **kwargs):

        self.centers = centers
        if self.centers is None:
            self.centers = X

        K = self.kernel.calc(X, Y=self.centers)
        if self.alpha is not None:
            K += np.eye(len(K)) * self.alpha

        if calc_error_loo or calc_error_gcv or calc_mle:

            self.lsq = LSQQR(calc_A_inv=True)
            self.lsq.fit(K, y)

            K_inv = self.lsq.A_inv @ self.lsq.rhs
            beta = self.lsq.beta[:, 0]

            if calc_error_loo:
                self.v_e_loo = -  beta / np.diag(K_inv)
                self.e_loo = (self.v_e_loo ** 2).mean()

            if calc_error_gcv:
                self.v_e_gcv = - beta / np.diag(K_inv).mean()
                self.e_gcv = (self.v_e_gcv ** 2).mean()

            if calc_mle:
                eig, _ = np.linalg.eig(K)
                self.mle = (y.T @ beta) + np.log(eig).mean()

        else:
            self.lsq = LSQ()
            self.lsq.fit(K, y)

    def _predict(self, X, out, **kwargs):
        H = self.kernel.calc(X, Y=self.centers)
        out["y"] = self.lsq.predict(H)

    def _optimize(self, **kwargs):
        if not self.optimize:
            return

        X, y = self.X, self.y

        # centers_sigma = calc_center_sigma(X, self.centers)
        # centers_sigma = np.ones(len(X))

        def fun(sigma):
            rbf = RBF(alpha=self.alpha, kernel=Gaussian(sigma=sigma))
            rbf.fit(X, y, calc_error_loo=True)
            return rbf.e_loo

        def fun_loo(sigma):
            rbf = RBFLOO(alpha=self.alpha, kernel=Gaussian(sigma=sigma))
            rbf.fit(X, y)
            return rbf.e_loo
            # return rbf.e_loo, rbf.v_e_loo

        #
        # def cond(sigma):
        #     self.kernel = Gaussian(sigma=sigma)
        #     self._fit(X, y)
        #     rcond = np.linalg.cond(self.lsq.R)
        #     return rcond
        #
        # def singularity(sigma):
        #     kernel = Gaussian(sigma=sigma)
        #     K = kernel.calc(X, y)
        #     D = vectorized_cdist(K, K, fill_diag_with_inf=True)
        #     return D.min()
        #
        # class MyBoundaryProblem(Problem):
        #
        #     def __init__(self):
        #         super().__init__(n_var=1, n_obj=1, n_constr=0, xl=1e-4, xu=10, type_var=np.double,
        #                          elementwise_evaluation=True)
        #
        #     def _evaluate(self, x, out, *args, **kwargs):
        #         sigma = x[0]
        #         # tmp = fun(sigma)
        #         out["F"] = fun_loo(sigma)
        #         # out["G"] = 1e-4 - singularity(sigma)
        #
        # problem = MyBoundaryProblem()
        # ret = minimize(problem, PSO(return_least_infeasible=True))
        # sigma_L = ret.X
        #
        # problem.coeff = -1
        # ret = minimize(problem, PSO(return_least_infeasible=True))
        # sigma_U = ret.X
        #
        # model = self

        sigma_L, sigma_U = 1e-4, 10

        sigma = golden(fun_loo, brack=(sigma_L, sigma_U))

        # class MyProblem(Problem):
        #
        #     def __init__(self, ):
        #         super().__init__(n_var=1, n_obj=1, n_constr=0, xl=sigma_L, xu=sigma_U, type_var=np.double,
        #                          elementwise_evaluation=True)
        #
        #     def _evaluate(self, x, out, *args, **kwargs):
        #         sigma = x[0]
        #         out["F"] = fun_loo(sigma)
        #
        # pop_size = 25
        # bias = np.linspace(sigma_L, sigma_U, pop_size)[:, None]
        #
        # ret = minimize(MyProblem(), PSO(pop_size=pop_size, sampling=bias))
        #
        # sigma = ret.X
        #
        # sigma = min(10.0, sigma)

        print("RBF sigma:", sigma, "ERROR", fun_loo(sigma))

        # print("RBF sigma:", sigma, "error", ret.F, ret.opt.get("error"))
        self.kernel = Gaussian(sigma=sigma)
        self._fit(X, y, **kwargs)

        # if sigma > 5.0:
        if True:
            # golden(fun_loo, brack=(sigma_L, sigma_U))

            import matplotlib.pyplot as plt

            val = np.linspace(sigma_L, max(sigma, sigma_U), 1000)

            plt.plot(val, [fun_loo(v) for v in val], label="slow")
            plt.plot(val, [fun(v) for v in val], color="green", alpha=0.4, label="fast")
            plt.scatter([sigma], [fun_loo(sigma)], color="red", s=40, marker="*")
            # plt.xscale("log")
            plt.yscale("log")
            plt.title(f"CV LOO: {sigma}")
            plt.legend()
            plt.show()





class RBFEquation(Model):

    def __init__(self,
                 alpha=1e-16,
                 kernel=Gaussian(1.0),
                 **kwargs):
        super().__init__(**kwargs)
        self.weights = None
        self.alpha = alpha
        self.kernel = kernel

    def _fit(self, X, y, **kwargs):
        K = self.kernel.calc(X)

        if self.alpha is not None:
            K = K + np.eye(len(K)) * self.alpha

        K_inv = np.linalg.inv(K)
        self.c = K_inv @ y

        self.v_e_loo = - (self.c[:, 0]) / np.diag(K_inv)
        self.e_loo = (self.v_e_loo ** 2).mean()

        self.v_e_gcv = - (self.c[:, 0]) / np.diag(K_inv).mean()
        self.e_gcv = (self.v_e_gcv ** 2).mean()

        self.mle = (y * self.c).sum() + np.log(np.linalg.eig(K)[0]).mean()

    def _predict(self, X, out, **kwargs):
        out["y"] = self.kernel.calc(X, Y=self.X) @ self.c


class RBFLOO(Model):

    def __init__(self,
                 alpha=1e-16,
                 kernel=Gaussian(1.0),
                 **kwargs):
        super().__init__(**kwargs)
        self.weights = None
        self.alpha = alpha
        self.kernel = kernel

    def _fit(self, X, y, **kwargs):
        self.v_e_loo = np.zeros(len(y))

        for k in range(len(X)):
            b = np.full(len(X), True)
            b[k] = False

            rbf = RBF(alpha=self.alpha, kernel=self.kernel)
            rbf.fit(X[b], y[b])

            y_hat = rbf.predict(X[~b])[0]

            self.v_e_loo[k] = (y_hat - y[~b])

        self.e_loo = (self.v_e_loo ** 2).mean()

    def _predict(self, X, out, **kwargs):
        pass
