import math
import torch
from pytorch3d.renderer import (
    AmbientLights,
    look_at_view_transform,
    MeshRenderer,
    MeshRasterizer,
    RasterizationSettings,
    SoftPhongShader, 
    FoVPerspectiveCameras,
)
from pytorch_msssim import ssim
from pytorch3d.loss import mesh_edge_loss, mesh_laplacian_smoothing, mesh_normal_consistency
from torchvision.utils import save_image

from vgg_loss import VGGLoss, TVLoss

class RenderLoss:
    def __init__(self, device, elev_range=[0., 0.], azim_range=[0., 0.], dist_range=[1.,1.], num_views_per_iteration=1) -> None:
        lights = AmbientLights(ambient_color=((1., 1., 1.),), device=device)
        raster_settings_soft = RasterizationSettings(image_size=1000,)
        rasterizer=MeshRasterizer(raster_settings=raster_settings_soft)
        shader=SoftPhongShader(device=device,lights=lights)
        self.renderer_textured = MeshRenderer(rasterizer=rasterizer, shader=shader)
        self.num_views_per_iteration = num_views_per_iteration
        self.device=device
        self.view_list = []
        self.elev_range = elev_range
        self.azim_range = azim_range
        self.dist_range = dist_range

        self.vgg_loss = VGGLoss().to(device)
        self.tv_loss = TVLoss(p=2).to(device)
    
    def compute_loss(self, orig, proc, f, save_img = False):
        loss_l2 = torch.tensor(0.0, device=self.device)
        loss_ssim = torch.tensor(0.0, device=self.device)
        loss_vgg = torch.tensor(0.0, device=self.device)
        loss_tvl = torch.tensor(0.0, device=self.device)
        loss = {}

        if f=='c':
            elev = torch.Tensor(self.num_views_per_iteration).to(self.device).uniform_(*self.elev_range)
            azim = torch.Tensor(self.num_views_per_iteration).to(self.device).uniform_(*self.azim_range)
            dist = torch.Tensor(self.num_views_per_iteration).to(self.device).uniform_(*self.dist_range)
            R, T = look_at_view_transform(dist=dist, elev=elev, azim=azim, device=self.device, at=((0,0,0),))
            self.cameras = FoVPerspectiveCameras(R=R, T=T, zfar=3, device=self.device)

        for j in range(self.num_views_per_iteration):
            orig_img = self.renderer_textured(orig, cameras=self.cameras[j])
            proc_img = self.renderer_textured(proc, cameras=self.cameras[j])
            mask = orig_img[:,:,:,3:]
            mask = mask.repeat(1,1,1,3)
            midx = torch.nonzero(mask, as_tuple = True)
            orig_img = orig_img[:,:,:,:3]
            proc_img = proc_img[:,:,:,:3]

            orig_nz = orig_img[midx]
            proc_nz = proc_img[midx]
            l2 = ((orig_nz - proc_nz) ** 2).mean()
            loss_l2 += l2

            orig_img = orig_img.permute(0,3,1,2)
            proc_img = proc_img.permute(0,3,1,2)
            loss_ssim += 1.-ssim(orig_img, proc_img, data_range=1)

            loss_vgg += self.vgg_loss(orig_img, proc_img)
            # loss_vgg += 0.0
            loss_tvl += self.tv_loss(orig_img)

            if save_img and j==0:
                if (f=='c'):
                    save_image(orig_img, "img_original.png")
                    save_image(proc_img, "img_compressed.png")
                elif (f=='d'):
                    save_image(proc_img, "img_downsampled.png")
                
        loss_l2 /= self.num_views_per_iteration
        loss_ssim /= self.num_views_per_iteration
        loss_tvl /= self.num_views_per_iteration
        loss_vgg /= self.num_views_per_iteration
        loss['l2'] = loss_l2
        loss['ssim'] = loss_ssim
        loss['vgg'] = loss_vgg
        loss['tvl'] = loss_tvl
        return loss
    
    def validate(self, orig, proc):
        PSNR, SSIM = 0., 0.
        elev = torch.linspace(*self.elev_range,5)
        azim = torch.linspace(*self.azim_range,6)
        x,y = torch.meshgrid((elev, azim), indexing='ij')
        R, T = look_at_view_transform(dist=1, elev=x.flatten(), azim=y.flatten(), device=self.device)
        cameras = FoVPerspectiveCameras(R=R, T=T, zfar=3, device=self.device)
        for i in range(len(cameras)):
            orig_imgs = self.renderer_textured(orig, cameras=cameras[i])[:,:,:,:3]
            proc_imgs = self.renderer_textured(proc, cameras=cameras[i])[:,:,:,:3]
            l2 = ((orig_imgs - proc_imgs) ** 2).mean()
            PSNR += (10 * torch.log10(1.0/ l2)/len(cameras)).cpu().item()
            orig_imgs = orig_imgs.permute(0,3,1,2)
            proc_imgs = proc_imgs.permute(0,3,1,2)
            SSIM += (ssim(orig_imgs, proc_imgs, data_range=1)/len(cameras)).cpu().item()
        mid_elev = sum(self.elev_range)*0.5
        mid_azim = sum(self.azim_range)*0.5
        R, T = look_at_view_transform(dist=1, elev=mid_elev, azim=mid_azim, device=self.device)
        camera = FoVPerspectiveCameras(R=R, T=T, zfar=3, device=self.device)
        orig_imgs = self.renderer_textured(orig, cameras=camera)[:,:,:,:3].permute(0,3,1,2)
        proc_imgs = self.renderer_textured(proc, cameras=camera)[:,:,:,:3].permute(0,3,1,2)
        return round(PSNR,3), round(SSIM,4), orig_imgs, proc_imgs

class PrefitLoss:
    def __init__(self) -> None:
        pass
    def compute_loss(self, mesh):
        loss = {}
        loss["edge"] += mesh_edge_loss(mesh)
        loss["normal"] += mesh_normal_consistency(mesh)
        loss["laplacian"] += mesh_laplacian_smoothing(mesh, method="uniform")
        return loss
