import argparse
import logging
import os

import math
import torch
import torch.nn as nn
from torch.nn import DataParallel
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import  models
from torch.utils.tensorboard import SummaryWriter
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy

from backbones import get_model, resnet18, resnet50
from lr_scheduler import build_scheduler

from expression import *
from expression.models import SwinTransFER, ResnetFER

from utils.utils_callbacks import CallBackLogging
from utils.utils_config import get_config
from utils.utils_logging import AverageMeter, init_logging
from utils.setup_seed import setup_seed

assert torch.__version__ >= "1.9.0"

class RecorderMeter(object):
    pass

class LSR2(nn.Module):
    def __init__(self, e):
        super().__init__()
        self.log_softmax = nn.LogSoftmax(dim=1)
        self.e = e

    def _one_hot(self, labels, classes, value=1):
        one_hot = torch.zeros(labels.size(0), classes)
        labels = labels.view(labels.size(0), -1)
        value_added = torch.Tensor(labels.size(0), 1).fill_(value)
        value_added = value_added.to(labels.device)
        one_hot = one_hot.to(labels.device)
        one_hot.scatter_add_(1, labels, value_added)
        return one_hot

    def _smooth_label(self, target, length, smooth_factor):
        one_hot = self._one_hot(target, length, value=1 - smooth_factor)
        mask = (one_hot==0)
        balance_weight = torch.tensor([0.95124031, 4.36690391, 1.71143654, 0.25714585, 0.6191221, 1.74056738, 0.48617274]).to(one_hot.device)
        ex_weight = balance_weight.expand(one_hot.size(0),-1)
        resize_weight = ex_weight[mask].view(one_hot.size(0),-1)
        resize_weight /= resize_weight.sum(dim=1, keepdim=True)
        one_hot[mask] += (resize_weight*smooth_factor).view(-1)
        return one_hot.to(target.device)

    def forward(self, x, target):
        smoothed_target = self._smooth_label(target, x.size(1), self.e)
        x = self.log_softmax(x)
        loss = torch.sum(- x * smoothed_target, dim=1)
        return torch.mean(loss)
    

def main(args):
    cfg = get_config(args.config)
    setup_seed(seed=cfg.seed, cuda_deterministic=True)

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        print("There is no GPU available")
        return
    
    os.makedirs(cfg.output, exist_ok=True)
    init_logging(cfg.output)

    summary_writer = (
        SummaryWriter(log_dir=os.path.join(cfg.output, "tensorboard"))
    )

    expression_train_loader = get_analysis_train_dataloader("Expression", cfg)

    swin = get_model(cfg.network)
    model = SwinTransFER(swin, num_classes=7)

    torch.cuda.set_device(3)
    model.cuda(3)
    model.train()

    cfg.epoch_step = len(expression_train_loader)
    cfg.num_epoch = math.ceil(cfg.total_step / cfg.epoch_step)
    cfg.lr = cfg.lr * cfg.batch_size / 512.0
    cfg.warmup_lr = cfg.warmup_lr * cfg.batch_size / 512.0
    cfg.min_lr = cfg.min_lr * cfg.batch_size / 512.0

    if cfg.optimizer == "sgd":
        opt = torch.optim.SGD(
            params=[{"params": model.parameters(), 'lr': cfg.lr}],
            lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay)

    elif cfg.optimizer == "adamw":
        opt = torch.optim.AdamW(
            params=[{"params": model.parameters(), 'lr': cfg.lr}],
            lr=cfg.lr, weight_decay=cfg.weight_decay)
    else:
        raise

    lr_scheduler = build_scheduler(
        optimizer=opt,
        lr_name=cfg.lr_name,
        warmup_lr=cfg.warmup_lr,
        min_lr=cfg.min_lr,
        num_steps=cfg.total_step,
        warmup_steps=cfg.warmup_step
    )

    start_epoch = 0
    global_step = 0

    if cfg.init:
        dict_checkpoint = torch.load(os.path.join(cfg.init_model, f"swin.pt"))
        model.encoder.load_state_dict(dict_checkpoint["state_dict_backbone"], strict=True)
        del dict_checkpoint

    if cfg.resume:
        dict_checkpoint = torch.load(os.path.join(cfg.output, f"checkpoint_step_{cfg.resume_step}_gpu.pt"))
        start_epoch = dict_checkpoint["epoch"]
        global_step = dict_checkpoint["global_step"]
        local_step = dict_checkpoint["local_step"]
        print("continue training from checkpoint ...")
        if local_step == cfg.epoch_step - 1:
            start_epoch = start_epoch+1
            local_step = 0
        else:
            local_step += 1
        global_step += 1
        model.load_state_dict(dict_checkpoint["state_dict_model"])
        opt.load_state_dict(dict_checkpoint["state_optimizer"])
        lr_scheduler.load_state_dict(dict_checkpoint["state_lr_scheduler"])
        del dict_checkpoint
    
    for key, value in cfg.items():
        num_space = 25 - len(key)
        logging.info(": " + key + " " * num_space + str(value))
    
    expression_val_dataloader = get_analysis_val_dataloader("Expression", config=cfg)
    expression_verification = ExpressionVerification(data_loader=expression_val_dataloader, summary_writer=summary_writer)

    callback_logging = CallBackLogging(
        frequent=cfg.frequent,
        total_step=cfg.total_step,
        batch_size=cfg.batch_size,
        start_step=global_step,
        writer=summary_writer
    )

    criterion = nn.CrossEntropyLoss()
    loss_am = AverageMeter()
    meanAcc = 0.0
    acc1 = 0.0
    highest_mean_acc = 0.0
    highest_top1_acc = 0.0

    amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100)

    for epoch in range(start_epoch, cfg.num_epoch):
        for idx, data in enumerate(expression_train_loader):
            if cfg.resume:
                if idx < local_step:
                    continue
            global_img, label = data
            global_img = global_img.cuda(3, non_blocking=True)
            label = label.cuda(3, non_blocking=True)
            expression_output, hm4, hm2, attention_output = model(global_img, label)

            hm2 = F.interpolate(hm2, size=hm4.shape[-2:], mode='bilinear', align_corners=False)
            mse_loss = F.mse_loss(hm2, hm4)

            loss = LSR2(0.3)(expression_output, label) + 0.05*mse_loss + 0.01*criterion(attention_output, label)
            if cfg.fp16:
                amp.scale(loss).backward()
                amp.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
                amp.step(opt)
                amp.update()
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
                opt.step()

            opt.zero_grad()
            lr_scheduler.step_update(global_step)

            with torch.no_grad():
                loss_am.update(loss.item(), 1)

                callback_logging(global_step, loss_am, epoch, cfg.fp16,
                                 lr_scheduler._get_lr(global_step)[-1], amp)
                if (global_step+1) % cfg.verbose == 0:
                    meanAcc, acc1 = expression_verification(global_step, model, label)
            
            if cfg.save_all_states and (global_step+1) % cfg.save_verbose == 0:
                checkpoint = {
                    "epoch": epoch,
                    "global_step": global_step,
                    "local_step": idx,
                    "state_dict_model": model.state_dict(),
                    "state_optimizer": opt.state_dict(),
                    "state_lr_scheduler": lr_scheduler.state_dict()
                }
                torch.save(checkpoint, os.path.join(cfg.output, f"checkpoint_step_{global_step}_gpu.pt"))
            if meanAcc > highest_mean_acc:
                highest_mean_acc = meanAcc
                checkpoint = {
                    "epoch": epoch,
                    "global_step": global_step,
                    "local_step": idx,
                    "state_dict_model": model.state_dict(),
                    "state_optimizer": opt.state_dict(),
                    "state_lr_scheduler": lr_scheduler.state_dict()
                }
                torch.save(checkpoint, os.path.join(cfg.output, f"checkpoint_highest_mean_acc_gpu.pt"))
            if acc1 > highest_top1_acc:
                highest_top1_acc = acc1
                checkpoint = {
                    "epoch": epoch,
                    "global_step": global_step,
                    "local_step": idx,
                    "state_dict_model": model.state_dict(),
                    "state_optimizer": opt.state_dict(),
                    "state_lr_scheduler": lr_scheduler.state_dict()
                }
                torch.save(checkpoint, os.path.join(cfg.output, f"checkpoint_highest_top1_acc_gpu.pt"))

            if global_step >= cfg.total_step - 1:
                break
            else:
                global_step += 1
        
        if global_step >= cfg.total_step - 1:
            break
        if cfg.dali:
            expression_train_loader.reset()
            
    with torch.no_grad():
        _, _ = expression_verification(global_step, model, 0)

if __name__ == "__main__":
    torch.backends.cudnn.benchmark = True
    parser = argparse.ArgumentParser(
        description="Training in Pytorch")
    parser.add_argument("--config", type=str, help="py config file")
    main(parser.parse_args())
