from pySOT.surrogate import RBFInterpolant, CubicKernel, LinearTail, TPSKernel, ConstantTail

from surrogate.model import Model


class RBF(Model):

    def __init__(self, kernel="cubic", tail="linear", **kwargs) -> None:
        super().__init__(**kwargs)
        self.kernel = kernel
        self.tail = tail

    def _fit(self, X, y, **kwargs):
        n, m = X.shape
        kernel, tail = get_kernel(self.kernel), get_tail(self.tail, m)
        self.model = RBFInterpolant(dim=m, kernel=kernel, tail=tail)
        self.model.add_points(X, y)

    def _predict(self, X, out):
        out["y"] = self.model.predict(X)

    @classmethod
    def hyperparameters(cls):
        return {
            "kernel": ["cubic", "tps"],
            "tail": ["linear"]
        }


def get_kernel(kernel):
    if kernel == "cubic":
        return CubicKernel()
    elif kernel == "tps":
        return TPSKernel()


def get_tail(tail, m):
    if tail == "linear":
        return LinearTail(m)
    elif tail == "constant":
        return ConstantTail(m)
