try:
    import torch
    import torch.nn as nn
except:
    raise Exception("pytorch not found: 'pip install torch torchvision'")

import copy

import numpy as np

from surrogate.model import Model
from surrogate.partitioning.random import RandomPartitioning
from surrogate.util import misc


class DNN(Model):

    def __init__(self,
                 epochs=1000,
                 n_layers=2,
                 n_hidden=300,
                 lr=8e-4,
                 drop_rate=0.2,
                 trn_split=0.8,
                 best_vld=True,
                 pretrained=None,
                 device="cpu",
                 **kwargs
                 ):

        super().__init__(**kwargs)

        if device.startswith("cuda") and torch.cuda.is_available():
            self.device = "cuda:0"
        else:
            self.device = "cpu"

        self.epochs = int(epochs)
        self.n_hidden = n_hidden
        self.n_layers = n_layers
        self.lr = lr
        self.drop_rate = drop_rate
        self.trn_split = trn_split
        self.pretrained = pretrained
        self.best_vld = best_vld

        self.learning = None
        self.history = []

    def _fit(self, X, y, **kwargs):
        n, m = X.shape

        n_train = int(len(X) * self.trn_split)
        n_vld = len(X) - n_train
        train, vld = RandomPartitioning(X, n_train, n_vld).do()

        net = MLP(n_feature=m, n_layers=self.n_layers, n_hidden=self.n_hidden, drop=self.drop_rate)

        if self.pretrained is not None:
            init = torch.load(self.pretrained, map_location='cpu')
            net.load_state_dict(init)
            self.net = copy.deepcopy(net)

        else:
            net.apply(MLP.init_weights)
            optimizer = torch.optim.Adam(net.parameters(), lr=self.lr)

            self.learning = MLPLearning(net,
                                        self.device,
                                        self.epochs,
                                        optimizer,
                                        torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.epochs, eta_min=0),
                                        nn.SmoothL1Loss(),
                                        X,
                                        y,
                                        train,
                                        vld)

            while self.learning.has_next():
                self.learning.next()
                epoch, train_loss, vld_loss = self.learning.epoch, self.learning.train_loss, self.learning.vld_loss


                self.history.append(dict(epoch=epoch, train_loss=train_loss, vld_loss=vld_loss))

                if self.verbose and epoch % 100 == 0:
                    print("Epoch {:4d}: trn loss = {:.4E}, vld loss = {:.4E}".format(epoch, train_loss, vld_loss))

            if self.best_vld:
                self.net = self.learning.min_vld_model.to('cpu')
            else:
                self.net = self.learning.net.to('cpu')

    def _predict(self, X, out, **kwargs):
        data = torch.from_numpy(X).float()

        net = self.net
        net.eval()
        with torch.no_grad():
            data = data.to(self.device)
            pred = net(data)

        out["y"] = pred.cpu().detach().numpy()


class MLPLearning:

    def __init__(self,
                 net,
                 device,
                 n_max_epochs,
                 optimizer,
                 scheduler,
                 criterion,
                 X,
                 y,
                 train,
                 vld) -> None:
        super().__init__()

        self.net = net.to(device)
        self.X = torch.from_numpy(X).float().to(device)
        self.y = torch.from_numpy(y).float().to(device)

        self.device = device
        self.n_max_epochs = n_max_epochs
        self.train, self.vld = train, vld
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.criterion = criterion

        self.epoch = 0
        self.train_loss = None
        self.vld_loss = None

        self.min_vld_loss = float("inf")
        self.min_vld_model = None
        self.min_vld_epoch = None

    def next(self):
        X, y, train, vld, = self.X, self.y, self.train, self.vld
        net, criterion, scheduler, optimizer = self.net, self.criterion, self.scheduler, self.optimizer

        net.train()
        optimizer.zero_grad()

        y_hat = net(X[train])
        loss = criterion(y_hat, y[train])
        loss.backward()
        optimizer.step()
        self.train_loss = loss.item()

        net.eval()
        with torch.no_grad():
            y_hat = net(X[vld])
            loss = criterion(y_hat, y[vld])
        self.vld_loss = loss.item()

        scheduler.step()

        if self.vld_loss < self.min_vld_loss:
            self.min_vld_epoch = self.epoch
            self.min_vld_loss = self.vld_loss
            self.min_vld_model = copy.deepcopy(net)

        self.epoch += 1

    def has_next(self):
        return self.epoch < self.n_max_epochs


class MLP(nn.Module):

    def __init__(self, n_feature, n_layers, n_hidden, drop, n_output=1):
        super(MLP, self).__init__()

        self.stem = nn.Sequential(nn.Linear(n_feature, n_hidden), nn.ReLU())

        layers = []
        for _ in range(n_layers):
            layers.append(nn.Linear(n_hidden, n_hidden))
            layers.append(nn.ReLU())
        self.hidden = nn.Sequential(*layers)

        self.regressor = nn.Linear(n_hidden, n_output)  # output layer
        self.drop = nn.Dropout(p=drop)

    def forward(self, x):
        x = self.stem(x)
        x = self.hidden(x)
        x = self.drop(x)
        x = self.regressor(x)  # linear output
        return x

    @staticmethod
    def init_weights(m):
        if type(m) == nn.Linear:
            n = m.in_features
            y = 1.0 / np.sqrt(n)
            m.weight.data.uniform_(-y, y)
            m.bias.data.fill_(0)


def train_one_epoch(net, data, target, criterion, optimizer, device):
    net.train()
    optimizer.zero_grad()

    data, target = data.to(device), target.to(device)
    pred = net(data)
    loss = criterion(pred, target)
    loss.backward()
    optimizer.step()

    return loss.item()


def infer(net, data, target, criterion, device):
    net.process()

    with torch.no_grad():
        data, target = data.to(device), target.to(device)
        pred = net(data)
        loss = criterion(pred, target)

    return loss.item()


def validate(net, data, target, device):
    net.process()

    with torch.no_grad():
        data, target = data.to(device), target.to(device)
        pred = net(data)
        pred, target = pred.cpu().detach().numpy(), target.cpu().detach().numpy()

        rmse, rho, tau = misc.get_correlation(pred, target)

    # print("Validation RMSE = {:.4f}, Spearman's Rho = {:.4f}, Kendall’s Tau = {:.4f}".format(rmse, rho, tau))
    return rmse, rho, tau, pred, target
