import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import transforms
from albumentations.pytorch import ToTensorV2
import albumentations as A
from PIL import Image
import os
import cv2
import numpy as np
from tqdm import tqdm
import timm
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix
import argparse
parser = argparse.ArgumentParser()

parser.add_argument("--input_dir", help="input directory !! ")
parser.add_argument("--verbose", help="verbose to print logs for debug !! ")
args = parser.parse_args()

verbose = args.verbose
source_dir = args.input_dir


device = "cuda" if torch.cuda.is_available() else "cpu"
test_transform = A.Compose(
    [
        A.Resize(height=224, width=224, p=1),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)

def apply_transform(image, transform):
    image = transform(image=np.array(image))['image']
    return image


class PyramidViTEfficientNetClassifier(nn.Module):
    def __init__(self, vit_model_name="pvt_v2_b5", effnet_model_name="efficientnet_b0", num_classes=1, pretrained=False):
        super(PyramidViTEfficientNetClassifier, self).__init__()
        
        # Load pre-trained Pyramid ViT model
        self.vit_model = timm.create_model(vit_model_name, pretrained=pretrained)
        vit_in_features = self.vit_model.get_classifier().in_features
        self.vit_model.reset_classifier(0)  # Remove original classifier
        
        # Load pre-trained EfficientNet model
        self.effnet_model = timm.create_model(effnet_model_name, pretrained=pretrained)
        effnet_in_features = self.effnet_model.get_classifier().in_features
        self.effnet_model.reset_classifier(0)  # Remove original classifier
        
        # Custom combined classifier head
        combined_in_features = vit_in_features + effnet_in_features  # Concatenate features
        self.classifier = nn.Sequential(
            nn.Linear(combined_in_features, 64),  # Reduce to 64 dimensions
            nn.ReLU(),
            nn.Dropout(0.5),  # Regularization
            nn.Linear(64, num_classes)  # Output layer for binary classification
        )

    def forward(self, x):
        # Extract features from Pyramid ViT
        vit_features = self.vit_model(x)  # Shape: (batch_size, vit_in_features)
        
        # Extract features from EfficientNet
        effnet_features = self.effnet_model(x)  # Shape: (batch_size, effnet_in_features)
        
        # Concatenate features
        combined_features = torch.cat((vit_features, effnet_features), dim=1)  # Concatenate along feature dimension
        
        # Pass through the custom classifier
        output = self.classifier(combined_features)
        return output
    
model = PyramidViTEfficientNetClassifier()

state_dict = torch.load("FaceQ_transformer.pt", map_location= device )
model.load_state_dict(state_dict)

root_directory =  source_dir
good_directory = os.path.join(root_directory, "good_crops")
bad_directory = os.path.join(root_directory, "bad_crops")
good_images_filepaths = sorted([os.path.join(good_directory, f) for f in os.listdir(good_directory)])
bad_images_filepaths = sorted([os.path.join(bad_directory, f) for f in os.listdir(bad_directory)])
val_images_filepaths = good_images_filepaths + bad_images_filepaths
val_images_filepaths = [i for i in val_images_filepaths if ".jpg" in i or ".png" in i]

labels = [0 if "bad" in i else 1 for i in val_images_filepaths]
results = []

for image_path in tqdm(val_images_filepaths):
    image = Image.open(image_path).convert('RGB')
    image = apply_transform(image, test_transform)
    image = image.unsqueeze(0)  # Add batch dimension

    with torch.no_grad():
        output = model(image)
        prob = torch.sigmoid(output).item()  # Apply sigmoid to get the probability

    if verbose:
        print(f'Inference result for {image_path}: {prob}')
    results.append(prob)

threshold = 0.6
results_bin = [1 if i >= threshold else 0 for i in results]
print( "Accuracy Score : ", accuracy_score(labels, results_bin) )
print( "F1 Score : ", f1_score(labels, results_bin))
print( "AUC ROC : ", roc_auc_score(labels, results))
print("Confusion Matrix : \n")
print(confusion_matrix(labels, results_bin))