import xml.etree.ElementTree as ET
from datetime import date, time
from typing import Any, Dict, Type, TypeVar
from xml.dom import minidom

from pydantic import BaseModel

from .abbreviated_enum import AbbreviatedEnum

T = TypeVar("T", bound="XMLModel")


class XMLModel(BaseModel):
    """
    Base class for all XML models. Provides hooks and helpers for XML serialization and deserialization.
    """

    @property
    def element_name_map(self) -> Dict[str, str]:
        config = getattr(self.__class__, "Config", object())
        if hasattr(config, "element_name_map"):
            return config.element_name_map
        return {}

    @property
    def attribute_fields(self) -> list[str]:
        config = getattr(self.__class__, "Config", object())
        if hasattr(config, "attribute_fields"):
            return config.attribute_fields
        return []

    @property
    def ignore_fields(self) -> list[str]:
        config = getattr(self.__class__, "Config", object())
        if hasattr(config, "ignore_fields"):
            return config.ignore_fields
        return []

    @property
    def flatten_fields(self) -> list[str]:
        config = getattr(self.__class__, "Config", object())
        if hasattr(config, "flatten_fields"):
            return config.flatten_fields
        return []

    def prepare_for_xml(self) -> Dict[str, Any]:
        """
        Hook for customizing the data before serialization.
        Override in subclasses to customize export shape (e.g., wrap items, rename, etc.)
        """

        return self

    def get_extra_fields(self) -> Dict[str, Any]:
        """
        Hook for injecting extra fields dynamically during XML serialization.
        """
        return {}

    @staticmethod
    def _xml_to_dict(xml_string: str, encoding: str = "utf-8") -> Dict[str, Any]:
        """
        Convert an XML string to a dictionary.
        Not yet implemented — customize as needed.
        """
        raise NotImplementedError("Deserialization from XML not implemented yet.")

    @classmethod
    def _object_to_element(cls, obj: Any, tag: str) -> ET.Element:
        """
        Recursively converts Python object (dict, list, primitive, XMLModel) into XML Element.
        """
        tag = cls.Config.element_name_map.get(tag, tag) if hasattr(cls, "Config") else tag
        element = ET.Element(tag)

        if isinstance(obj, XMLModel):
            obj = obj.prepare_for_xml()
            obj_dict = obj.__dict__.copy()
            obj_dict.update(obj.get_extra_fields())
            for field, value in obj_dict.items():
                if field in obj.ignore_fields:
                    continue
                field_tag = obj.element_name_map.get(field, field)
                flatten = field in obj.flatten_fields

                if field in obj.attribute_fields and value is not None:
                    element.set(field_tag, str(value))
                elif isinstance(value, list) and flatten:
                    for item in value:
                        child_el = cls._object_to_element(item, field_tag)
                        element.append(child_el)
                else:
                    child_el = cls._object_to_element(value, field_tag)
                    element.append(child_el)
        elif isinstance(obj, dict):
            for key, value in obj.items():
                child_el = cls._object_to_element(value, key)
                element.append(child_el)
        elif isinstance(obj, list):
            for item in obj:
                child_el = cls._object_to_element(item, tag)
                element.append(child_el)
        elif isinstance(obj, date):
            element.text = obj.strftime("%Y/%m/%d")
        elif isinstance(obj, time):
            element.text = obj.strftime("%I:%M:%S %p")
        elif isinstance(obj, AbbreviatedEnum):
            element.text = obj.label
        elif obj is not None:
            element.text = str(obj)

        return element

    def to_xml(self, root_name: str = "root", pretty_print: bool = True, encoding: str = "utf-8") -> str:
        """
        Serialize the model to an XML string.
        """
        prepared = self.prepare_for_xml()
        root_element = self.__class__._object_to_element(prepared, root_name)
        xml_bytes = ET.tostring(root_element, encoding=encoding)

        if pretty_print:
            dom = minidom.parseString(xml_bytes)
            xml_string = dom.toprettyxml(indent="  ", encoding=encoding)
            return xml_string.decode(encoding).split("\n", 1)[1]

        return xml_bytes.decode(encoding)

    @classmethod
    def from_xml(cls: Type[T], xml_string: str) -> T:
        return cls.model_validate(cls._xml_to_dict(xml_string))

    @classmethod
    def read_xml_file(cls: Type[T], file_path: str) -> T:
        with open(file_path, "r", encoding="utf-8") as f:
            xml_string = f.read()
        return cls.from_xml(xml_string)

    def write_xml_file(self, file_path: str, root_name: str = "root", pretty_print: bool = True) -> None:
        xml_string = self.to_xml(root_name, pretty_print)
        with open(file_path, "w", encoding="utf-8") as f:
            f.write(xml_string)
