import logging
import time
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional

import cv2
import imagehash
import numpy as np
from PIL import Image
from pydantic import BaseModel

from ..config.frame_extractor_config import FrameExtractorConfig


class VideoInfo(BaseModel):
    filename: str
    fps: float
    total_frames: int
    width: int
    height: int


class ExtractionStats(BaseModel):
    frames_extracted: int
    extraction_time: float
    extraction_rate: float
    # Optional fields depending on the operation that produced the stats
    frame_interval: int | None = None
    target_fps: float | None = None
    skip_similar: float | None = None


@dataclass
class Frame:
    key: int
    image: Image.Image


@dataclass
class Frames:
    items: List[Frame]
    image_width: int
    image_height: int


@dataclass
class ExtractFramesResult:
    is_empty: bool
    frames: Frames
    extraction_stats: ExtractionStats
    video_info: VideoInfo


@dataclass
class FilterFramesResult:
    frames: Frames
    extraction_stats: ExtractionStats


logger = logging.getLogger(__name__)


class FrameExtractor:
    """
    Extract frames from videos at specified FPS intervals.

    This class focuses solely on frame extraction logic.
    """

    def __init__(self, config: FrameExtractorConfig = FrameExtractorConfig(), video_path: Optional[str] = None):
        """
        Initialize frame extractor.

        Args:
            config: The frame extractor config
        """

        self.config = config
        self._video_path: Optional[str] = video_path
        self._cap: Optional[cv2.VideoCapture] = None

    def __enter__(self):
        if not self._video_path:
            raise ValueError("video_path must be set before entering context. Use FrameExtractor(config, video_path) or .open(path).")
        self._cap = cv2.VideoCapture(self._video_path)
        if not self._cap.isOpened():
            raise ValueError(f"Cannot open video: {self._video_path}")
        return self

    def __exit__(self, exc_type, exc, tb):
        if self._cap is not None:
            try:
                self._cap.release()
            finally:
                self._cap = None
        return False

    def open(self, video_path: str) -> "FrameExtractor":
        self._video_path = video_path
        return self

    def extract_frames(self, video_path: Optional[str] = None, start_min: int = 0, end_min: Optional[int] = None) -> ExtractFramesResult:
        """
        Extract frames from video at target FPS.

        Args:
            video_path: Path to video file
            start_min: Start time in minutes (0 = beginning of video)
            end_min: End time in minutes (None = end of video, 0 = beginning of video)

        Returns:
            Dictionary containing frames, stats, and video info

        Raises:
            ValueError: If video cannot be opened or has invalid properties
        """
        # Resolve active capture/path
        active_cap = self._cap
        active_path: Optional[str] = self._video_path
        if video_path is not None:
            active_path = video_path
        if active_cap is None:
            if not active_path:
                raise ValueError("No video source. Provide video_path or use the context manager with .open(path).")
            logger.info(f"Extracting frames from: {active_path}, start_min = {start_min}, end_min = {end_min}")
        else:
            logger.info(f"Extracting frames from bound capture: {active_path}, start_min = {start_min}, end_min = {end_min}")
        start_time = time.time()

        # Open and validate if not provided by context
        cap_was_temporary = False
        if active_cap is None:
            cap = cv2.VideoCapture(active_path)
            if not cap.isOpened():
                raise ValueError(f"Cannot open video: {active_path}")
            cap_was_temporary = True
        else:
            cap = active_cap

        # Get video properties
        video_fps = cap.get(cv2.CAP_PROP_FPS)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

        # Calculate sampling interval
        frame_interval = max(1, int(round(video_fps / self.config.target_fps)))

        # Extract frames
        frames_list: list[Frame] = []

        # We have to maintain the total frame_count, frame_keys separately
        frame_count = -1
        frame_key = -1
        extracted_count = 0

        logger.info(f"Video: {total_frames} frames at {video_fps:.2f} FPS")
        logger.info(f"Extracting every {frame_interval} frames")

        start_frame = int(start_min * 60 * video_fps)
        end_frame = int(end_min * 60 * video_fps) if end_min is not None else total_frames
        logger.info(f"start = {start_frame}, end = {end_frame}")

        # Save the image width and height
        image_width = 0
        image_height = 0
        while True:
            ret, frame = cap.read()
            frame_count += 1
            # We have to update frame_keys even before the start of the section
            frame_key += 1

            # We skip frames before the start_frame, and break if we get to the end
            if not ret:
                break
            if frame_count < start_frame:
                continue
            if end_min is not None and frame_count >= end_frame:
                break

            # Extract frame at interval
            if frame_count % frame_interval == 0:
                # Convert BGR to RGB and create PIL Image
                rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                pil_image = Image.fromarray(rgb_frame)

                if extracted_count == 0:
                    image_width = pil_image.width
                    image_height = pil_image.height

                frames_list.append(Frame(key=frame_key, image=pil_image))
                extracted_count += 1

        if cap_was_temporary:
            cap.release()
        extraction_time = time.time() - start_time

        # Prepare results
        video_info = VideoInfo(
            filename=(active_path or "").split("/")[-1],
            fps=float(video_fps),
            total_frames=total_frames,
            width=width,
            height=height,
        )

        extraction_stats = ExtractionStats(
            frames_extracted=extracted_count,
            frame_interval=frame_interval,
            target_fps=self.config.target_fps,
            extraction_time=extraction_time,
            extraction_rate=extracted_count / extraction_time if extraction_time > 0 else 0,
        )

        is_empty = extracted_count == 0

        logger.info(f"Extracted {extracted_count} frames in {extraction_time:.2f}s")

        return ExtractFramesResult(
            is_empty=is_empty,
            frames=Frames(items=frames_list, image_width=image_width, image_height=image_height),
            extraction_stats=extraction_stats,
            video_info=video_info,
        )

    def extract_frames_generator(self, video_path: Optional[str] = None, start_min: int = 0, end_min: Optional[int] = None):
        """
        Generator that yields frames from video at target FPS without storing them all in memory.

        Args:
            video_path: Path to video file
            start_min: Start time in minutes (0 = beginning of video)
            end_min: End time in minutes (None = end of video, 0 = beginning of video)

        Yields:
            Frame objects one at a time

        Raises:
            ValueError: If video cannot be opened or has invalid properties
        """
        # Resolve active capture/path
        active_cap = self._cap
        active_path: Optional[str] = self._video_path
        if video_path is not None:
            active_path = video_path
        if active_cap is None:
            if not active_path:
                raise ValueError("No video source. Provide video_path or use the context manager with .open(path).")
            logger.info(f"Extracting frames from: {active_path}, start_min = {start_min}, end_min = {end_min}")
        else:
            logger.info(f"Extracting frames from bound capture: {active_path}, start_min = {start_min}, end_min = {end_min}")

        # Open and validate if not provided by context
        cap_was_temporary = False
        if active_cap is None:
            cap = cv2.VideoCapture(active_path)
            if not cap.isOpened():
                raise ValueError(f"Cannot open video: {active_path}")
            cap_was_temporary = True
        else:
            cap = active_cap

        try:
            # Get video properties
            video_fps = cap.get(cv2.CAP_PROP_FPS)
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

            # Calculate sampling interval
            frame_interval = max(1, int(round(video_fps / self.config.target_fps)))

            logger.info(f"Video: {total_frames} frames at {video_fps:.2f} FPS")
            logger.info(f"Extracting every {frame_interval} frames")

            start_frame = int(start_min * 60 * video_fps)
            end_frame = int(end_min * 60 * video_fps) if end_min is not None else total_frames
            logger.info(f"start = {start_frame}, end = {end_frame}")

            # We have to maintain the total frame_count, frame_keys separately
            frame_count = -1
            frame_key = -1

            while True:
                ret, frame = cap.read()
                frame_count += 1
                # We have to update frame_keys even before the start of the section
                frame_key += 1

                # We skip frames before the start_frame, and break if we get to the end
                if not ret:
                    break
                if frame_count < start_frame:
                    continue
                if end_min is not None and frame_count >= end_frame:
                    break

                # Extract frame at interval
                if frame_count % frame_interval == 0:
                    # Convert BGR to RGB and create PIL Image
                    rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    pil_image = Image.fromarray(rgb_frame)

                    yield Frame(key=frame_key, image=pil_image)

        finally:
            if cap_was_temporary:
                cap.release()

    def get_video_info(self, video_path: Optional[str] = None) -> VideoInfo:
        """
        Get video information without extracting frames.

        Args:
            video_path: Path to video file

        Returns:
            VideoInfo object with video properties

        Raises:
            ValueError: If video cannot be opened
        """
        # Resolve active capture/path
        active_cap = self._cap
        active_path: Optional[str] = self._video_path
        if video_path is not None:
            active_path = video_path

        if active_cap is not None:
            cap = active_cap
            video_fps = cap.get(cv2.CAP_PROP_FPS)
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        else:
            if not active_path:
                raise ValueError("No video source. Provide video_path or use the context manager with .open(path).")
            cap = cv2.VideoCapture(active_path)
            if not cap.isOpened():
                raise ValueError(f"Cannot open video: {active_path}")
            video_fps = cap.get(cv2.CAP_PROP_FPS)
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            cap.release()

        return VideoInfo(
            filename=(active_path or "").split("/")[-1],
            fps=float(video_fps),
            total_frames=total_frames,
            width=width,
            height=height,
        )

    def extract_frames_with_stats(self, video_path: Optional[str] = None, start_min: int = 0, end_min: Optional[int] = None) -> ExtractFramesResult:
        """
        Extract frames from video at target FPS and return with stats (uses generator internally).

        Args:
            video_path: Path to video file
            start_min: Start time in minutes (0 = beginning of video)
            end_min: End time in minutes (None = end of video, 0 = beginning of video)

        Returns:
            Dictionary containing frames, stats, and video info

        Raises:
            ValueError: If video cannot be opened or has invalid properties
        """
        start_time = time.time()

        # Get video info first
        video_info = self.get_video_info(video_path)
        video_fps = video_info.fps

        # Calculate sampling interval
        frame_interval = max(1, int(round(video_fps / self.config.target_fps)))

        # Extract frames using generator
        frames_list: list[Frame] = []
        image_width = 0
        image_height = 0
        extracted_count = 0

        for frame in self.extract_frames_generator(video_path, start_min, end_min):
            if extracted_count == 0:
                image_width = frame.image.width
                image_height = frame.image.height
            frames_list.append(frame)
            extracted_count += 1

        extraction_time = time.time() - start_time

        extraction_stats = ExtractionStats(
            frames_extracted=extracted_count,
            frame_interval=frame_interval,
            target_fps=self.config.target_fps,
            extraction_time=extraction_time,
            extraction_rate=extracted_count / extraction_time if extraction_time > 0 else 0,
        )

        is_empty = extracted_count == 0

        logger.info(f"Extracted {extracted_count} frames in {extraction_time:.2f}s")

        return ExtractFramesResult(
            is_empty=is_empty,
            frames=Frames(items=frames_list, image_width=image_width, image_height=image_height),
            extraction_stats=extraction_stats,
            video_info=video_info,
        )

    def filter_frames_by_hash(self, frames: Frames) -> FilterFramesResult:
        """
        Filter frames from extracted output using perceptual hashing to skip duplicates.

        Args:
            frames_dict: frames_dict output from extract_by_fps

        Returns:
            Dictionary containing frames, stats, and video info

        Raises:
            ValueError: If video cannot be opened or has invalid properties
        """
        logger.info(f"Filtering images, removing {self.config.skip_similar * 100}% of duplicates")
        start_time = time.time()

        # Ensure processing in key order
        ordered_items = sorted(frames.items, key=lambda f: f.key)
        total_frames = len(ordered_items)

        frame_num = 0
        extracted_count = 0
        filtered_items: list[Frame] = []

        last_hash = None

        for frame_item in ordered_items:
            frame_num += 1

            # Compute hash
            current_hash = imagehash.average_hash(frame_item.image)

            if frame_num == 1:
                last_hash = current_hash
            else:
                diff = current_hash - last_hash  # type: ignore[operator]
                last_hash = current_hash

                # Skip similar frames based on hash difference
                if diff < 1 and np.random.rand() <= self.config.skip_similar:
                    continue

            filtered_items.append(frame_item)
            extracted_count += 1

        extraction_time = time.time() - start_time

        extraction_stats = ExtractionStats(
            frames_extracted=extracted_count,
            skip_similar=self.config.skip_similar,
            extraction_time=extraction_time,
            extraction_rate=extracted_count / extraction_time if extraction_time > 0 else 0,
        )

        logger.info(f"removed {total_frames - extracted_count} out of {total_frames} frames in {extraction_time:.2f}s")

        return FilterFramesResult(
            frames=Frames(items=filtered_items, image_width=frames.image_width, image_height=frames.image_height),
            extraction_stats=extraction_stats,
        )

    def save_frames(
        self,
        frames: Frames,
        output_dir: str,
        filename_pattern: str = "frame_{key:06d}.jpg",
        quality: int = 95,
    ) -> list[Path]:
        """
        Save frames to a local directory.

        Args:
            frames: Frames collection to save
            output_dir: Directory to write images to (created if missing)
            filename_pattern: Python format string with {key} placeholder
            quality: JPEG/WebP quality (if applicable)

        Returns:
            List of written file paths
        """
        Path(output_dir).mkdir(parents=True, exist_ok=True)
        written: list[Path] = []
        for item in sorted(frames.items, key=lambda f: f.key):
            filename = filename_pattern.format(key=item.key)
            out_path = Path(output_dir) / filename
            ext = out_path.suffix.lower()
            img = item.image
            if ext in {".jpg", ".jpeg"}:
                img = img.convert("RGB")
                img.save(out_path, format="JPEG", quality=quality, optimize=True)
            elif ext == ".png":
                img.save(out_path, format="PNG", optimize=True)
            elif ext == ".webp":
                img.save(out_path, format="WEBP", quality=quality, method=6)
            else:
                # Fallback to PIL's default based on extension (may raise if unknown)
                img.save(out_path)
            written.append(out_path)
        return written

    def save_frames_from_generator(
        self,
        frame_generator,
        output_dir: str,
        filename_pattern: str = "frame_{key:06d}.jpg",
        quality: int = 95,
    ) -> list[Path]:
        """
        Save frames from a generator to a local directory without storing all frames in memory.

        Args:
            frame_generator: Generator that yields Frame objects
            output_dir: Directory to write images to (created if missing)
            filename_pattern: Python format string with {key} placeholder
            quality: JPEG/WebP quality (if applicable)

        Returns:
            List of written file paths
        """
        Path(output_dir).mkdir(parents=True, exist_ok=True)
        written: list[Path] = []

        for frame in frame_generator:
            filename = filename_pattern.format(key=frame.key)
            out_path = Path(output_dir) / filename
            ext = out_path.suffix.lower()
            img = frame.image

            if ext in {".jpg", ".jpeg"}:
                img = img.convert("RGB")
                img.save(out_path, format="JPEG", quality=quality, optimize=True)
            elif ext == ".png":
                img.save(out_path, format="PNG", optimize=True)
            elif ext == ".webp":
                img.save(out_path, format="WEBP", quality=quality, method=6)
            else:
                # Fallback to PIL's default based on extension (may raise if unknown)
                img.save(out_path)
            written.append(out_path)

        return written
