import torch
import torch.nn.functional as F
from pytorch3d.structures import Meshes
from pytorch3d.renderer import TexturesUV
from pytorch3d.io import load_objs_as_meshes, save_obj
from pytorch3d.ops import interpolate_face_attributes
from functools import partial

import numpy as np
from datetime import datetime as dt
import shutil as sh
import os, glob, math, cv2

def adjust_lr(optimizer, lr, cur_step, total_steps, warmup):
    lr_mult = 0.5 * (math.cos(math.pi * (cur_step - warmup)/ (total_steps - warmup)) + 1.0)
    if cur_step < warmup:
        lr_mult = 0.1 + 0.9 * cur_step / warmup
    for i, param_group in enumerate(optimizer.param_groups):
        param_group['lr'] = lr * lr_mult
    return lr * lr_mult



def create_log_dir(expt_name:str, resume:bool = False):
    log_path = os.path.join("out_out", expt_name)
    if (expt_name=='debug' or expt_name=='test') and os.path.exists(log_path):
        sh.rmtree(log_path)
    if os.path.exists(log_path) and not resume:
        d = dt.now()
        dstring = f'{d.month:02d}{d.day:02d}_{d.hour:02d}{d.minute:02d}{d.second:02d}'
        log_path += dstring
    if not resume:
        os.makedirs(log_path)
        bck_path = os.path.join(log_path,'bck')
        os.makedirs(bck_path)
        for f in glob.glob("*.py"): sh.copy2(f, bck_path)
    logfile = os.path.join(log_path, "train.log")
    def write_to_file(logfile, *args):
        out_str = ""
        for arg in args: out_str += str(arg)+" "
        print(out_str)
        f = open(logfile,'a')
        f.write(out_str+"\n")
        f.close()
    writelog = partial(write_to_file, logfile)
    return log_path, writelog



def sample_texture(texture:torch.Tensor, mesh:Meshes):
    _,C,h,w = texture.size()
    verts = mesh.verts_packed() #[V,3]
    V = verts.size(0)
    faces = mesh.faces_packed() #[F,3]
    K = faces.size(0)
    face_uv_idxs = mesh.textures.faces_uvs_padded().squeeze() #[F,3] -|
    verts_uvs = mesh.textures.verts_uvs_padded().squeeze() #[UV,2]  <-|
    align_corners = mesh.textures.align_corners

    #Converting u,v values from .obj specification to format required for grid_sample 
    verts_uvs = 2.*verts_uvs - 1. #[0,1]->[-1,1]
    verts_uvs[:,1] *= -1. # .obj has -1,-1 as bottom left, grid sample expects -1,-1  at top left

    verts_uvs = verts_uvs[face_uv_idxs.flatten(),:] #[f*3, 2]
    verts_features = F.grid_sample(texture, verts_uvs[None,None,:,:], align_corners=align_corners).squeeze().transpose(1,0) #[f*3,C]
    # verts_idxs = torch.stack([faces.flatten()]*C, dim=-1) #[f*3,C]

    sampled_texture = torch.zeros((V,C)).to(texture)
    sampled_count = torch.zeros((V,C)).to(texture)

    # index_add_ does sampled_texture[verts_idx[i]][:] += verts_features[i][:] in the correct way when there are idxs repeat
    sampled_texture.index_add_(0, faces.flatten(), verts_features)
    sampled_count.index_add_(0, faces.flatten(), torch.ones_like(verts_features))
    return sampled_texture/sampled_count



def texture_to_fragments(size:torch.Tensor, texture:TexturesUV):
    # size in [w,h]
    verts_uvs = texture.verts_uvs_padded().squeeze(0) #[V,2]   <-|
    face_uv_idxs = texture.faces_uvs_padded().squeeze(0) #[F,3] _|
    K = face_uv_idxs.size(0)
    # verts_uvs[:,1] = 1.-verts_uvs[:,1]
    verts_uvs = verts_uvs*(size-1.)+ 0.5 #+0.5 for align_corners=True
    face_uvs = verts_uvs[face_uv_idxs.flatten(),:].reshape(K,3,2).permute(0,2,1) #[K,2,3]
    face_max = torch.max(face_uvs,dim=2).values.ceil().long()
    face_min = torch.min(face_uvs,dim=2).values.floor().long()

    w,h = size
    pix_to_face = torch.ones([h,w]).to(size).long()*-1
    bary_coords = torch.ones([h,w,3]).to(size)*-1.

    for i in range(K):
        face = face_uvs[i,:,:] #[2,3]
        xmin, ymin = face_min[i,:]
        xmax, ymax = face_max[i,:]
        idx_x,idx_y = torch.meshgrid(torch.arange(xmin,xmax, device=xmin.device), torch.arange(ymin,ymax,device=ymin.device), indexing='ij')
        search_space = torch.stack([idx_x,idx_y], dim=-1).view(-1,2) #[N,2]
        N = search_space.size(0)

        face = face.repeat(N,1,1) #[N,2,3]
        v01 = face[:,:,1] - face[:,:,0] #[N,2]
        v12 = face[:,:,2] - face[:,:,1] #[N,2]
        area = torch.linalg.det(torch.stack([v01[0,:], v12[0,:]], dim=-1))
        # v20 = face[:,:,0] - face[:,:,2]

        vp0 = (search_space + 0.5) - face[:,:,0] #[N,2]
        vp1 = (search_space + 0.5) - face[:,:,1] #[N,2]

        bc2 = torch.linalg.det(torch.stack([v01, vp0], dim=-1))/area
        bc0 = torch.linalg.det(torch.stack([v12, vp1], dim=-1))/area
        bc1 = torch.ones_like(bc0) - bc0 - bc2
        contains = torch.nonzero((bc0>=0.)*(bc1>=0.)*(bc2>=0.)).squeeze()
        # print(((bc0>=0.)*(bc1>=0.)*(bc2>=0.)).reshape(5,10).transpose(1,0))
        # print(((bc0>=0.)*(bc1>=0.)*(bc2>=0.)).sum())
        # print(idx_x.flatten()[contains], idx_y.flatten()[contains])
        valid_bcs =  torch.stack([bc0[contains], bc1[contains], bc2[contains]], dim=-1)
        tri_ind_x, tri_ind_y = idx_x.flatten()[contains], idx_y.flatten()[contains]
        pix_to_face[tri_ind_y, tri_ind_x] = i
        bary_coords[tri_ind_y, tri_ind_x, :] = valid_bcs

    pix_to_face = pix_to_face.unsqueeze(0).unsqueeze(-1)
    bary_coords = bary_coords.unsqueeze(0).unsqueeze(3)
    pix_to_face = torch.flip(pix_to_face, dims=(1,))
    bary_coords = torch.flip(bary_coords, dims=(1,))
    return pix_to_face, bary_coords



def get_mesh(mesh_path:str, device:torch.device):
    mesh = load_objs_as_meshes([mesh_path], device=device)
    offset = -(mesh.get_bounding_boxes()[0]).mean(1)
    mesh = mesh.offset_verts_(offset)
    scale=mesh.get_bounding_boxes()[0].max()*2.
    mesh = mesh.scale_verts_(1./scale.item())
    return mesh



def save_mesh_as_obj(path:str, mesh:Meshes):
    save_obj(path, verts = mesh.verts_packed(), faces=mesh.faces_packed(), 
        verts_uvs=mesh.textures.verts_uvs_padded()[0],
        faces_uvs=mesh.textures.faces_uvs_padded()[0],
        texture_map=mesh.textures.maps_padded()[0])



def warp_uvs(verts_uvs:torch.Tensor, flow:torch.Tensor, align_corners:bool=True):
    verts_uvs = 2.*verts_uvs - 1. #[0,1]->[-1,1]
    verts_uvs[:,:,1] *= -1. #[UV,2]
    #flow.permute(..):[1,2,H,W], uvs[..]:[N=1,h=1,w=UV,f=2] -> [1,f(2),1,UV] -> [1,2,UV] -> [1,UV,2]
    warped_uvs = F.grid_sample(flow.permute(0,3,1,2), verts_uvs[:,None,:,:], align_corners=align_corners).squeeze(2).permute(0,2,1) #[1,UV,2]
    warped_uvs[:,:,1] *= -1.
    return (warped_uvs+1.)/2.



def smoothen_flow(warped_uvs, face_uv_idxs, p2f, bcs):
    uv_flow = warped_uvs.squeeze(0)*2. - 1.
    uv_flow[:,-1] *= -1.
    facewise_flow = uv_flow[face_uv_idxs.flatten(), :]
    face_uv_idxs = face_uv_idxs.squeeze(0)
    facewise_flow = facewise_flow.view(face_uv_idxs.size(0), 3, -1)
    return interpolate_face_attributes(p2f, bcs, facewise_flow).squeeze(3)




def visualize_uvs(img_size:torch.Tensor, texture:TexturesUV):
    face_uv_idxs = texture.faces_uvs_padded().squeeze() #[F,3] -|
    verts_uvs = texture.verts_uvs_padded().squeeze() #[UV,2]  <-|
    faces = (verts_uvs[face_uv_idxs.flatten(),:].view(*face_uv_idxs.shape,2)) * torch.tensor(img_size)
    faces = faces.round().long().cpu().numpy()
    vis_uv = np.zeros((*img_size, 3), dtype=np.uint8)
    np.random.seed(0)
    for f in range(faces.shape[0]):
        face = faces[f,:,:]
        b,g,r = (np.random.rand(3)*255)
        cv2.drawContours(vis_uv, [face], 0, (b,g,r,), -1)
    vis_uv = np.flip(vis_uv, axis=0)
    cv2.imwrite("uv.png", vis_uv)



def get_identity_warp(img_size:torch.Tensor, device):
    w,h = img_size
    hrow = torch.linspace(-1, 1, h).to(device)
    wrow = torch.linspace(-1, 1, w).to(device)
    warp = torch.stack(torch.meshgrid(wrow, hrow, indexing='xy'), dim=-1).unsqueeze(0)
    return warp



if __name__ == "__main__":
    from pytorch3d.io import load_objs_as_meshes
    from pytorch3d.ops import interpolate_face_attributes
    from torchvision.utils import save_image

    mesh = load_objs_as_meshes(["carla001_Hiphop/00.obj"]).cuda()
    y_offset = -(mesh.get_bounding_boxes()[0][1]).mean()
    offset_tensor = torch.Tensor([0.,y_offset,0.]).cuda()
    mesh = mesh.offset_verts_(offset_tensor)
    faces = mesh.faces_packed()
    txtr = mesh.textures.maps_padded().permute(0,3,1,2)

    print("getting p2f")
    p2f, bcs = texture_to_fragments(torch.Tensor((512,512)).long().cuda(), mesh.textures)

    print("sampling txtr")
    tonv = sample_texture(txtr, mesh)
    tonf = tonv[faces.flatten(), :].view(faces.size(0),3,-1)
    print("interping f attribs")
    timg = interpolate_face_attributes(p2f, bcs, tonf).squeeze(3).permute(0,3,1,2)
    save_image(timg, "interp_txtr.png")
    save_image(txtr, "orig_txtr.png")

    print(p2f.squeeze()+1.)
    print(bcs.squeeze()[:,:,0])
    print(bcs.squeeze()[:,:,1])
    print(bcs.squeeze()[:,:,2])

# if __name__ == "__main__":
    # mw = get_mesh("output/base360dist_warp_s8/compressed_mesh_ssim.obj", "cpu")
    # s = torch.zeros(512,512).size()
    # visualize_uvs(s, mw.textures)