import numpy as np

from surrogate.model import Model


def calc_kernel_matrix(A, B, func, theta):
    D = np.repeat(A, B.shape[0], axis=0) - np.tile(B, (A.shape[0], 1))
    K = func(D, theta)
    return np.reshape(K, (A.shape[0], B.shape[0]))


def gaussian(D, sigma):
    return np.exp(-(np.linalg.norm(D, axis=1) ** 2) / (2 * sigma))


class RBFLocal(Model):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.weights = None

    def _fit(self, X, y, target=None):
        if target is None:
            raise Exception("This is an implementation of an RBF to predict a single point. "
                            "Please provide the target index.")

        vld = target
        if not isinstance(vld, list):
            vld = [vld]
        trn = [k for k in range(len(X)) if k not in target]

        sigma = 1
        H = calc_kernel_matrix(X[vld], X[trn], gaussian, sigma)
        A_inv = np.linalg.pinv(H.T @ H)
        w = A_inv @ H.T @ y[vld]

        P = np.eye(len(vld)) - H @ A_inv @ H.T
        e = (y[vld].T @ P ** 2 @ y[vld])[0, 0]

        p = len(vld)
        e_loo = (1 / p) * (y[vld].T @ P @ np.eye(len(P)) * (np.diag(P) ** -2) @ P @ y[vld])[0, 0]
        e_gcv = ((p * y[vld].T @ P ** 2 @ y[vld]) / np.trace(P) ** 2)[0, 0]

        # self.weights = np.dot(np.linalg.pinv(H), Y)

        self.weights = w

    def _predict(self, X):
        H = calc_kernel_matrix(X, self.X, gaussian, 1)
        return H @ self.weights
