import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch3d.structures import Meshes
from pytorch3d.ops import GraphConv, interpolate_face_attributes
from pytorch3d.renderer import TexturesUV
from utils import sample_texture, texture_to_fragments, warp_uvs, get_identity_warp, smoothen_flow
from math import log2

class ResBlock(nn.Module):
    def __init__(self, in_ch:int, out_ch:int, downsample:any=None) -> None:
        super(ResBlock, self).__init__()
        stride= 1 + (downsample is not None)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch,  3, 1, 1)
        self.downsample = downsample
    def forward(self, input):
        residual = input
        x = self.conv2(F.relu(self.conv1(input)))
        if self.downsample is not None:
            residual = self.downsample(residual)
        return F.relu(x+residual)

class GraphResBlock(nn.Module):
    def __init__(self, in_ch:int, out_ch:int) -> None:
        super(GraphResBlock, self).__init__()
        self.conv1 = GraphConv(in_ch, out_ch)
        self.conv2 = GraphConv(out_ch, out_ch)
        self.add_residual = (in_ch==out_ch)
    def forward(self, verts: torch.Tensor, edges:torch.Tensor):
        x = F.relu(self.conv1(verts,edges))
        x = self.conv2(x,edges)
        if self.add_residual: return F.relu(x+verts)
        else: return F.relu(x)

class TextureEncoderNet(nn.Module):
    def __init__(self, scale) -> None:
        super(TextureEncoderNet, self).__init__()
        self.layer0 = nn.Sequential(nn.Conv2d(6, 64, 3, 1, 1),
                              nn.BatchNorm2d(64), nn.ReLU())
        self.rbs = nn.ModuleList([self._make_layer(64, 64, 2) for _ in range(scale)])
        self.scale = scale
    def _make_layer(self, in_ch:int, out_ch:int, block_count:int):
        if (in_ch!=out_ch): 
            downsampler = nn.Sequential(nn.Conv2d(in_ch, out_ch, 1, 2),
                nn.BatchNorm2d(out_ch))
        else:
            downsampler=None
        layers = []
        layers.append(ResBlock(in_ch, out_ch, downsampler))
        for i in range(1, block_count):
            layers.append(ResBlock(out_ch, out_ch))
        return nn.Sequential(*layers)
    def forward(self, texture_map: torch.Tensor) -> torch.Tensor:
        x = self.layer0(texture_map)
        for i in range(self.scale):
            x = self.rbs[i](x)
            x = F.max_pool2d(x,2,2)
        return x

class GraphEncoderNet(nn.Module):
    def __init__(self) -> None:
        super(GraphEncoderNet, self).__init__()
        self.grb1 = GraphResBlock(64, 64)
        self.grb2 = GraphResBlock(64, 64)
        self.grb3 = GraphResBlock(64, 64)
        self.grb4 = GraphResBlock(64, 64)
    def forward(self, vert, edges):
        vert = self.grb1(vert,edges)
        vert = self.grb2(vert,edges)
        vert = self.grb3(vert,edges)
        return self.grb4(vert,edges)

class TextureDecoderNet(nn.Module):
    def __init__(self, in_ch:int) -> None:
        super(TextureDecoderNet, self).__init__()
        self.conv_1 = nn.Conv2d(in_ch,32,3,1,'same')
        self.block2 = self._make_layer(32,32,2)
        self.conv_3 = nn.Conv2d(32,3,3,1,'same')
    def _make_layer(self, in_ch:int, out_ch:int, block_count:int):
        layers = []
        layers.append(ResBlock(in_ch, out_ch))
        for i in range(1, block_count):
            layers.append(ResBlock(out_ch, out_ch))
        return nn.Sequential(*layers)
    def forward(self,x) -> torch.Tensor:
        x = F.relu(self.conv_1(x))
        x = self.conv_3(self.block2(x))
        return x
    
class FlowDecoderNet(nn.Module):
    def __init__(self) -> None:
        super(FlowDecoderNet, self).__init__()
        self.conv_1 = nn.Conv2d(32,32,3,1,'same')
        self.block2 = self._make_layer(32,32,2)
        self.conv_3 = nn.Conv2d(32,2,3,1,'same')
    def _make_layer(self, in_ch:int, out_ch:int, block_count:int):
        layers = []
        layers.append(ResBlock(in_ch, out_ch))
        for i in range(1, block_count):
            layers.append(ResBlock(out_ch, out_ch))
        return nn.Sequential(*layers)
    def forward(self,x) -> torch.Tensor:
        x = F.relu(self.conv_1(x))
        x = self.conv_3(self.block2(x))
        x = F.tanh(x).permute(0,2,3,1)
        return x

class TextureCodec(nn.Module):
    def __init__(self, mesh: Meshes, scale: float, use_gcnn: bool, use_uv_warp: bool):
        super(TextureCodec, self).__init__()
        self.scale = scale
        self.use_gcnn = use_gcnn
        self.use_uv_warp = use_uv_warp
        log_scale = int(log2(scale))

        self.texture_encoder = TextureEncoderNet(log_scale)
        if self.use_gcnn:
            self.graph_encoder = GraphEncoderNet()
        decoder_in_ch = 64
        if self.use_uv_warp:
            self.flow_decoder = FlowDecoderNet()
            decoder_in_ch = 32
        self.texture_decoder = TextureDecoderNet(in_ch = decoder_in_ch)

        self.mesh = mesh
        self.verts = mesh.verts_packed() #[V,3]
        self.faces = mesh.faces_packed() #[F,3]
        self.edges = mesh.edges_packed() #[E,2]
        self.texture_map = mesh.textures.maps_padded() #[1,H,W,3]
        self.align_corners = mesh.textures.align_corners
        _,h,w,_ = self.texture_map.size()
        self.verts_uvs = self.mesh.textures.verts_uvs_padded() #[1,UV,2]  <-|
        self.face_uv_idxs = self.mesh.textures.faces_uvs_padded() #[1,K,3] _|

        self.size_hr = torch.Tensor([w,h]).to(self.texture_map).long()
        self.size_lr = torch.Tensor([round(w/scale),round(h/scale)]).to(self.texture_map).long()

        self.identity_warp = get_identity_warp(self.size_lr, device=self.size_lr.device)
        
        self.p2f_hr, self.bcs_hr = texture_to_fragments(self.size_hr, mesh.textures)
        self.p2f_lr, self.bcs_lr = texture_to_fragments(self.size_lr, mesh.textures)

        normals = mesh.verts_normals_packed() #[V,3]
        facewise_normals = normals[self.faces.flatten(), :].view(self.faces.size(0),3,-1) #[K,3,C]
        self.normal_map_hr = interpolate_face_attributes(self.p2f_hr, self.bcs_hr, facewise_normals).squeeze(3).permute(0,3,1,2)

    def forward(self):
        texture = self.texture_map.permute(0,3,1,2)
        encoder_input = torch.concat([texture, self.normal_map_hr], dim=1)

        encoded_texture_map = self.texture_encoder(encoder_input) #[1,D,H*scale,W*scale]

        if self.use_gcnn:
            feats_on_verts = sample_texture(encoded_texture_map, mesh=self.mesh) #[V,C]
            feats_on_verts = self.graph_encoder(feats_on_verts, self.edges) #[V,C]
            feats_on_faces = feats_on_verts[self.faces.flatten(), :].view(self.faces.size(0),3,-1) #[K,3,C]
            geocoded_texture_map = interpolate_face_attributes(self.p2f_lr, self.bcs_lr, feats_on_faces) #[N,H,W,1,C]
            geocoded_texture_map = geocoded_texture_map.squeeze(3).permute(0,3,1,2) #[N,C,H,W]
            geocoded_texture_map += encoded_texture_map
        else:
            geocoded_texture_map = encoded_texture_map

        if self.use_uv_warp:
            C = geocoded_texture_map.size(1)
            flow_in, geocoded_texture_map = torch.split(geocoded_texture_map, split_size_or_sections=C//2, dim=1)
            residual_warping = self.flow_decoder(flow_in)
            uv_warping = torch.clamp(self.identity_warp + residual_warping, -1.0, 1.0) #[1,H,W,2]
            warped_uvs = warp_uvs(self.verts_uvs, uv_warping, align_corners=self.align_corners)
            # uv_warping = smoothen_flow(warped_uvs, self.face_uv_idxs, self.p2f_lr, self.bcs_lr)
            geocoded_texture_map = F.grid_sample(geocoded_texture_map, uv_warping) #warped for better uv parametrization
        else:
            warped_uvs = self.verts_uvs

        decoded_texture_map = self.texture_decoder(geocoded_texture_map)
        decoded_texture_map = torch.clamp(decoded_texture_map, 0., 1.)
        
        downsampled_texture_map = F.interpolate(texture, size=(self.size_lr[0],self.size_lr[1],), mode='bicubic', antialias=False, align_corners=True)
        downsampled_texture_map = torch.clamp(downsampled_texture_map, 0., 1.)

        dsp_texture = TexturesUV(downsampled_texture_map.permute(0,2,3,1), self.face_uv_idxs, self.verts_uvs)
        out_texture = TexturesUV(decoded_texture_map.permute(0,2,3,1), self.face_uv_idxs, warped_uvs)
        out_mesh = Meshes([self.verts], [self.faces], out_texture)
        dsp_mesh = Meshes([self.verts], [self.faces], dsp_texture)
        return out_mesh, dsp_mesh