import torch
import torch.nn as nn
import torch.nn.functional as F
import PIL
from PIL import Image

import numpy as np
import torchvision
import matplotlib.pyplot as plt

import losses.visualization as vi
import cv2



__all__ = ['SPD']


class SpatialNorm(nn.Module):
    def __init__(self, divergence='kl'):
        if divergence == 'kl':
            self.criterion = nn.KLDivLoss()
        else:
            self.criterion = nn.MSELoss()

        self.norm = nn.Softmax(dim=-1)

    def forward(self, pred_S, pred_T):
        norm_S = self.norm(pred_S)
        norm_T = self.norm(pred_T)

        loss = self.criterion(pred_S, pred_T)
        return loss


class ChannelNorm(nn.Module):
    def __init__(self):
        super(ChannelNorm, self).__init__()

    def forward(self, featmap):
        n, c, h, w = featmap.shape
        featmap = featmap.reshape((n, c, -1))
        featmap = featmap.softmax(dim=-1)
        return featmap


class SPD(nn.Module):
    def __init__(self, s_channels, t_channels, norm_type='none',divergence='mse',temperature=1.0, num_class = 19):
        super(SPD, self).__init__()

        # define normalize function
        if norm_type == 'channel':
            self.normalize = ChannelNorm()
        elif norm_type == 'spatial':
            self.normalize = nn.Softmax(dim=1)
        elif norm_type == 'channel_mean':
            self.normalize = lambda x: x.view(x.size(0), x.size(1), -1).mean(-1)
        else:
            self.normalize = None
        self.norm_type = norm_type

        # define loss function
        if divergence == 'mse':
            self.criterion = nn.MSELoss(reduction='sum')
        elif divergence == 'kl':
            self.criterion = nn.KLDivLoss(reduction='sum')
            self.temperature = temperature
        self.divergence = divergence
        
        self.num_class = num_class

    def forward(self, preds_S, preds_T, target, ratios_d, ratios_a, matching = "sort"):

        #ratios_d = [2, 2, 2, 2]
        im_F = self.feature_discrimination(preds_T, target, ratios_d)
        #ratios_a = [2, 2, 2, 2]
        im_F_D = self.activation_ex(im_F, ratios_a)

        loss = self.distance_loss(im_F_D, preds_S, matching)


        return loss  #

    def activation_ex(self, preds_T, ratios):
        im_F = []
        for i in range(len(preds_T)):
            im_b_F = []
            norm_t = preds_T[i].norm(dim=[2, 3])
            for j, feature in enumerate(norm_t):
                index_sort = feature.sort(descending=True)[1][:len(feature) // ratios[i]]
                im_b_F.append(preds_T[i][j, index_sort, :, :].unsqueeze(0))
            im_b_F = torch.cat(im_b_F, dim=0)
            im_F.append(im_b_F)

        return im_F

    def feature_discrimination(self, im_F, target, ratios):
        im_F_D = []
        for i, F_T in enumerate(im_F):
            # resize target
            target_j = F.interpolate(target.unsqueeze(1).float(), (F_T.shape[2], F_T.shape[3]), mode='nearest').int()
            back_idx = (target_j == -1)
            target_j[back_idx] = self.num_class # city 19
            B_ijk = F.one_hot(target_j.to(torch.int64), num_classes=self.num_class + 1)    # city = 20
            B_ijk = B_ijk.transpose(1, 4)[:, :-1].view(B_ijk.shape[0], self.num_class, -1)    # city 19
            num_class_F = B_ijk.sum(2)
            num_class_F[num_class_F == 0] = 1.0

            P_ijk = F_T.view(F_T.shape[0], F_T.shape[1], -1)
            # M_l,c
            P_ijk = torch.matmul(P_ijk, B_ijk.transpose(1, 2).float())
            # a_l,c
            P_ijk = P_ijk / num_class_F.unsqueeze(1)
            indexer = torch.heaviside(P_ijk, torch.tensor(0.0))
            exp = torch.exp(P_ijk - P_ijk.max(2, keepdims=True)[0]) * indexer
            # p_l,c
            P_ijk = exp / exp.sum(2, keepdims=True)
            P_ijk[P_ijk == 0] = 1.0

            # H_l,c
            H_ij = P_ijk * torch.log(1.0 / P_ijk)
            #zero H eliminate
            H_ij[H_ij==0] = H_ij.max()[0]
            index_sort = H_ij.sum(2).sort(descending=False)[1][:, : F_T.shape[1] // ratios[i]]
            im_b_F = []
            for n in range(len(index_sort)):
                im_b_F.append(F_T[n, index_sort[n], :, :].unsqueeze(0))
            im_b_F = torch.cat(im_b_F, dim=0)
            im_F_D.append(im_b_F)


            

        return im_F_D

    def distance_loss(self, im_F_D, preds_S, sort = "sort"):

        loss = 0

        if sort == "sort":
            for i, F_T in enumerate(im_F_D):
                n, c, h, w = preds_S[i].shape
                F_loss = self.criterion(preds_S[i][:, :F_T.shape[1]], F_T) / (c * h * w)
                loss = loss + F_loss
        elif sort == "student_matching": # all student use; s:t=1:m
            for i, F_T in enumerate(im_F_D):
                n, c, h, w = preds_S[i].shape
                nt, ct, ht, wt = F_T.shape
                # corr = torch.matmul(preds_S[i].view(n,c,-1), F_T.view(n,c,-1).transpose(1,2)) # 128x128
                corr = torch.matmul(preds_S[i].view(n,c,-1), F_T.view(n,ct,-1).transpose(1,2)) # 128x128
                corr_idx = corr.sort(dim=2)[1][:,:,0]

                im_b_F = []
                for j in range(len(corr_idx)):
                    im_b_F.append(F_T[j, corr_idx[j],:,:].unsqueeze(0))
                im_b_F = torch.cat(im_b_F, dim=0)

                F_loss = self.criterion(preds_S[i][:, :F_T.shape[1]], im_b_F) / (c * h * w)
                loss = loss + F_loss
        elif sort == "teacher_matching": # all teacher use; s:t=m:1
            for i, F_T in enumerate(im_F_D):
                n, c, h, w = preds_S[i].shape
                nt, ct, ht, wt = F_T.shape
                # corr = torch.matmul(preds_S[i].view(n,c,-1), F_T.view(n,c,-1).transpose(1,2)) # 128x128
                corr = torch.matmul(preds_S[i].view(n,c,-1), F_T.view(n,ct,-1).transpose(1,2)) # 128x128
                corr_idx = corr.sort(dim=1)[1][:,:,0]

                im_b_F = []
                for j in range(len(corr_idx)):
                    im_b_F.append(preds_S[j, corr_idx[j],:,:].unsqueeze(0))
                im_b_F = torch.cat(im_b_F, dim=0)

                F_loss = self.criterion(F_T[i][:, :F_T.shape[1]], im_b_F) / (c * h * w)
                loss = loss + F_loss

        return loss

       

    
