import math
import random

from surrogate.partitioning.partitioning import Partitioning


class CrossvalidationPartitioning(Partitioning):

    def __init__(self,
                 X,
                 k_folds,
                 randomize=True,
                 n_sets=None):

        super().__init__(X, n_sets)
        self.k_folds = k_folds
        self.randomize = randomize

    def _do(self):
        n, _ = self.X.shape
        n_each = math.ceil(n / self.k_folds)

        indices = list(range(n))
        if self.randomize:
            random.shuffle(indices)

        slices = [indices[i * n_each:(i + 1) * n_each] for i in range(self.k_folds)]

        ret = []
        for k in range(self.k_folds):
            train = []
            [train.extend(slices[j]) for j in range(self.k_folds) if k != j]
            test = slices[k]

            ret.append((train, test))

        return ret
