import numpy as np

from surrogate.partitioning.partitioning import Partitioning


class RandomPartitioning(Partitioning):

    def __init__(self,
                 X,
                 n_train,
                 n_test=None,
                 **kwargs):

        super().__init__(X, **kwargs)
        self.n_train = n_train

        self.n_test = n_test
        if self.n_test is None:
            self.n_test = len(X) - self.n_train

        if self.n_train + self.n_test > len(X):
            raise Exception("Both sets together are supposed to be less than the overall size")

    def _do(self):
        n, _ = self.X.shape
        M = np.random.permutation(n)

        train = M[:self.n_train]

        test = M[self.n_train:]
        if self.n_test is not None:
            test = test[:self.n_test]

        return train, test
