"""
Embedding encoding matrix generation module.
"""

import cv2
import numpy as np
import logging

logger = logging.getLogger(__name__)
from typing import Dict, Any


class EmbeddingEncoder:
    """
    Create 2D encoding matrices from CLIP model embeddings.

    Converts embedding arrays into visual matrix representations.
    """

    def __init__(self, target_height: int = 400, target_width: int = 768):
        """
        Initialize embedding encoder.

        Args:
            target_height: Target height for encoding matrix
            target_width: Target width for encoding matrix
        """
        self.target_height = target_height
        self.target_width = target_width

        logger.info(f"EmbeddingEncoder initialized: {target_width}x{target_height}")

    def create_encoding_matrix(self, embeddings: np.ndarray) -> Dict[str, Any]:
        """
        Create 2D encoding matrix from embeddings.

        Args:
            embeddings: Raw embeddings array (n_frames, embedding_dim)

        Returns:
            Dictionary containing encoding matrix and metadata
        """
        logger.info(f"Creating encoding matrix from embeddings: {embeddings.shape}")

        n_frames, embedding_dim = embeddings.shape

        # Create encoding matrix
        if n_frames == 1:
            # Single frame - reshape embedding to matrix
            encoding_matrix = embeddings.reshape(1, -1)
        else:
            # Multiple frames - use embeddings as rows
            encoding_matrix = embeddings

        # Normalize to [0, 1] range
        matrix_min = np.min(encoding_matrix)
        matrix_max = np.max(encoding_matrix)

        if matrix_max > matrix_min:
            normalized_matrix = (encoding_matrix - matrix_min) / (
                matrix_max - matrix_min
            )
        else:
            normalized_matrix = np.zeros_like(encoding_matrix)

        # Resize to target dimensions
        resized_matrix = cv2.resize(
            normalized_matrix, (self.target_width, self.target_height)
        )

        # Convert to uint8 for visualization
        encoding_image = (resized_matrix * 255).astype(np.uint8)

        result = {
            "encoding_matrix": encoding_matrix,
            "normalized_matrix": normalized_matrix,
            "encoding_image": encoding_image,
            "original_embeddings": embeddings,
            "n_frames": n_frames,
            "embedding_dim": embedding_dim,
            "encoding_shape": encoding_image.shape,
        }

        logger.info(f"Created encoding matrix: {encoding_image.shape}")
        return result
