from abc import abstractmethod

import numpy as np


class Normalization:

    def __init__(self) -> None:
        super().__init__()

    @abstractmethod
    def forward(self, X):
        pass

    @abstractmethod
    def backward(self, X):
        pass


class NoNormalization(Normalization):

    def __init__(self) -> None:
        super().__init__()

    def forward(self, X):
        return X

    def backward(self, X):
        return X


class ZeroToOneNormalization(Normalization):

    def __init__(self, xl=None, xu=None, estimate_bounds=True) -> None:
        super().__init__()
        self.xl = xl
        self.xu = xu
        self.estimate_bounds = estimate_bounds

    def forward(self, X):

        if self.estimate_bounds:
            if self.xl is None:
                self.xl = np.min(X, axis=0)
            if self.xu is None:
                self.xu = np.max(X, axis=0)

        xl, xu = self.xl, self.xu

        # if np.any(xl == xu):
        #     raise Exception("Normalization failed because lower and upper bounds are equal!")

        # calculate the denominator
        denom = xu - xl

        # we can not divide by zero -> plus small epsilon
        denom += 1e-32

        # normalize the actual values
        N = (X - xl) / denom

        return N

    def backward(self, X):
        return X * (self.xu - self.xl) + self.xl


class Standardization(Normalization):

    def __init__(self, mu=None, sigma=None, estimate_bounds=True) -> None:
        super().__init__()
        self.mu = mu
        self.sigma = sigma
        self.estimate_bounds = estimate_bounds

    def forward(self, X):

        if self.estimate_bounds:
            if self.mu is None:
                self.mu = np.mean(X, axis=0)
            if self.sigma is None:
                self.sigma = np.var(X, axis=0)

        mu, sigma = self.mu, self.sigma

        # normalize the actual values
        N = (X - mu) / sigma

        return N

    def backward(self, X):
        return (X * self.sigma) + self.mu
