"""
Traitlet implementations for ctapipe
"""
import os
import pathlib
import warnings
from collections.abc import Mapping
from urllib.parse import urlparse
import astropy.units as u
import traitlets
import traitlets.config
from astropy.time import Time
from traitlets import Undefined
from ctapipe.core.plugins import detect_and_import_plugins
from .component import Component, non_abstract_children
from .telescope_component import TelescopeParameter
__all__ = [
    # Implemented here
    "AstroQuantity",
    "AstroTime",
    "BoolTelescopeParameter",
    "IntTelescopeParameter",
    "FloatTelescopeParameter",
    "classes_with_traits",
    "create_class_enum_trait",
    "has_traits",
    # imported from traitlets
    "Path",
    "Bool",
    "CRegExp",
    "CaselessStrEnum",
    "CInt",
    "Dict",
    "Enum",
    "Float",
    "Int",
    "Integer",
    "List",
    "Long",
    "Set",
    "TraitError",
    "Tuple",
    "Unicode",
    "flag",
    "observe",
]
import logging
logger = logging.getLogger(__name__)
class PathError:
    """Signal configuration error of a Path"""
    def __init__(self, path, reason):
        self.path = path
        self.reason = reason
    def __repr__(self):
        return f"'{self.path}': {self.reason}"
class NoneDefaultNotAllowedWarning(UserWarning):
    """For the default_value=None, allow_none=False case."""
# Aliases
Bool = traitlets.Bool
Int = traitlets.Int
CInt = traitlets.CInt
Integer = traitlets.Integer
Float = traitlets.Float
Long = traitlets.Long
Unicode = traitlets.Unicode
Dict = traitlets.Dict
Enum = traitlets.Enum
List = traitlets.List
Set = traitlets.Set
CRegExp = traitlets.CRegExp
CaselessStrEnum = traitlets.CaselessStrEnum
UseEnum = traitlets.UseEnum
TraitError = traitlets.TraitError
TraitType = traitlets.TraitType
Tuple = traitlets.Tuple
observe = traitlets.observe
flag = traitlets.config.boolean_flag
[docs]
class AstroQuantity(TraitType):
    """A trait containing an ``astropy.units`` quantity."""
    def __init__(self, physical_type=None, **kwargs):
        super().__init__(**kwargs)
        if physical_type is not None:
            if isinstance(physical_type, u.PhysicalType):
                self.physical_type = physical_type
            elif isinstance(physical_type, u.UnitBase):
                self.physical_type = u.get_physical_type(physical_type)
            else:
                raise TraitError(
                    "Given physical type must be either of type"
                    " astropy.units.PhysicalType or a subclass of"
                    f" astropy.units.UnitBase, was {type(physical_type)}."
                )
        else:
            self.physical_type = physical_type
        if self.default_value is not Undefined and self.physical_type is not None:
            default_type = u.get_physical_type(self.default_value)
            if default_type != self.physical_type:
                raise TraitError(
                    f"Given physical type {self.physical_type} does not match"
                    f" physical type of the default value, {default_type}."
                )
    @property
    def info_text(self):
        info = "An ``astropy.units.Quantity`` instance"
        if self.allow_none:
            info += "or None"
        return info
[docs]
    def validate(self, obj, value):
        try:
            if isinstance(value, Mapping):
                quantity = u.Quantity(**value)
            else:
                quantity = u.Quantity(value)
        except TypeError:
            self.error(obj, value)
        except ValueError:
            self.error(obj, value)
        if self.physical_type is not None:
            given_type = u.get_physical_type(quantity)
            if given_type != self.physical_type:
                raise TraitError(
                    f"Given quantity is of physical type {given_type}."
                    f" Expected {self.physical_type}."
                )
        return quantity 
 
[docs]
class AstroTime(TraitType):
    """A trait representing a point in Time, as understood by ``astropy.time``."""
[docs]
    def validate(self, obj, value):
        """try to parse and return an ISO time string"""
        try:
            the_time = Time(value)
            the_time.format = "iso"
            return the_time
        except ValueError:
            self.error(obj, value) 
    @property
    def info_text(self):
        info = "an ISO8601 datestring or Time instance"
        if self.allow_none:
            info += "or None"
        return info 
[docs]
class Path(TraitType):
    """
    A path Trait for input/output files.
    Attributes
    ----------
    exists: boolean or None
        If True, path must exist, if False path must not exist
    directory_ok: boolean
        If False, path must not be a directory
    file_ok: boolean
        If False, path must not be a file
    """
    def __init__(
        self,
        default_value=None,
        allow_none=True,
        exists=None,
        directory_ok=True,
        file_ok=True,
        **kwargs,
    ):
        super().__init__(default_value=default_value, allow_none=allow_none, **kwargs)
        if self.default_value is None and not self.allow_none:
            msg = (
                "Setting default_value=None and allow_none will result in misleading error messages."
                " To require an value, set allow_none=True, default_value=None and check later that"
                " the value is not None"
            )
            warnings.warn(msg, NoneDefaultNotAllowedWarning)
        self.exists = exists
        self.directory_ok = directory_ok
        self.file_ok = file_ok
    @property
    def info_text(self):
        info = "a pathlib.Path or non-empty str for "
        if self.exists is True:
            info += "an existing"
        elif self.exists is False:
            info += "a not existing"
        else:
            info += "a"
        if self.directory_ok and self.file_ok:
            info += " directory or file"
        else:
            if self.file_ok:
                info += " file"
            if self.directory_ok:
                info += "directory"
        if self.allow_none:
            info += " or None"
        return info
[docs]
    def validate(self, obj, value):
        """Validate trait value."""
        if value is None or value is Undefined:
            if self.allow_none:
                return value
            else:
                self.error(obj, value)
        elif isinstance(value, bytes):
            value = os.fsdecode(value)
        elif not isinstance(value, str | pathlib.Path):
            self.error(obj, value)
        # expand any environment variables in the path:
        value = os.path.expandvars(value)
        value = self._validate_str(obj, value)
        return self._check_path(value) 
    def _validate_str(self, obj, value):
        if value == "":
            self.error(obj, value)
        try:
            url = urlparse(value)
        except ValueError:
            self.error(obj, value)
        if url.scheme in ("http", "https"):
            # here to avoid circular import, since every module imports
            # from ctapipe.core
            from ctapipe.utils.download import download_cached
            value = download_cached(value, progress=True)
        elif url.scheme == "dataset":
            # here to avoid circular import, since every module imports
            # from ctapipe.core
            from ctapipe.utils import get_dataset_path
            value = get_dataset_path(value.partition("dataset://")[2])
        elif url.scheme in ("", "file"):
            value = pathlib.Path(url.netloc, url.path)
        else:
            self.error(obj, value)
        return value
    def _check_path(self, value):
        value = value.absolute()
        exists = value.exists()
        if self.exists is not None and exists != self.exists:
            msg = "does not exist" if self.exists else "exists but must not"
            raise TraitError(PathError(value, msg), self.info(), self)
        if exists:
            if not self.directory_ok and value.is_dir():
                raise TraitError(PathError(value, "is a directory"), self.info(), self)
            if not self.file_ok and value.is_file():
                raise TraitError(PathError(value, "is a file"), self.info(), self)
        return value 
[docs]
def create_class_enum_trait(base_class, default_value, help=None, allow_none=False):
    """create a configurable CaselessStrEnum traitlet from baseclass
    the enumeration should contain all names of non_abstract_children()
    of said baseclass and the default choice should be given by
    ``base_class._default`` name.
    default must be specified and must be the name of one child-class
    """
    if help is None:
        help = "{} to use.".format(base_class.__name__)
    choices = [cls.__name__ for cls in non_abstract_children(base_class)]
    if default_value not in choices:
        raise ValueError(f"{default_value} is not in choices: {choices}")
    return CaselessStrEnum(
        choices,
        default_value=default_value,
        help=help,
        allow_none=allow_none,
    ).tag(config=True) 
class ComponentName(Unicode):
    """A trait that is the name of a Component class"""
    def __init__(self, cls, **kwargs):
        # we need to prevent triggering importing plugins at
        # import time to avoid circular imports, this flag is used
        # to prevent calling the full plugin mechanism at definition
        # time of a `ComponentName`
        self._init_done = False
        if not issubclass(cls, Component):
            raise TypeError(f"cls must be a Component, got {cls}")
        self.cls = cls
        super().__init__(**kwargs)
        if "help" not in kwargs:
            self.help = f"The name of a {cls.__name__} subclass"
        self._init_done = True
    @property
    def help(self):
        if self._init_done:
            children = list(self.cls.non_abstract_subclasses())
        else:
            children = []
        return f"{self._help}. Possible values: {children}"
    @help.setter
    def help(self, value):
        self._help = value
    @property
    def info_text(self):
        if self._init_done:
            return f"Any of {list(self.cls.non_abstract_subclasses())}"
        else:
            return f"Any subclass of {self.cls}"
    def validate(self, obj, value):
        if self.allow_none and value is None:
            return None
        if value in self.cls.non_abstract_subclasses():
            return value
        self.error(obj, value)
class ComponentNameList(List):
    """A trait that is a list of Component classes"""
    def __init__(self, cls, **kwargs):
        # we need to prevent triggering importing plugins at
        # import time to avoid circular imports, this flag is used
        # to prevent calling the full plugin mechanism at definition
        # time of a `ComponentNameList`
        self._init_done = False
        if not issubclass(cls, Component):
            raise TypeError(f"cls must be a Component, got {cls}")
        self.cls = cls
        trait = ComponentName(cls)
        super().__init__(trait=trait, **kwargs)
        if "help" not in kwargs:
            self.help = f"A list of {cls.__name__} subclass names"
        self._init_done = True
    @property
    def help(self):
        if self._init_done:
            children = list(self.cls.non_abstract_subclasses())
        else:
            children = []
        return f"{self._help}. Possible values: {children}"
    @help.setter
    def help(self, value):
        self._help = value
    @property
    def info_text(self):
        if self._init_done:
            return f"A list of {list(self.cls.non_abstract_subclasses())}"
        else:
            return f"A list of {self.cls} subclasses"
[docs]
def classes_with_traits(base_class):
    """Returns a list of the base class plus its non-abstract children
    if they have traits"""
    if hasattr(base_class, "plugin_entry_point"):
        detect_and_import_plugins(base_class.plugin_entry_point)
    all_classes = [base_class] + non_abstract_children(base_class)
    with_traits = []
    for cls in all_classes:
        if has_traits(cls):
            with_traits.append(cls)
        # add subcomponents
        if hasattr(cls, "classes"):
            # we will ignore failing classes to not break anyone
            if isinstance(cls.classes, List):
                classes = cls.classes.default()
            else:
                classes = cls.classes
            try:
                for component in classes:
                    with_traits.extend(classes_with_traits(component))
            except Exception:
                pass
    return with_traits 
[docs]
def has_traits(cls, ignore=("config", "parent")):
    """True if cls has any traits apart from the usual ones
    all our components have at least 'config' and 'parent' as traitlets
    this is inherited from `traitlets.config.Configurable` so we ignore them
    here.
    """
    return bool(set(cls.class_trait_names()) - set(ignore)) 
[docs]
class FloatTelescopeParameter(TelescopeParameter):
    """a `~ctapipe.core.telescope_component.TelescopeParameter` with Float trait type"""
    def __init__(self, **kwargs):
        """Create a new FloatTelescopeParameter"""
        super().__init__(trait=Float(), **kwargs) 
[docs]
class IntTelescopeParameter(TelescopeParameter):
    """a `~ctapipe.core.telescope_component.TelescopeParameter` with Int trait type"""
    def __init__(self, **kwargs):
        """Create a new IntTelescopeParameter"""
        super().__init__(trait=Integer(), **kwargs) 
[docs]
class BoolTelescopeParameter(TelescopeParameter):
    """a `~ctapipe.core.telescope_component.TelescopeParameter` with Bool trait type"""
    def __init__(self, **kwargs):
        """Create a new BoolTelescopeParameter"""
        super().__init__(trait=Bool(), **kwargs)