import numpy as np
from scipy.stats import spearmanr, kendalltau
from sklearn.metrics import mean_squared_error, r2_score


def calc_metric(metric, y, y_hat):
    if metric == "mse":
        return mean_squared_error(y, y_hat)
    elif metric == "mae":
        return np.mean(np.abs(y - y_hat))
    elif metric == "r2":
        return r2_score(y, y_hat)
    elif metric == "corr-spear":
        return spearmanr(y, y_hat).correlation
    elif metric == "corr-kendall":
        return kendalltau(y, y_hat).correlation
    else:
        raise Exception("Metric is not known.")


def check_equal_shape(a, b):
    assert a.shape == b.shape


def mse(y, y_hat):
    check_equal_shape(y, y_hat)
    return np.power(y - y_hat, 2).mean()
