import os
from pathlib import Path
from typing import Optional, Tuple

import cv2

from ..annotations.bounding_box import BoundingBox
from ..models.nassco.macp.inspection import MACPInspection
from ..models.nassco.pacp.inspection import PACPInspection
from ..models.types import InspectionUnion
from ..models.wrc.mainline.inspection import MainlineInspection
from ..models.wrc.manhole.inspection import ManholeInspection
from .defect_properties import DefectBand


def draw_bounding_box(
    image_path: str,
    bounding_box: BoundingBox,
    color: Tuple[int, int, int] = (0, 0, 255),
    thickness: int = 2,
    output_path: Optional[str] = None,
    quality: Optional[int] = None,
    target_size_kb: Optional[int] = None,
    short_code: Optional[str] = None,
):
    """
    Draws a bounding box on an image and saves it with compression.

    Args:
        image_path: Path to the image file.
        bounding_box: Bounding box to draw on the image.
        color: Color of the bounding box.
        thickness: Thickness of the bounding box.
        output_path: Path to save the image with the bounding box. If not provided, the image is not saved.
        quality: JPEG quality (0-100). If None, will be auto-calculated based on original file size.
        target_size_kb: Target file size in KB. If provided, quality will be adjusted to meet this target.
        short_code: Short code of the defect.
    Returns:
        The path to the image with the bounding box.
    """

    image = _draw_bounding_box_on_image(image_path, bounding_box, color, thickness, short_code)
    if output_path:
        quality = _calculate_optimal_quality(image_path, target_size_kb) if quality is None else quality
        return save_compressed_image(image=image, output_path=output_path, quality=quality, max_size_mb=target_size_kb / 1024.0 if target_size_kb else None)
    return None


def _draw_bounding_box_on_image(image_path: str, bounding_box: BoundingBox, color: Tuple[int, int, int] = (0, 0, 255), thickness: int = 2, short_code: Optional[str] = None):
    """
    Draws a bounding box on an image and returns the modified image without saving.

    Args:
        image_path: Path to the image file.
        bounding_box: Bounding box to draw on the image.
        color: Color of the bounding box.
        thickness: Thickness of the bounding box.
        short_code: Short code of the defect.
    Returns:
        The modified image as a numpy array.
    """
    image = cv2.imread(image_path)
    xmin, ymin, xmax, ymax = bounding_box.xmin, bounding_box.ymin, bounding_box.xmax, bounding_box.ymax
    cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color, thickness)
    if short_code:
        # Get image dimensions
        height, _ = image.shape[:2]

        # Calculate text position
        # Try to place text above the bounding box first
        text_y_above = ymin - 5
        # If not enough space above, place text below the bounding box
        text_y_below = ymax + 15  # 15 pixels below the bottom border

        # Choose position based on available space
        if text_y_above >= 15:  # Enough space above (15px margin)
            text_y = text_y_above
        elif text_y_below <= height - 15:  # Enough space below (15px margin)
            text_y = text_y_below
        else:
            # Fallback: place at the top border if no space above or below
            text_y = ymin

        cv2.putText(image, short_code, (xmin, text_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, thickness)
    return image


def _calculate_optimal_quality(original_image_path: str, target_size_kb: Optional[int] = None) -> int:
    """
    Calculate optimal JPEG quality based on original image size.

    Args:
        original_image_path: Path to the original image
        target_size_kb: Target file size in KB. If None, uses heuristics based on original size.

    Returns:
        Optimal JPEG quality (0-100)
    """
    original_size_kb = Path(original_image_path).stat().st_size / 1024.0

    if target_size_kb is not None:
        # Calculate quality based on target size
        size_ratio = target_size_kb / original_size_kb

        # Map size ratio to quality (empirical values)
        if size_ratio >= 1.0:
            return 95  # No compression needed
        elif size_ratio >= 0.8:
            return 90
        elif size_ratio >= 0.6:
            return 85
        elif size_ratio >= 0.4:
            return 75
        elif size_ratio >= 0.2:
            return 60
        elif size_ratio >= 0.1:
            return 45
        else:
            return 30  # Minimum quality
    else:
        # Heuristic based on original file size
        if original_size_kb < 100:  # Already small
            return 95
        elif original_size_kb < 500:  # Small
            return 90
        elif original_size_kb < 1000:  # Medium
            return 85
        elif original_size_kb < 2000:  # Large
            return 75
        elif original_size_kb < 5000:  # Very large
            return 60
        else:  # Extremely large
            return 45


def generate_inspection_pictures(
    inspection: InspectionUnion,
    frame_dir: str,
    output_dir: str,
    quality: Optional[int] = None,
    target_size_kb: Optional[int] = None,
    bounding_box: bool = True,
    display_short_codes: bool = True,
):
    """
    Generate annotated pictures for the inspection.

    Args:
        inspection: The inspection object containing observations.
        frame_dir: Directory containing the raw frame images.
        output_dir: Directory to save the annotated images.
        quality: JPEG quality (0-100) for image compression. If None, auto-calculated.
        target_size_kb: Target file size in KB. If provided, quality will be adjusted to meet this target.
        bounding_box: Whether to draw the bounding box on the image.
        display_short_codes: Whether to display the short codes on the image.
    """

    if isinstance(inspection, PACPInspection) or isinstance(inspection, MACPInspection):
        color_band_attr = "nassco_band_color_bgr"
    elif isinstance(inspection, MainlineInspection) or isinstance(inspection, ManholeInspection):
        color_band_attr = "wrc_band_color_bgr"
    else:
        raise ValueError(f"Unsupported inspection type: {type(inspection)}")

    annotated_frames = []
    # ensure the output dir exists
    os.makedirs(output_dir, exist_ok=True)

    for idx, observation in enumerate(inspection.observations):
        if not observation.image_reference:
            raise ValueError(f"Image name not found for {observation.code}")

        frame_name = observation.image_reference
        raw_frame_path = os.path.join(frame_dir, frame_name)
        if not os.path.exists(raw_frame_path):
            raise FileNotFoundError(f"Raw frame file not found: {raw_frame_path}")

        # Generate unique output filename using index to avoid overwriting
        # when multiple observations share the same raw frame
        base_name, ext = os.path.splitext(frame_name)
        unique_output_name = f"{base_name}_{idx:04d}{ext}"
        output_path = os.path.join(output_dir, unique_output_name)

        if bounding_box:
            defect_band = DefectBand(severity=observation.severity or 0)
            color = getattr(defect_band, color_band_attr)
            annotated_frame_path = draw_bounding_box(
                raw_frame_path,
                observation.bounding_box,
                color=color,
                thickness=2,
                output_path=output_path,
                quality=quality,
                target_size_kb=target_size_kb,
                short_code=observation.code.abbreviation if display_short_codes and observation.code else None,
            )
        else:
            if quality is None:
                quality = _calculate_optimal_quality(raw_frame_path, target_size_kb)
            annotated_frame_path = save_compressed_image(
                image=cv2.imread(raw_frame_path),
                output_path=output_path,
                quality=quality,
                max_size_mb=target_size_kb / 1024.0 if target_size_kb else None,
            )
        if annotated_frame_path is None:
            raise ValueError(f"Annotated frame file not found: {annotated_frame_path}")

        # Update observation's image_reference to the unique output filename
        # so downstream code (e.g., PDF generation) references the correct file
        observation.image_reference = unique_output_name

        annotated_frames.append(annotated_frame_path)
    return annotated_frames


def save_compressed_image(image, output_path: str, quality: int = 95, max_size_mb: Optional[float] = None):
    """
    Save an image with compression, optionally limiting file size.

    Args:
        image: OpenCV image array (numpy array)
        output_path: Path to save the image
        quality: JPEG quality (0-100) or PNG compression level (0-9)
        max_size_mb: Maximum file size in MB. If exceeded, quality will be reduced automatically.

    Returns:
        Path to the saved image
    """
    output_ext = Path(output_path).suffix.lower()

    if max_size_mb is not None:
        # Try to save with target quality first
        current_quality = quality
        while current_quality > 10:  # Don't go below 10% quality
            if output_ext in [".jpg", ".jpeg"]:
                encode_params = [cv2.IMWRITE_JPEG_QUALITY, current_quality]
            elif output_ext in [".png"]:
                png_compression = int((100 - current_quality) / 10)
                encode_params = [cv2.IMWRITE_PNG_COMPRESSION, png_compression]
            else:
                encode_params = []

            cv2.imwrite(output_path, image, encode_params)

            # Check file size
            file_size_mb = Path(output_path).stat().st_size / (1024 * 1024)
            if file_size_mb <= max_size_mb:
                break

            # Reduce quality and try again
            current_quality -= 10
    else:
        # Save with specified quality
        if output_ext in [".jpg", ".jpeg"]:
            encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality]
        elif output_ext in [".png"]:
            png_compression = int((100 - quality) / 10)
            encode_params = [cv2.IMWRITE_PNG_COMPRESSION, png_compression]
        else:
            encode_params = []

        cv2.imwrite(output_path, image, encode_params)

    return Path(output_path)
