import boto3
import re
import logging
from typing import Optional, Any

from .distance_ocr import DistanceOCR

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)


def _detect_text_in_image(client, image_bytes: bytes) -> list[str]:
    """
    Detects and returns text from an image using AWS Rekognition.

    :param client: boto3 Rekognition client (from create_rekognition_client)
    :param image_bytes: Image data as bytes (from get_frame_data)
    :return: List of detected text strings
    """
    # Call Rekognition DetectText API
    response = client.detect_text(Image={"Bytes": image_bytes})

    # Extract detected text
    detected_texts = [
        text["DetectedText"]
        for text in response.get("TextDetections", [])
        if text["Type"] == "LINE"
    ]

    return detected_texts


_SPLIT_DECIMAL_PATTERN = re.compile(r"(\d+)[.,]\s*$")  # e.g. "12." or "12," at end of line

# Normalize OCR spacing: "26 .01" or "26 , 01" -> "26.01" (avoids matching only ".01")
_NORMALIZE_SPACED_DECIMAL = re.compile(r"(\d+)\s+([.,])\s*(\d+)")


def _extract_distance(
    texts: list[str],
    units: Optional[list[str]] = None,
    verbose: bool = False,
) -> float | None:
    """
    Extracts a distance value from a list of OCR text lines.

    Handles OCR text that may split decimal numbers across lines (e.g., "12." on one
    line and "5 m" on the next).

    :param texts: List of OCR text lines to search
    :param units: List of valid distance units (default ['m', 'ft'])
    :param verbose: if True, log matched lines used for distance
    :return: Extracted distance as float, or None if not found
    """
    if units is None:
        units = ["m", "ft"]

    units_pattern = "|".join(map(re.escape, units))
    distance_pattern = re.compile(
        rf"(\d+(?:[.,]\d+)?)\s*(?:{units_pattern})\b", re.IGNORECASE
    )

    for i, text in enumerate(texts):
        text = text.replace("\xa0", " ").strip()
        text = _NORMALIZE_SPACED_DECIMAL.sub(r"\1\2\3", text)  # "26 .01" -> "26.01"
        match = distance_pattern.search(text)
        if not match:
            continue
        raw_value = match.group(1)
        if verbose:
            logger.info(
                f"Found match: {raw_value}"
            )
        normalized_value = float(raw_value.replace(",", "."))

        # Handle OCR split decimals: prev line "12." + current "5 m" -> 12.5
        is_integer_only = "." not in raw_value and "," not in raw_value
        if is_integer_only and i > 0:
            prev_text = texts[i - 1].replace("\xa0", " ").strip()
            prev_match = _SPLIT_DECIMAL_PATTERN.search(prev_text)
            if prev_match:
                whole_part = prev_match.group(1)
                normalized_value = float(f"{whole_part}.{raw_value}")
                if verbose:
                    logger.info(
                        "Matched (split decimal) lines [%d, %d]: %r + %r -> %s",
                        i - 1,
                        i,
                        texts[i - 1],
                        texts[i],
                        normalized_value,
                    )
                return normalized_value

        if verbose:
            logger.info(
                "Matched line [%d]: %r -> %s",
                i,
                texts[i],
                normalized_value,
            )
        return normalized_value

    return None


class AWSDistanceExtractor(DistanceOCR):
    def __init__(
        self, rekognition_client: Optional[Any] = None, region_name: str = "ap-south-1"
    ):
        self.client = rekognition_client or boto3.client(
            "rekognition", region_name=region_name
        )

    def get_image_distance(
        self,
        image_path: str,
        *,
        units: Optional[list[str]] = None,
        verbose: bool = False,
        **kwargs,
    ) -> float | None:
        """
        Gets live OCR distance for a specific frame.

        :param image_path: Path to the frame image file
        :param units: list of possible units
        :param verbose: if True, log all OCR strings and which lines matched (default False)
        :return: Distance or None
        """
        with open(image_path, "rb") as f:
            frame_data = f.read()
        text_lines = _detect_text_in_image(self.client, frame_data)

        if verbose:
            for i, text in enumerate(text_lines):
                logger.info("Checking string [%d]: %r", i, text)

        # Try to extract distance from each text line
        distance = _extract_distance(text_lines, units=units, verbose=verbose)

        return distance

    def get_distance_from_image_data(
        self,
        image_bytes: bytes,
        *,
        units: Optional[list[str]] = None,
        verbose: bool = False,
        **kwargs,
    ) -> float | None:
        """
        Gets live OCR distance from image bytes (no file path).

        :param image_bytes: Image data as bytes
        :param units: list of possible units
        :param verbose: if True, log all OCR strings and which lines matched (default False)
        :return: Distance or None
        """
        text_lines = _detect_text_in_image(self.client, image_bytes)

        if verbose:
            for i, text in enumerate(text_lines):
                logger.info("Checking string [%d]: %r", i, text)

        distance = _extract_distance(text_lines, units=units, verbose=verbose)

        return distance

def main() -> None:
    import argparse
    from dotenv import load_dotenv

    load_dotenv()

    parser = argparse.ArgumentParser(description="Extract distance from pipe inspection image using Gemini.")
    parser.add_argument("image", help="Path to the image file")
    args = parser.parse_args()

    extractor = AWSDistanceExtractor()
    result = extractor.get_image_distance(args.image)
    print(result if result is not None else "NONE")


if __name__ == "__main__":
    main()