import logging
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from torch.nn import Module
from typing import Literal, Optional

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)

_transform_image = T.Compose(
    [T.ToTensor(), T.Resize(244), T.CenterCrop(224), T.Normalize([0.5], [0.5])]
)


def _load_image(image_file: str):
    with Image.open(image_file) as image
        transformed_image = _transform_image(image)[:3].unsqueeze(0) # type: ignore
    return transformed_image


_hub_prefix = "facebookresearch/dinov2"
_dino_model_info: dict[Literal["small", "base", "large"], tuple[str, int]] = {
    "small": ("dinov2_vits14", 384),
    "base": ("dinov2_vitb14", 768),
    "large": ("dinov2_vitl14", 1024),
}


class DinoV2Embedder:
    def __init__(
        self,
        model_size: Literal["small", "base", "large"] = "base",
        device: Optional[Literal["cuda", "cpu"]] = None,
    ):
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device(device)

        self.model_size = model_size
        hub_id, self.embedding_dim = _dino_model_info[self.model_size]
        self.dino_model: Module = torch.hub.load(_hub_prefix, hub_id) # type: ignore
        self.dino_model.to(self.device)
        self.dino_model.eval()

    def compute_single_embedding(self, image_file: str) -> list:
        with torch.no_grad():
            embedding = self.dino_model(_load_image(image_file).to(self.device))
        return np.array(embedding[0].cpu().numpy()).reshape(1, -1).tolist()

    def compute_embeddings(self, image_files: list[str]) -> dict[str, list]:
        image_to_embedding: dict[str, list] = {}

        for image_file in image_files:
            image_to_embedding[image_file] = self.compute_single_embedding(image_file)

        return image_to_embedding
