import os
import logging
import tempfile
import json

from .annotation_utils import get_labels_from_annotation_file
from .utils import (
    folder_exists_in_s3,
    upload_file_to_s3,
    file_exists_in_s3,
    list_files_in_s3,
    zip_directory,
    download_file_from_s3,
)
from .s3_config import (
    FRAMES_PATH,
    FRAMES_FILE,
    ANNOTATIONS_PATH,
    ANNOTATIONS_FILE,
    MANIFESTS_PATH,
    MANIFEST_FILE,
)
from .pydantic_models import DatasetManifest, DatasetVersionInfo

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)


class VersionedDatasetManager:
    """
    Manages versioned datasets stored in S3, including frames, annotations, and manifests.
    """

    def __init__(self, bucket: str, s3_client, prefix: str = ""):
        """
        Initialize the VersionedDatasetManager.

        Args:
            bucket: Name of the S3 bucket where datasets are stored.
            s3_client: Boto3 S3 client instance for S3 operations.
            prefix: Optional S3 prefix/path to prepend to all dataset paths.
                   Defaults to empty string (root of bucket).
        """
        self.s3_client = s3_client
        self.bucket = bucket
        normalized_prefix = (prefix or "").strip("/")
        self.prefix = normalized_prefix

    # ======================================================
    # Directory path builders (no S3 checks here)
    # ======================================================
    def get_manifests_dir(self, dataset_id: str) -> str:
        """
        Get the S3 directory path for manifest files for a given dataset.

        Args:
            dataset_id: Unique identifier for the dataset.

        Returns:
            S3 key path to the manifests directory for the dataset.
            Format: {prefix}/manifests/dataset_id={dataset_id}
        """
        base = f"{self.prefix}/{MANIFESTS_PATH}" if self.prefix else MANIFESTS_PATH
        return f"{base}/dataset_id={dataset_id}".strip("/").replace("//", "/")

    def get_annotations_dir(self, dataset_id: str) -> str:
        """
        Get the S3 directory path for annotation files for a given dataset.

        Args:
            dataset_id: Unique identifier for the dataset.

        Returns:
            S3 key path to the annotations directory for the dataset.
            Format: {prefix}/annotations/dataset_id={dataset_id}
        """
        base = f"{self.prefix}/{ANNOTATIONS_PATH}" if self.prefix else ANNOTATIONS_PATH
        return f"{base}/dataset_id={dataset_id}".strip("/").replace("//", "/")

    def get_frames_dir(self, dataset_id: str) -> str:
        """
        Get the S3 directory path for frame files for a given dataset.

        Args:
            dataset_id: Unique identifier for the dataset.

        Returns:
            S3 key path to the frames directory for the dataset.
            Format: {prefix}/RAW-DATA/frames/dataset_id={dataset_id}
        """
        base = f"{self.prefix}/{FRAMES_PATH}" if self.prefix else FRAMES_PATH
        return f"{base}/dataset_id={dataset_id}".strip("/").replace("//", "/")

    # ======================================================
    # File path builders (optional S3 existence checking)
    # ======================================================
    def get_manifest_path(self, dataset_id: str, allow_new: bool = False) -> str:
        """
        Get the S3 key path for a dataset's manifest file.

        Args:
            dataset_id: Unique identifier for the dataset.
            allow_new: If True, return path even if file doesn't exist in S3.
                      If False, raise ValueError if file doesn't exist.

        Returns:
            S3 key path to the manifest file.

        Raises:
            ValueError: If allow_new is False and the manifest file doesn't exist in S3.
        """
        key = f"{self.get_manifests_dir(dataset_id)}/{MANIFEST_FILE}"
        key = key.strip("/").replace("//", "/")

        if not allow_new and not file_exists_in_s3(self.bucket, key, self.s3_client):
            raise ValueError(f"manifest file not present on s3://{self.bucket}/{key}")

        return key

    def get_frames_path(self, dataset_id: str, allow_new: bool = False) -> str:
        """
        Get the S3 key path for a dataset's frames zip file.

        Args:
            dataset_id: Unique identifier for the dataset.
            allow_new: If True, return path even if file doesn't exist in S3.
                      If False, raise ValueError if file doesn't exist.

        Returns:
            S3 key path to the frames zip file.

        Raises:
            ValueError: If allow_new is False and the frames file doesn't exist in S3.
        """
        key = f"{self.get_frames_dir(dataset_id)}/{FRAMES_FILE}"
        key = key.strip("/").replace("//", "/")

        if not allow_new and not file_exists_in_s3(self.bucket, key, self.s3_client):
            raise ValueError(f"frames file not present on s3://{self.bucket}/{key}")

        return key

    def get_annotations_path(
        self, dataset_id: str, version: str = "", allow_new: bool = False
    ) -> str:
        """
        Get the S3 key path for a dataset's annotation file for a specific version.

        Args:
            dataset_id: Unique identifier for the dataset.
            version: Version string. If empty and allow_new is True, uses latest version.
                    If empty and allow_new is False, raises ValueError.
            allow_new: If True, return path even if file doesn't exist in S3.
                      If False, raise ValueError if file doesn't exist.

        Returns:
            S3 key path to the annotations file for the specified version.

        Raises:
            ValueError: If allow_new is False and version is empty, or if version
                       doesn't exist, or if annotations file doesn't exist in S3.
        """
        if not version:
            if not allow_new:
                raise ValueError("Use allow_new for finding latest manifest")
            version = self.get_latest_version(dataset_id)
        elif not allow_new and not self.version_exists(dataset_id, version):
            raise ValueError(f"Version {version} not found for dataset {dataset_id}")

        key = (
            f"{self.get_annotations_dir(dataset_id)}/version={version}/{ANNOTATIONS_FILE}"
        )
        key = key.strip("/").replace("//", "/")

        if not allow_new and not file_exists_in_s3(self.bucket, key, self.s3_client):
            raise ValueError(
                f"Annotations file not present on s3://{self.bucket}/{key}"
            )

        return key

    # ======================================================
    # Manifest Helpers
    # ======================================================
    def is_new_dataset(self, dataset_id: str) -> bool:
        """
        Check if a dataset is new (doesn't exist in S3 yet).

        Args:
            dataset_id: Unique identifier for the dataset.

        Returns:
            True if the dataset doesn't exist in S3, False if it exists.

        Raises:
            Exception: If there's an error checking dataset existence in S3.
        """
        try:
            return not folder_exists_in_s3(self.bucket, f"{self.prefix}/{MANIFESTS_PATH}/dataset_id={dataset_id}", self.s3_client)
        except Exception as e:
            logger.error(f"Error checking dataset existence: {e}")
            raise

    def _load_manifest(self, dataset_id: str) -> DatasetManifest:
        """
        Load and parse the manifest file for a dataset from S3.

        Args:
            dataset_id: Unique identifier for the dataset.

        Returns:
            DatasetManifest object parsed from the S3 manifest file.

        Raises:
            ValueError: If the dataset is new (no manifest exists).
            Exception: If there's an error downloading or parsing the manifest.
        """
        if self.is_new_dataset(dataset_id):
            raise ValueError("Cannot load manifest for new dataset")

        manifest_key = self.get_manifest_path(dataset_id)

        with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as tmp:
            tmp_path = tmp.name

        try:
            download_file_from_s3(self.bucket, manifest_key, tmp_path, self.s3_client)
            with open(tmp_path, "r") as f:
                data = json.load(f)
            return DatasetManifest(**data)
        finally:
            if os.path.exists(tmp_path):
                os.unlink(tmp_path)

    def _save_manifest(self, dataset_id: str, manifest: DatasetManifest):
        """
        Save a manifest object to S3 as a JSON file.

        Args:
            dataset_id: Unique identifier for the dataset.
            manifest: DatasetManifest object to save.

        Raises:
            Exception: If there's an error creating the temporary file or uploading to S3.
        """
        manifest_key = self.get_manifest_path(dataset_id, allow_new=True)

        with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as tmp:
            tmp_path = tmp.name

        try:
            with open(tmp_path, "w") as f:
                json.dump(manifest.model_dump(), f, indent=2)
            upload_file_to_s3(tmp_path, self.bucket, manifest_key, self.s3_client)
        finally:
            if os.path.exists(tmp_path):
                os.unlink(tmp_path)

    # ======================================================
    # Versioning Logic
    # ======================================================
    def version_exists(self, dataset_id: str, version: str) -> bool:
        """
        Check if a specific version exists for a dataset.

        Args:
            dataset_id: Unique identifier for the dataset.
            version: Version string to check.

        Returns:
            True if the version exists in the dataset's manifest, False otherwise.
            Returns False if the dataset is new or if there's an error loading the manifest.
        """
        if self.is_new_dataset(dataset_id):
            return False

        try:
            manifest = self._load_manifest(dataset_id)
            return version in manifest.versions
        except Exception as e:
            logger.error(f"Error checking version existence: {e}")
            return False

    def get_latest_version(self, dataset_id: str) -> str:
        """
        Get the latest version string for a dataset by comparing numeric version values.

        Args:
            dataset_id: Unique identifier for the dataset.

        Returns:
            String representation of the highest numeric version found in the manifest.
            Versions are parsed as floats for comparison.

        Raises:
            ValueError: If the dataset is new, has no versions, or has no valid numeric versions.
        """
        if self.is_new_dataset(dataset_id):
            raise ValueError("Cannot get latest version for new dataset")

        manifest = self._load_manifest(dataset_id)

        if not manifest.versions:
            raise ValueError(f"No versions found for dataset {dataset_id}")

        parsed = []
        for v in manifest.versions:
            try:
                parsed.append(float(v))
            except ValueError:
                logger.warning(f"Invalid version string in manifest: {v}")

        if not parsed:
            raise ValueError(f"No valid versions for dataset {dataset_id}")

        return str(max(parsed))

    # ======================================================
    # Upload Logic
    # ======================================================
    def _upload_dataset(
        self,
        dataset_id: str,
        frames_dir: str,
        annotation_file: str,
        version: str,
        annotation_type: str = "COCO",
        dataset_name: str = ""
    ):
        """
        Internal method to upload a dataset version to S3.

        Uploads frames (as zip), annotations file, and updates/creates the manifest.
        This method handles both new datasets and version updates.

        Args:
            dataset_id: Unique identifier for the dataset.
            frames_dir: Local directory path containing frame files to upload.
            annotation_file: Local file path to the annotation JSON file.
            version: Version string for this dataset version.
            annotation_type: Type of annotations (e.g., "COCO"). Defaults to "COCO".
            dataset_name: Name of the dataset. Required for new datasets.

        Raises:
            ValueError: If creating a new dataset but dataset_name is not provided.
            Exception: If there's an error during file operations or S3 uploads.
        """
    
        with tempfile.TemporaryDirectory() as temp_dir:
            new_dataset = self.is_new_dataset(dataset_id)

            num_frames = len(os.listdir(frames_dir))

            # frames.zip
            frames_zip_path = os.path.join(temp_dir, "frames.zip")
            zip_directory(frames_dir, frames_zip_path)
            frames_s3_key = self.get_frames_path(dataset_id, allow_new=True)
            upload_file_to_s3(frames_zip_path, self.bucket, frames_s3_key, self.s3_client)

            # annotations.json
            annotations_s3_key = self.get_annotations_path(
                dataset_id, version=version, allow_new=True
            )
            upload_file_to_s3(annotation_file, self.bucket, annotations_s3_key, self.s3_client)

            # manifest
            if new_dataset:
                if not dataset_name:
                    raise ValueError("New datasets require dataset name")
                manifest = DatasetManifest(dataset_id=dataset_id, dataset_name=dataset_name, frames_s3_paths=frames_s3_key, versions={})
            else:
                manifest = self._load_manifest(dataset_id)

            manifest.frames_s3_paths = frames_s3_key  # ensure updated path

            labels = get_labels_from_annotation_file(annotation_file)
            version_info = DatasetVersionInfo(
                version=version,
                annotation_type=annotation_type,
                annotation_file=annotations_s3_key,
                num_frames=num_frames,
                num_clips=0,
                num_videos=1,
                labels = labels
            )
            manifest.versions[version] = version_info
            self._save_manifest(dataset_id, manifest)

    def create_dataset(self, dataset_id, dataset_name, frames_dir, annotation_file, annotation_type="COCO"):
        """
        Create a new dataset with version 1.0.

        Args:
            dataset_id: Unique identifier for the new dataset.
            dataset_name: Name of the dataset.
            frames_dir: Local directory path containing frame files to upload.
            annotation_file: Local file path to the annotation JSON file.
            annotation_type: Type of annotations (e.g., "COCO"). Defaults to "COCO".

        Raises:
            ValueError: If a dataset with the given dataset_id already exists.
            Exception: If there's an error during the upload process.
        """
        if not self.is_new_dataset(dataset_id):
            raise ValueError(f"Dataset already exists for id {dataset_id}")

        self._upload_dataset(
            dataset_id, frames_dir, annotation_file, version="1.0", annotation_type=annotation_type, dataset_name=dataset_name
        )

    def update_dataset(self, dataset_id, frames_dir, annotation_file, annotation_type="COCO"):
        """
        Update an existing dataset by creating a new version.

        Automatically increments the version number by 1.0 from the latest version.

        Args:
            dataset_id: Unique identifier for the existing dataset.
            frames_dir: Local directory path containing frame files to upload.
            annotation_file: Local file path to the annotation JSON file.
            annotation_type: Type of annotations (e.g., "COCO"). Defaults to "COCO".

        Raises:
            ValueError: If the dataset doesn't exist (use create_dataset() instead).
            Exception: If there's an error during the upload process.
        """
        if self.is_new_dataset(dataset_id):
            raise ValueError(
                f"No existing dataset {dataset_id}. Use create_dataset() instead."
            )

        latest = self.get_latest_version(dataset_id)
        new_version = f"{float(latest) + 1:.1f}"

        logger.info(f"Updating dataset {dataset_id} from v{latest} → v{new_version}")

        self._upload_dataset(
            dataset_id,
            frames_dir,
            annotation_file,
            version=new_version,
            annotation_type=annotation_type,
        )

    def create_or_update_dataset(self, dataset_id, frames_dir, annotation_file, annotation_type="COCO", dataset_name=""):
        """
        Create a new dataset or update an existing one, automatically determining the action.

        If the dataset doesn't exist, creates it with version 1.0. If it exists,
        creates a new version by incrementing the latest version.

        Args:
            dataset_id: Unique identifier for the dataset.
            dataset_name: Name of the dataset. Required for new datasets.
            frames_dir: Local directory path containing frame files to upload.
            annotation_file: Local file path to the annotation JSON file.
            annotation_type: Type of annotations (e.g., "COCO"). Defaults to "COCO".

        Raises:
            ValueError: If creating a new dataset but dataset_name is not provided.
            Exception: If there's an error during the upload process.
        """
        if self.is_new_dataset(dataset_id):
            logger.info(f"Creating new dataset {dataset_id}")
            self.create_dataset(dataset_id, dataset_name, frames_dir, annotation_file, annotation_type)
        else:
            logger.info(f"Updating dataset {dataset_id}")
            self.update_dataset(dataset_id, frames_dir, annotation_file, annotation_type)

    # ======================================================
    # Download + Listing
    # ======================================================
    def download_dataset(self, dataset_id, output_dir, version: str = ""):
        """
        Download a dataset version from S3 to a local directory.

        Downloads the manifest, annotations file for the specified version,
        and the frames zip file. Creates the output directory if it doesn't exist.

        Args:
            dataset_id: Unique identifier for the dataset.
            output_dir: Local directory path where files will be downloaded.
                       The manifest will be saved as "dataset_manifest.json",
                       annotations as "annotations.json", and frames zip in "data/frames.zip".
            version: Version string to download. If empty, downloads the latest version.

        Raises:
            ValueError: If the dataset doesn't exist or the specified version doesn't exist.
            Exception: If there's an error during the download process.
        """
        if not version:
            version = self.get_latest_version(dataset_id)

        os.makedirs(output_dir, exist_ok=True)

        # manifest
        manifest_key = self.get_manifest_path(dataset_id)
        manifest_local = os.path.join(output_dir, "dataset_manifest.json")
        download_file_from_s3(self.bucket, manifest_key, manifest_local, self.s3_client)

        # annotations
        anno_key = self.get_annotations_path(dataset_id, version)
        anno_local = os.path.join(output_dir, ANNOTATIONS_FILE)
        download_file_from_s3(self.bucket, anno_key, anno_local, self.s3_client)

        # frames.zip
        frames_key = self.get_frames_path(dataset_id)
        frames_local = os.path.join(output_dir, "data", FRAMES_FILE)
        os.makedirs(os.path.join(output_dir, "data"), exist_ok=True)
        download_file_from_s3(self.bucket, frames_key, frames_local, self.s3_client)

    def get_datasets(self) -> dict[str, DatasetManifest]:
        """
        Retrieve all datasets and their manifests from S3.

        Scans the S3 bucket for manifest files and loads them into a dictionary.
        Datasets that fail to load are logged as warnings and excluded from the result.

        Returns:
            Dictionary mapping dataset_id to DatasetManifest objects for all
            successfully loaded datasets.

        Note:
            This method may take some time if there are many datasets, as it
            loads each manifest sequentially. Failed manifest loads are logged
            but don't raise exceptions.
        """
        base = f"{self.prefix}/{MANIFESTS_PATH}" if self.prefix else MANIFESTS_PATH
        base = base.strip("/").replace("//", "/")

        result = list_files_in_s3(
            bucket_name=self.bucket,
            prefix=base,
            delimiter="",
            client=self.s3_client,
        )
        files = result.get("files", [])

        manifests = {}
        for fp in files:
            if fp.endswith(f"/{MANIFEST_FILE}"):
                parts = fp.rstrip("/").split("/")
                dir_name = parts[-2]
                if dir_name.startswith("dataset_id="):
                    dataset_id = dir_name[len("dataset_id=") :]
                    try:
                        manifest = self._load_manifest(dataset_id)
                        manifests[dataset_id] = manifest
                    except Exception as e:
                        logger.warning(
                            f"Failed loading manifest for dataset {dataset_id}: {e}"
                        )
        return manifests
        
