import warnings
warnings.filterwarnings("ignore", category=UserWarning) 

import os, argparse, torch, json
from model import TextureCodec
from utils import get_mesh, save_mesh_as_obj, create_log_dir, adjust_lr, visualize_uvs
from losses import RenderLoss
import torch.optim as optim
from torchvision.utils import save_image

torch.autograd.set_detect_anomaly(True)
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
torch.cuda.set_device(device)

ap = argparse.ArgumentParser()
# Model Params
ap.add_argument('expt_name', type=str, help="name of experiment")
ap.add_argument('-m', '--mesh_path', type=str, help="path of mesh", default="clothdolls/model1.obj")
ap.add_argument('-s', '--scale', type=int, help="downsampling factor, only factors of 2 supported", default=4)
ap.add_argument('-ng', '--no_gcnn', action='store_false', help="use to disable graph cnn module")
ap.add_argument('-nw', '--no_warp', action='store_false', help="use to disable uv warping  module")

# Loss Params
# When computing the loss, camera locations are sampled randomly from the provided ranges
ap.add_argument('-ca','--camera_azim_range', nargs=2, type=float, default=[0,0], help="camera azimuthal angle range")
ap.add_argument('-ce','--camera_elev_range', nargs=2, type=float, default=[0,0], help="camera elevation angle range")
ap.add_argument('-cd','--camera_dist_range', nargs=2, type=float, default=[1,1], help="camera distance range")
ap.add_argument('-lw','--loss_weights', nargs=4, type=float, default=[200,10,0,1], \
            help="weights of l2, ssim, vgg, tvl losses respectively")

# Optimizer params
ap.add_argument('-st', '--steps', type=int, help="total number of optimization steps", default=100000)
ap.add_argument('-bs', '--batchsize', type=int, help="Number of views batched per step", default=4)
ap.add_argument('-lr', '--learning_rate', type=float, help="peak learning rate", default=0.0005)
ap.add_argument('-wu', '--warmup', type=int, help="number of steps to reach peak lr", default=5000)
ap.add_argument('-re', '--resume', action='store_true', help="whether to resume training")

# Logging params
ap.add_argument('-vf', '--val_freq', type=int, help="steps per validation", default=100)
ap.add_argument('-pf', '--print_freq', type=int, help="steps per logging losses", default=20)

args = ap.parse_args()
args.loss_weights = {'l2':args.loss_weights[0], 'ssim':args.loss_weights[1], 
    'vgg':args.loss_weights[2], 'tvl':args.loss_weights[3]}

log_dir, write_log = create_log_dir(expt_name = args.expt_name, resume = args.resume)

if args.resume:
    ckpt = torch.load(os.path.join(log_dir, "ckpt.pth"))
    args = ckpt['args']
    args.resume=True
    model_state = ckpt['model_state']
    opt_state = ckpt['opt_state']
    init_step = ckpt['step']+1
    best_psnr, best_ssim = ckpt['best_losses']
    write_log("Resuming Training")
else:
    write_log("Training Parameters:")
    for k,v in vars(args).items():
        write_log(f'\t{k}: {v}')
    write_log("\n")
    init_step = 0
    best_psnr = 0.
    best_ssim = 0.

mesh = get_mesh(mesh_path=args.mesh_path, device=device)
render_loss = RenderLoss(elev_range = args.camera_elev_range, \
    azim_range = args.camera_azim_range, \
    dist_range = args.camera_dist_range, \
    num_views_per_iteration=args.batchsize, device=device)
model = TextureCodec(mesh=mesh, scale=args.scale, use_gcnn = args.no_gcnn, use_uv_warp = args.no_warp).to(device)
optimizer = optim.AdamW(model.parameters(), betas=(0.5, 0.999))

if args.resume:
    model.load_state_dict(model_state)
    optimizer.load_state_dict(opt_state)

first_run=True
for step in range(init_step, args.steps):
    compressed_mesh, dsp_mesh = model()
    print_now = (step%args.print_freq == 0)
    
    cmp_loss = render_loss.compute_loss(mesh, compressed_mesh,'c', print_now)
    cmp_total_loss = sum([v * args.loss_weights[k] for k,v in cmp_loss.items()])

    dsp_loss = render_loss.compute_loss(mesh, dsp_mesh, 'd', print_now)
    dsp_total_loss = sum([v * args.loss_weights[k] for k,v in dsp_loss.items()])

    lr = adjust_lr(optimizer, lr=args.learning_rate, cur_step=step, total_steps=args.steps, warmup=args.warmup)
    optimizer.zero_grad()
    cmp_total_loss.backward()
    optimizer.step()

    cmp_loss = {k:round(v.cpu().item(),4) for k,v in cmp_loss.items()}
    dsp_loss = {k:round(v.cpu().item(),4) for k,v in dsp_loss.items()}

    if print_now:
        write_log(f'Step: {step:05d}:')
        write_log(f'    CompressionLoss: {cmp_total_loss.cpu().item():.3f}, DownsampleLoss: {dsp_total_loss.cpu().item():.3f}')
        write_log('    Compression Losses:', cmp_loss)
        write_log('    Downsampled Losses:', dsp_loss)
        visualize_uvs(torch.tensor([512,512], device=compressed_mesh.device), texture=compressed_mesh.textures)

    if first_run:
        dsp_val_psnr, dsp_val_ssim, orig_img, dsp_img = render_loss.validate(mesh, dsp_mesh)
        save_mesh_as_obj(os.path.join(log_dir,"downsampled_mesh.obj"), dsp_mesh)
        save_image(dsp_img, os.path.join(log_dir,"img_downsampled.png"))
        save_image(orig_img, os.path.join(log_dir,"img_original.png"))
        first_run = False

    if (step%args.val_freq==0 and step>0):
        cmp_val_psnr, cmp_val_ssim, _, cmp_img = render_loss.validate(mesh, compressed_mesh)
        write_log(f'Validation Compression PSNR: {cmp_val_psnr} ; Downsampling PSNR: {dsp_val_psnr}')
        write_log(f'Validation Compression SSIM: {cmp_val_ssim} ; Downsampling SSIM: {dsp_val_ssim}')
        if cmp_val_psnr > best_psnr:
            best_psnr = cmp_val_psnr
            write_log(f'New Best PSNR; Saving Mesh')
            save_mesh_as_obj(os.path.join(log_dir,"compressed_mesh_psnr.obj"), compressed_mesh)
            save_image(cmp_img, os.path.join(log_dir,"img_compressed_ssim.png"))
        if cmp_val_ssim > best_ssim:
            best_ssim = cmp_val_ssim
            write_log(f'New Best SSIM; Saving Mesh')
            save_mesh_as_obj(os.path.join(log_dir,"compressed_mesh_ssim.obj"), compressed_mesh)
            save_image(cmp_img, os.path.join(log_dir,"img_compressed_psnr.png"))
        save_mesh_as_obj(os.path.join(log_dir,"compressed_mesh_latest.obj"), compressed_mesh)
        save_image(cmp_img, os.path.join(log_dir,"img_compressed_latest.png"))
        model_pth = os.path.join(log_dir, "ckpt.pth")
        torch.save({'step': step, 
                    'model_state': model.state_dict(), 
                    'opt_state': optimizer.state_dict(), 
                    'args':args,
                    'best_losses':[best_psnr,best_ssim]}, model_pth)