from abc import ABCMeta, abstractmethod
from typing import Any
import astropy.units as u
import numpy as np
import tables
from astropy.table import Table
from astropy.time import Time
from scipy.interpolate import interp1d
from ..core import Component, traits
from ..io.hdf5dataformat import (
    DL0_TEL_POINTING_GROUP,
    DL1_FLATFIELD_IMAGE_GROUP,
    DL1_FLATFIELD_PEAK_TIME_GROUP,
    DL1_SKY_PEDESTAL_IMAGE_GROUP,
)
__all__ = [
    "MonitoringInterpolator",
    "LinearInterpolator",
    "PointingInterpolator",
    "ChunkInterpolator",
    "StatisticsInterpolator",
    "PedestalImageInterpolator",
    "FlatfieldImageInterpolator",
    "FlatfieldPeakTimeInterpolator",
]
[docs]
class MonitoringInterpolator(Component, metaclass=ABCMeta):
    """
    MonitoringInterpolator parent class.
    Parameters
    ----------
    h5file : None | tables.File
        An open hdf5 file with read access.
    """
    def __init__(self, h5file: None | tables.File = None, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        if h5file is not None and not isinstance(h5file, tables.File):
            raise TypeError("h5file must be a tables.File")
        self.h5file = h5file
[docs]
    @abstractmethod
    def __call__(self, tel_id: int, time: Time):
        """
        Interpolates monitoring data for a given timestamp
        Parameters
        ----------
        tel_id : int
            Telescope id.
        time : astropy.time.Time
            Time for which to interpolate the monitoring data.
        """
        pass 
[docs]
    @abstractmethod
    def add_table(self, tel_id: int, input_table: Table) -> None:
        """
        Add a table to this interpolator.
        This method reads input tables and creates instances of the needed interpolators.
        The first index of _interpolators needs to be tel_id, the second needs to be
        the name of the parameter that is to be interpolated.
        Parameters
        ----------
        tel_id : int
            Telescope id.
        input_table : astropy.table.Table
            Table of pointing values, expected columns
            are always ``time`` as ``Time`` column and
            other columns for the data that is to be interpolated.
        """
        pass 
    def _check_tables(self, input_table: Table) -> None:
        missing = self.required_columns - set(input_table.colnames)
        if len(missing) > 0:
            raise ValueError(f"Table is missing required column(s): {missing}")
        for col in self.expected_units:
            unit = input_table[col].unit
            if unit is None:
                if self.expected_units[col] is not None:
                    raise ValueError(
                        f"{col} must have units compatible with '{self.expected_units[col].name}'"
                    )
            elif not self.expected_units[col].is_equivalent(unit):
                if self.expected_units[col] is None:
                    raise ValueError(f"{col} must have units compatible with 'None'")
                else:
                    raise ValueError(
                        f"{col} must have units compatible with '{self.expected_units[col].name}'"
                    )
    def _read_parameter_table(self, tel_id: int) -> None:
        # prevent circular import between io and monitoring
        from ..io import read_table
        input_table = read_table(
            self.h5file,
            f"{self.telescope_data_group}/tel_{tel_id:03d}",
        )
        self.add_table(tel_id, input_table) 
[docs]
class LinearInterpolator(MonitoringInterpolator):
    """
    LinearInterpolator parent class.
    Parameters
    ----------
    h5file : None | tables.File
        An open hdf5 file with read access.
    """
    bounds_error = traits.Bool(
        default_value=True,
        help="If true, raises an exception when trying to extrapolate out of the given table",
    ).tag(config=True)
    extrapolate = traits.Bool(
        help="If bounds_error is False, this flag will specify whether values outside"
        "the available values are filled with nan (False) or extrapolated (True).",
        default_value=False,
    ).tag(config=True)
    def __init__(self, h5file: None | tables.File = None, **kwargs: Any) -> None:
        super().__init__(h5file, **kwargs)
        self._interpolators = {}
        self.interp_options: dict[str, Any] = dict(assume_sorted=True, copy=False)
        if self.bounds_error:
            self.interp_options["bounds_error"] = True
        elif self.extrapolate:
            self.interp_options["bounds_error"] = False
            self.interp_options["fill_value"] = "extrapolate"
        else:
            self.interp_options["bounds_error"] = False
            self.interp_options["fill_value"] = np.nan
    def _check_interpolators(self, tel_id: int) -> None:
        if tel_id not in self._interpolators:
            if self.h5file is not None:
                self._read_parameter_table(tel_id)  # might need to be removed
            else:
                raise KeyError(f"No table available for tel_id {tel_id}") 
[docs]
class PointingInterpolator(LinearInterpolator):
    """
    Interpolator for pointing and pointing correction data.
    """
    telescope_data_group = DL0_TEL_POINTING_GROUP
    required_columns = frozenset(["time", "azimuth", "altitude"])
    expected_units = {"azimuth": u.rad, "altitude": u.rad}
[docs]
    def __call__(self, tel_id: int, time: Time) -> tuple[u.Quantity, u.Quantity]:
        """
        Interpolate alt/az for given time and tel_id.
        Parameters
        ----------
        tel_id : int
            Telescope id.
        time : astropy.time.Time
            Time for which to interpolate the pointing.
        Returns
        -------
        altitude : astropy.units.Quantity[deg]
            Interpolated altitude angle.
        azimuth : astropy.units.Quantity[deg]
            Interpolated azimuth angle.
        """
        self._check_interpolators(tel_id)
        mjd = time.tai.mjd
        az = u.Quantity(self._interpolators[tel_id]["az"](mjd), u.rad, copy=False)
        alt = u.Quantity(self._interpolators[tel_id]["alt"](mjd), u.rad, copy=False)
        return alt, az 
[docs]
    def add_table(self, tel_id: int, input_table: Table) -> None:
        """
        Add a table to this interpolator.
        Parameters
        ----------
        tel_id : int
            Telescope id.
        input_table : astropy.table.Table
            Table of pointing values, expected columns
            are ``time`` as ``Time`` column, ``azimuth`` and ``altitude``
            as quantity columns for pointing and pointing correction data.
        """
        self._check_tables(input_table)
        if not isinstance(input_table["time"], Time):
            raise TypeError("'time' column of pointing table must be astropy.time.Time")
        input_table = input_table.copy()
        input_table.sort("time")
        az = input_table["azimuth"].quantity.to_value(u.rad)
        # prepare azimuth for interpolation by "unwrapping": i.e. turning
        # [359, 1] into [359, 361]. This assumes that if we get values like
        # [359, 1] the telescope moved 2 degrees through 0, not 358 degrees
        # the other way around. This should be true for all telescopes given
        # the sampling speed of pointing values and their maximum movement speed.
        # No telescope can turn more than 180° in 2 seconds.
        az = np.unwrap(az)
        alt = input_table["altitude"].quantity.to_value(u.rad)
        mjd = input_table["time"].tai.mjd
        self._interpolators[tel_id] = {}
        self._interpolators[tel_id]["az"] = interp1d(
            mjd, az, kind="linear", **self.interp_options
        )
        self._interpolators[tel_id]["alt"] = interp1d(
            mjd, alt, kind="linear", **self.interp_options
        ) 
 
[docs]
class ChunkInterpolator(MonitoringInterpolator):
    """
    Simple interpolator for overlapping chunks of data.
    """
    def __init__(self, h5file: None | tables.File = None, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        self.time_start = {}
        self.time_end = {}
        self.values = {}
        self.columns = list(self.required_columns)  # these will be the data columns
        self.columns.remove("time_start")
        self.columns.remove("time_end")
[docs]
    def __call__(
        self, tel_id: int, time: Time, timestamp_tolerance: u.Quantity = 0.0 * u.s
    ) -> float | dict[str, float]:
        """
        Interpolate overlapping chunks of data for a given time, tel_id, and column(s).
        Parameters
        ----------
        tel_id : int
            Telescope id.
        time : astropy.time.Time
            Time for which to interpolate the data.
        timestamp_tolerance : astropy.units.Quantity
            Time difference in seconds to consider two timestamps equal. Default is 0s.
        Returns
        -------
        interpolated : float or dict
            Interpolated data for the specified column(s).
        """
        if tel_id not in self.values:
            self._read_parameter_table(tel_id)
        result = {}
        for column in self.columns:
            result[column] = self._interpolate_chunk(
                tel_id, column, time, timestamp_tolerance
            )
        if len(self.columns) == 1:
            return result[self.columns[0]]
        return result 
[docs]
    def add_table(self, tel_id: int, input_table: Table) -> None:
        """
        Add a table to this interpolator for specific columns.
        Parameters
        ----------
        tel_id : int
            Telescope id.
        input_table : astropy.table.Table
            Table of values to be interpolated, expected columns
            are ``time_start`` as ``validity start Time`` column,
            ``time_end`` as ``validity end Time`` and the specified columns
            for the data of the chunks.
        """
        self._check_tables(input_table)
        input_table = input_table.copy()
        input_table.sort("time_start")
        self.values[tel_id] = {}
        self.time_start[tel_id] = input_table["time_start"].to_value("mjd")
        self.time_end[tel_id] = input_table["time_end"].to_value("mjd")
        for column in self.columns:
            self.values[tel_id][column] = input_table[column] 
    def _interpolate_chunk(
        self, tel_id, column, time: Time, timestamp_tolerance: u.Quantity = 0.0 * u.s
    ) -> float | list[float]:
        """
        Interpolates overlapping chunks of data preferring earlier chunks if valid
        Parameters
        ----------
        tel_id : int
            tel_id for which data is to be interpolated
        time : astropy.time.Time
            Time for which to interpolate the data.
        timestamp_tolerance : astropy.units.Quantity
            Time difference in seconds to consider two timestamps equal. Default is 0s.
        """
        time_start = self.time_start[tel_id]
        time_end = self.time_end[tel_id]
        values = self.values[tel_id][column]
        mjd_times = np.atleast_1d(time.to_value("mjd"))
        # Convert timestamp tolerance to MJD days
        tolerance_mjd = timestamp_tolerance.to_value("day")
        # Find the index of the closest preceding start time
        preceding_indices = np.searchsorted(time_start, mjd_times, side="right") - 1
        interpolated_values = []
        for mjd, preceding_index in zip(mjd_times, preceding_indices):
            # Default value is NaN or array of NaNs
            value = (
                np.nan if np.isscalar(values[0]) else np.full_like(values[0], np.nan)
            )
            # Check if the requested time is before the first chunk
            if preceding_index < 0:
                # If the time is before the first chunk and not within tolerance, return NaN
                if (time_start[0] - tolerance_mjd) > mjd:
                    interpolated_values.append(value)
                    continue
                else:
                    # Use the first chunk since it's within tolerance
                    preceding_index = 0
            # Check if the time is within the valid range of the chunk
            if (
                (time_start[preceding_index] - tolerance_mjd)
                <= mjd
                <= (time_end[preceding_index] + tolerance_mjd)
            ):
                value = values[preceding_index]
                # If no NaN values, we can append immediately and continue
                if np.all(~np.isnan(value)):
                    interpolated_values.append(value)
                    continue
            # Fill NaN values from earlier overlapping chunks
            for i in range(preceding_index - 1, -1, -1):
                if (
                    (time_start[i] - tolerance_mjd)
                    <= mjd
                    <= (time_end[i] + tolerance_mjd)
                ):
                    # Only fill NaN values
                    value = np.where(np.isnan(value), values[i], value)
                    # If no NaN values left, we can stop
                    if np.all(~np.isnan(value)):
                        break
            interpolated_values.append(value)
        # If only a single time was provided, return the value directly
        if len(interpolated_values) == 1:
            interpolated_values = interpolated_values[0]
        return interpolated_values 
[docs]
class StatisticsInterpolator(ChunkInterpolator):
    """Interpolator for statistics tables."""
    required_columns = frozenset(["time_start", "time_end", "mean", "median", "std"])
    expected_units = {"mean": None, "median": None, "std": None} 
[docs]
class PedestalImageInterpolator(StatisticsInterpolator):
    """Interpolator for pedestal image tables."""
    telescope_data_group = DL1_SKY_PEDESTAL_IMAGE_GROUP 
[docs]
class FlatfieldImageInterpolator(StatisticsInterpolator):
    """Interpolator for flatfield image tables."""
    telescope_data_group = DL1_FLATFIELD_IMAGE_GROUP 
[docs]
class FlatfieldPeakTimeInterpolator(StatisticsInterpolator):
    """Interpolator for flatfield peak time tables."""
    telescope_data_group = DL1_FLATFIELD_PEAK_TIME_GROUP