import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import init
import numpy as np

from timm.models.layers import trunc_normal_
from torch.autograd import Variable


class SwinTransFER(torch.nn.Module):
    def __init__(self, swin, swin_num_features=768, num_classes=7) -> None:
        super().__init__()
        self.encoder = swin
        self.num_classes = num_classes

        self.norm = nn.LayerNorm(swin_num_features)
        self.norm2 = nn.LayerNorm(384)
        self.avgpool = nn.AdaptiveAvgPool1d(1)

        self.cls_head = nn.Linear(swin_num_features, num_classes)
        self.att_head = nn.Linear(swin_num_features, num_classes)
        self.projection_head = nn.Sequential(
            nn.Linear(swin_num_features, swin_num_features),
            nn.ReLU(),
            nn.Linear(swin_num_features, swin_num_features)
        )
        self.conv1x1 = nn.Conv2d(384, 768, kernel_size=1, stride=1, padding=0)

    def get_cos_sim_mat(self, f: torch.Tensor, scale=1.0) -> torch.Tensor:
        """

        Args:
            f (torch.Tensor): feature

        Returns:
            torch.Tensor: scaled cosine similarity matrix
        """
        # feature to z
        z = self.projection_head(f)

        # dot product sim mat
        sim_mat = torch.matmul(z, z.T)

        norms = torch.norm(z, dim=1)
        norms = norms.view(1, -1)

        # cosine similarity matrix
        cos_sim_mat = sim_mat / torch.matmul(norms.T, norms)

        # scaling
        return cos_sim_mat / scale
    
    def self_masking_matrix(
        self, matrix: torch.Tensor, replace: float = -10e6
    ) -> torch.Tensor:
        """

        Args:
            matrix (torch.Tensor): matrix whose diagonal will be -10e6

        Returns:
            torch.Tensor: self-masked similarity matrix
        """
        batch_size = matrix.shape[0]
        device = matrix.device
        data_type = matrix.dtype

        idx = np.diag_indices(batch_size)
        matrix[idx[0], idx[1]] = (
            (replace * torch.ones(batch_size, dtype=data_type))
            .to(device)
            .detach()
        )
        return matrix

    def get_att_mat(
        self, matrix: torch.Tensor, self_masking: bool
    ) -> torch.Tensor:
        """

        Args:
            matrix (torch.Tensor): scaled cosine similarity matrix
            self_masking (bool): is self masked?

        Returns:
            torch.Tensor: attention matrix applied softmax by row
        """

        # self masking
        if self_masking:
            matrix = self.self_masking_matrix(matrix)

        # get attention matrix
        attention_matrix = nn.Softmax(dim=1)(matrix)
        return attention_matrix

    def enhance_sim_mat(self, cos_sim_mat: torch.Tensor, labels: torch.Tensor, same_class_factor: float) -> torch.Tensor:
        """
        Args:
            cos_sim_mat (torch.Tensor): cosine similarity matrix
            labels (torch.Tensor): class labels for each sample
            same_class_factor (float): factor to enhance similarity for same class
            diff_class_factor (float): factor to reduce similarity for different classes

        Returns:
            torch.Tensor: enhanced similarity matrix
        """
        batch_size = cos_sim_mat.shape[0]

        # Create a mask where each element [i, j] is True if labels[i] == labels[j], else False
        mask = labels.unsqueeze(0) == labels.unsqueeze(1)

        # Apply the same class factor and different class factor
        cos_sim_mat = torch.where(mask, cos_sim_mat * same_class_factor, cos_sim_mat)

        return cos_sim_mat
    
    def forward(self, x, labels):
        if self.training:
            x, x2 = self.encoder.forward_features(x) # x:[32, 49, 768] x2:[32, 196, 384]
            x = self.norm(x)  # B L C
            x2 = self.norm2(x2)  
            feature = self.avgpool(x.transpose(1, 2))  # B C 1
            feature = torch.flatten(feature, 1)          
            origin_output = self.cls_head(feature)

            # cosine similarity matrix
            cos_sim_mat = self.get_cos_sim_mat(feature)
            cos_sim_mat = self.enhance_sim_mat(cos_sim_mat, labels=labels, same_class_factor=1.5)
            # cosine similarity matrix to attention matrix
            att_mat = self.get_att_mat(cos_sim_mat, True)
            # attention_feature
            attention_feature = torch.matmul(att_mat, feature)
            attention_output = self.att_head(attention_feature)

            fc_weights = self.cls_head.weight
            fc_weights = fc_weights.view(1, self.num_classes, 768, 1, 1)
            fc_weights = Variable(fc_weights, requires_grad = False)

            # attention
            B, L, C = x.shape
            feat = x.transpose(1, 2).view(B, 1, C, 7, 7) # N * 1 * C * H * W
            hm = feat * fc_weights
            hm = hm.sum(2) # N * self.num_labels * H * W

            B2, L2, C2 = x2.shape # [32, 196, 384]
            feat2 = x2.transpose(1, 2).view(B2, C2, 14, 14)
            feat2 = self.conv1x1(feat2) 
            feat2 = feat2.unsqueeze(1)
            hm2 = feat2 * fc_weights
            hm2 = hm2.sum(2)

            return origin_output, hm, hm2, attention_output
        else:
            x = self.encoder.forward_features(x)
            x = self.norm(x)
            feature = self.avgpool(x.transpose(1, 2))
            feature = torch.flatten(feature, 1)
            output = self.cls_head(feature)
            return output

class ResnetFER(nn.Module):
    def __init__(self, resnet, num_classes = 7):
        super().__init__()
        self.encoder = resnet
        self.num_classes = num_classes
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.num_ftrs = self.encoder.fc.in_features
        self.cls_head = nn.Linear(self.num_ftrs, num_classes)
        self.att_head = nn.Linear(self.num_ftrs, num_classes)
        self.projection_head = nn.Sequential(
            nn.Linear(self.num_ftrs, self.num_ftrs),
            nn.ReLU(),
            nn.Linear(self.num_ftrs, self.num_ftrs)
        )
        self.conv1x1 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)

    def get_cos_sim_mat(self, f: torch.Tensor, scale=1.0) -> torch.Tensor:
        """

        Args:
            f (torch.Tensor): feature

        Returns:
            torch.Tensor: scaled cosine similarity matrix
        """
        # feature to z
        z = self.projection_head(f)

        # dot product sim mat
        sim_mat = torch.matmul(z, z.T)

        norms = torch.norm(z, dim=1)
        norms = norms.view(1, -1)

        # cosine similarity matrix
        cos_sim_mat = sim_mat / torch.matmul(norms.T, norms)

        # scaling
        return cos_sim_mat / scale
    
    def self_masking_matrix(
        self, matrix: torch.Tensor, replace: float = -10e6
    ) -> torch.Tensor:
        """

        Args:
            matrix (torch.Tensor): matrix whose diagonal will be -10e6

        Returns:
            torch.Tensor: self-masked similarity matrix
        """
        batch_size = matrix.shape[0]
        device = matrix.device
        data_type = matrix.dtype

        idx = np.diag_indices(batch_size)
        matrix[idx[0], idx[1]] = (
            (replace * torch.ones(batch_size, dtype=data_type))
            .to(device)
            .detach()
        )
        return matrix

    def get_att_mat(
        self, matrix: torch.Tensor, self_masking: bool
    ) -> torch.Tensor:
        """

        Args:
            matrix (torch.Tensor): scaled cosine similarity matrix
            self_masking (bool): is self masked?

        Returns:
            torch.Tensor: attention matrix applied softmax by row
        """

        # self masking
        if self_masking:
            matrix = self.self_masking_matrix(matrix)

        # get attention matrix
        attention_matrix = nn.Softmax(dim=1)(matrix)
        return attention_matrix

    def enhance_sim_mat(self, cos_sim_mat: torch.Tensor, labels: torch.Tensor, same_class_factor: float) -> torch.Tensor:
        """
        Args:
            cos_sim_mat (torch.Tensor): cosine similarity matrix
            labels (torch.Tensor): class labels for each sample
            same_class_factor (float): factor to enhance similarity for same class
            diff_class_factor (float): factor to reduce similarity for different classes

        Returns:
            torch.Tensor: enhanced similarity matrix
        """
        batch_size = cos_sim_mat.shape[0]

        # Create a mask where each element [i, j] is True if labels[i] == labels[j], else False
        mask = labels.unsqueeze(0) == labels.unsqueeze(1)

        # Apply the same class factor and different class factor
        cos_sim_mat = torch.where(mask, cos_sim_mat * same_class_factor, cos_sim_mat)

        return cos_sim_mat
    
    # def forward(self, x, labels):
        if self.training:
            x4, x2 = self.encoder.forward_features(x)
            feature = self.avgpool(x4)
            feature = torch.flatten(feature, 1)
            origin_output = self.cls_head(feature)
            # cosine similarity matrix
            cos_sim_mat = self.get_cos_sim_mat(feature)
            cos_sim_mat = self.enhance_sim_mat(cos_sim_mat, labels=labels, same_class_factor=1.5)
            # cosine similarity matrix to attention matrix
            att_mat = self.get_att_mat(cos_sim_mat, True)
            # attention_feature
            attention_feature = torch.matmul(att_mat, feature)
            attention_output = self.att_head(attention_feature)

            fc_weights = self.cls_head.weight
            fc_weights = fc_weights.view(1, 7, 512, 1, 1)
            # fc_weights = fc_weights.view(1, 8, 512, 1, 1)
            fc_weights = Variable(fc_weights, requires_grad = False)
            # CAM
            feat4 = x4.unsqueeze(1) # N * 1 * C * H * W
            hm4 = feat4 * fc_weights
            hm4 = hm4.sum(2) # N * self.num_labels * H * W

            x2_adjusted = self.conv1x1(x2)
            feat2 = x2_adjusted.unsqueeze(1)  # [N, 1, 512, H1, W1]
            hm2 = feat2 * fc_weights
            hm2 = hm2.sum(2)
            return origin_output, hm4, hm2, attention_output
        else:
            x = self.encoder.forward_features(x)
            x = self.avgpool(x)
            feature = torch.flatten(x, 1)
            output = self.cls_head(feature)
            return output

    def forward(self, x):
        if self.training:
            x4, x2 = self.encoder.forward_features(x)
            feature = self.avgpool(x4)
            feature = torch.flatten(feature, 1)
            origin_output = self.cls_head(feature)

            fc_weights = self.cls_head.weight
            fc_weights = fc_weights.view(1, 7, 512, 1, 1)
            fc_weights = Variable(fc_weights, requires_grad = False)

            feat4 = x4.unsqueeze(1) # N * 1 * C * H * W
            hm4 = feat4 * fc_weights
            hm4 = hm4.sum(2)

            x2_adjusted = self.conv1x1(x2)
            feat2 = x2_adjusted.unsqueeze(1) # N * 1 * C * H * W
            hm2 = feat2 * fc_weights
            hm2 = hm2.sum(2)
            return origin_output, hm4, hm2
        else:
            x = self.encoder.forward_features(x)
            x = self.avgpool(x)
            feature = torch.flatten(x, 1)
            output = self.cls_head(feature)
            return output
