import os
import argparse
import glob
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import open_clip
import numpy as np

class ImageEmbeddingProcessor:
    """Processes image embeddings and visualizes their relationships."""

    def __init__(self, model, device, preprocess):
        """
        Initializes the processor with a model and device.
        
        Args:
            model: The model used for generating image embeddings.
            device: The computing device ('cuda' or 'cpu').
        """
        self.model = model
        self.device = device
        self.preprocess = preprocess

    def inverse_normalize(self, image_tensor):
        """Applies inverse normalization to an image tensor."""
        mean = [0.48145466, 0.4578275, 0.40821073]
        std = [0.26862954, 0.26130258, 0.27577711]
        inv_norm = transforms.Normalize(mean=[-m / s for m, s in zip(mean, std)], std=[1 / s for s in std])
        return inv_norm(image_tensor)

    def load_and_embed_images(self, image_paths):
        """Loads images from paths and computes their embeddings."""
        embeddings_map = {}
        for idx, path in enumerate(image_paths):
            processed_image = self.preprocess(Image.open(path)).unsqueeze(0).to(device)
            with torch.no_grad():
                embedding = self.model.encode_image(processed_image)
            embeddings_map[f'img{idx}'] = (processed_image, embedding)
        return embeddings_map

    def project_embeddings(self, embeddings_map):
        """Projects embeddings onto a lower-dimensional space using PCA."""
        embeddings = np.array([emb[1].cpu().numpy().flatten() for emb in embeddings_map.values()])
        pca = PCA(n_components=6)
        projections = pca.fit_transform(embeddings)
        return projections, pca

    def visualize_projections(self, projections, img_title_map, mat_org, mat_final, texts, output_path, output_filename):
        """Visualizes the projected embeddings along with the vision-text matrices."""
        N = len(texts)
        plt.rcParams.update({'font.size': 12})
        fig, axs = plt.subplots(nrows=2, ncols=7, squeeze=False, figsize=(24, 8))

        for i, (key, projection) in enumerate(zip(img_title_map.keys(), projections)):
            row, col = divmod(i, 3)
            axs[row, col * 2].imshow(self.inverse_normalize(embeddings_map[key][0]).squeeze(0).cpu().permute(1, 2, 0))
            axs[row, col * 2].set_title(img_title_map[key])
            axs[row, col * 2].axis('off')

            axs[row, col * 2 + 1].bar(range(len(projection)), projection, color='skyblue')
            axs[row, col * 2 + 1].set_title(f'Projection {key}')
            axs[row, col * 2 + 1].set_xticks(range(len(projection)))
            axs[row, col * 2 + 1].set_ylim([-10, 10])

        # Plotting the original and final matrices
        for i, matrix in enumerate([mat_org, mat_final]):
            ax = axs[i, -1]
            cax = ax.matshow(matrix, cmap='viridis')
            ax.set_title('Original' if i == 0 else 'Final')
            plt.colorbar(cax, ax=ax, fraction=0.046, pad=0.04)
            ax.set_xticks(range(N))
            ax.set_yticks(range(N))
            ax.set_xticklabels(texts, rotation=90)
            ax.set_yticklabels(texts)

        plt.tight_layout()
        os.makedirs(output_path, exist_ok=True)
        filename = os.path.join(output_path, f"{output_filename}.pdf")
        plt.savefig(filename)
        plt.show()

def main():
    parser = argparse.ArgumentParser(description="Image Embedding Processing and Visualization")
    parser.add_argument("--image_dir", type=str, required=True, help="Directory containing image files")
    parser.add_argument("--output_path", type=str, required=True, help="Output path for visualizations")
    args = parser.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k', device=device)

    processor = ImageEmbeddingProcessor(model, device)
    image_paths = glob.glob(os.path.join(args.image_dir, '*.jpeg'))
    image_paths.sort()
    embeddings_map = processor.load_and_embed_images(image_paths)

    projections, _ = processor.project_embeddings(embeddings_map)

    # Example usage of visualizing projections (mat_org and mat_final are placeholders)
    # You need to replace these with actual vision-text matrix data, which can be obtained using the ClassificationHelper class
    # mat_org, mat_final
    '''
    For Example, 
    mat_org = np.array([[1.0000e+00, 5.8658e-10, 1.6041e-12],
        [1.1187e-07, 1.0000e+00, 8.2994e-12],
        [2.2586e-12, 8.6227e-12, 1.0000e+00]])
    mat_final = matrix = np.array([[4.2853e-11, 2.7194e-11, 1.0000e+00],
        [3.7758e-12, 2.7308e-11, 1.0000e+00],
        [1.0000e+00, 1.0218e-09, 9.7523e-12]])
    '''

    texts = text_list=["lizard", "peacock", "wombat"]
    img_title_map = {'img0': 'agama lizard', 'img1': 'peacock', 'img2': 'wombat', 'img3': 'agama lizard -> wombat', 'img4': 'peacock -> wombat', 'img5': 'wombat -> agama lizard'}

    processor.visualize_projections(projections, img_title_map, mat_org, mat_final, texts, args.output_path, "proj")

if __name__ == "__main__":
    main()



