# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
Utilities to generate toymodel (fake) reconstruction inputs for testing
purposes.
Examples:
.. code-block:: python
    >>> from ctapipe.instrument import CameraGeometry
    >>> geom = CameraGeometry.make_rectangular(20, 20)
    >>> showermodel = Gaussian(
    ...    x=0.25 * u.m, y=0.0 * u.m,
    ...    length=0.1 * u.m, width=0.02 * u.m,
    ...    psi='40d'
    ... )
    >>> image, signal, noise = showermodel.generate_image(geom, intensity=1000)
    >>> print(image.shape)
    (400,)
"""
from abc import ABCMeta, abstractmethod
import astropy.units as u
import numpy as np
from numpy.random import default_rng
from scipy.integrate import quad
from scipy.ndimage import convolve1d
from scipy.special import ellipe
from scipy.stats import multivariate_normal, norm, skewnorm
from ctapipe.calib.camera.gainselection import GainChannel
from ctapipe.image.hillas import camera_to_shower_coordinates
from ctapipe.utils import linalg
__all__ = [
    "WaveformModel",
    "Gaussian",
    "SkewedGaussian",
    "ImageModel",
    "obtain_time_image",
]
TOYMODEL_RNG = default_rng(0)
VALID_GAIN_CHANNEL = {"HIGH", "LOW", "ALL"}
[docs]
def obtain_time_image(x, y, centroid_x, centroid_y, psi, time_gradient, time_intercept):
    """Create a pulse time image for a toymodel shower. Assumes the time development
    occurs only along the longitudinal (major) axis of the shower, and scales
    linearly with distance along the axis.
    Parameters
    ----------
    x : u.Quantity[length]
        X camera coordinate to evaluate the time at.
        Usually the array of pixel X positions
    y : u.Quantity[length]
        Y camera coordinate to evaluate the time at.
        Usually the array of pixel Y positions
    centroid_x : u.Quantity[length]
        X camera coordinate for the centroid of the shower
    centroid_y : u.Quantity[length]
        Y camera coordinate for the centroid of the shower
    psi : convertible to `astropy.coordinates.Angle`
        rotation angle about the centroid (0=x-axis)
    time_gradient : u.Quantity[time/angle]
        Rate at which the time changes with distance along the shower axis
    time_intercept : u.Quantity[time]
        Pulse time at the shower centroid
    Returns
    -------
    float or ndarray
        Pulse time in nanoseconds at (x, y)
    """
    unit = x.unit
    x = x.to_value(unit)
    y = y.to_value(unit)
    centroid_x = centroid_x.to_value(unit)
    centroid_y = centroid_y.to_value(unit)
    psi = psi.to_value(u.rad)
    time_gradient = time_gradient.to_value(u.ns / unit)
    time_intercept = time_intercept.to_value(u.ns)
    longitudinal, _ = camera_to_shower_coordinates(x, y, centroid_x, centroid_y, psi)
    return longitudinal * time_gradient + time_intercept 
[docs]
class ImageModel(metaclass=ABCMeta):
[docs]
    @abstractmethod
    def pdf(self, x, y):
        """Probability density function.""" 
[docs]
    def generate_image(self, camera, intensity=50, nsb_level_pe=20, rng=None):
        """Generate a randomized DL1 shower image.
        For the signal, poisson random numbers are drawn from
        the expected signal distribution for each pixel.
        For the background, for each pixel a poisson random number
        if drawn with mean ``nsb_level_pe``.
        Parameters
        ----------
        camera : `ctapipe.instrument.CameraGeometry`
            camera geometry object
        intensity : int
            Total number of photo electrons to generate
        nsb_level_pe : type
            level of NSB/pedestal in photo-electrons
        Returns
        -------
        image: array with length n_pixels containing the image
        signal: only the signal part of image
        noise: only the noise part of image
        """
        if rng is None:
            rng = TOYMODEL_RNG
        expected_signal = self.expected_signal(camera, intensity)
        signal = rng.poisson(expected_signal)
        noise = rng.poisson(nsb_level_pe, size=signal.shape)
        image = (signal + noise) - np.mean(noise)
        return image, signal, noise 
[docs]
    def expected_signal(self, camera, intensity):
        """Expected signal in each pixel for the given camera
        and total intensity.
        Parameters
        ----------
        camera: `ctapipe.instrument.CameraGeometry`
            camera geometry object
        intensity: int
            Total number of expected photo electrons
        Returns
        -------
        image: array with length n_pixels containing the image
        """
        pdf = self.pdf(camera.pix_x, camera.pix_y)
        return pdf * intensity * camera.pix_area.value 
 
[docs]
class Gaussian(ImageModel):
    def __init__(self, x, y, length, width, psi):
        """Create 2D Gaussian model for a shower image in a camera.
        Parameters
        ----------
        centroid : u.Quantity[length, shape=(2, )]
            position of the centroid of the shower in camera coordinates
        width: u.Quantity[length]
            width of shower (minor axis)
        length: u.Quantity[length]
            length of shower (major axis)
        psi : u.Quantity[angle]
            rotation angle about the centroid (0=x-axis)
        Returns
        -------
        a `scipy.stats` object
        """
        self.unit = x.unit
        self.x = x
        self.y = y
        self.width = width
        self.length = length
        self.psi = psi
        aligned_covariance = np.array(
            [
                [self.length.to_value(self.unit) ** 2, 0],
                [0, self.width.to_value(self.unit) ** 2],
            ]
        )
        # rotate by psi angle: C' = R C R+
        rotation = linalg.rotation_matrix_2d(self.psi)
        rotated_covariance = rotation @ aligned_covariance @ rotation.T
        self.dist = multivariate_normal(
            mean=u.Quantity([self.x, self.y]).to_value(self.unit),
            cov=rotated_covariance,
        )
[docs]
    def pdf(self, x, y):
        """2d probability for photon electrons in the camera plane"""
        X = np.column_stack([x.to_value(self.unit), y.to_value(self.unit)])
        return self.dist.pdf(X) 
 
[docs]
class SkewedGaussian(ImageModel):
    """A shower image that has a skewness along the major axis."""
    def __init__(self, x, y, length, width, psi, skewness):
        """Create 2D skewed Gaussian model for a shower image in a camera.
        Skewness is only applied along the main shower axis.
        See https://en.wikipedia.org/wiki/Skew_normal_distribution and
        https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.skewnorm.html
        for details.
        Parameters
        ----------
        centroid : u.Quantity[length, shape=(2, )]
            position of the centroid of the shower in camera coordinates
        width: u.Quantity[length]
            width of shower (minor axis)
        length: u.Quantity[length]
            length of shower (major axis)
        psi : u.Quantity[angle]
            rotation angle about the centroid (0=x-axis)
        Returns
        -------
        a `scipy.stats` object
        """
        self.unit = x.unit
        self.x = x
        self.y = y
        self.width = width
        self.length = length
        self.psi = psi
        self.skewness = skewness
        a, loc, scale = self._moments_to_parameters()
        self.long_dist = skewnorm(a=a, loc=loc, scale=scale)
        self.trans_dist = norm(loc=0, scale=self.width.to_value(self.unit))
        self.rotation = linalg.rotation_matrix_2d(-self.psi)
        self.mu = u.Quantity([self.x, self.y]).to_value(self.unit)
    def _moments_to_parameters(self):
        """Returns loc and scale from mean, std and skewnewss."""
        # see https://en.wikipedia.org/wiki/Skew_normal_distribution#Estimation
        skew23 = np.abs(self.skewness) ** (2 / 3)
        delta = np.sign(self.skewness) * np.sqrt(
            (np.pi / 2 * skew23) / (skew23 + (0.5 * (4 - np.pi)) ** (2 / 3))
        )
        a = delta / np.sqrt(1 - delta**2)
        scale = self.length.to_value(self.unit) / np.sqrt(1 - 2 * delta**2 / np.pi)
        loc = -scale * delta * np.sqrt(2 / np.pi)
        return a, loc, scale
[docs]
    def pdf(self, x, y):
        """2d probability for photon electrons in the camera plane."""
        pos = np.column_stack([x.to_value(self.unit), y.to_value(self.unit)])
        long, trans = self.rotation @ (pos - self.mu).T
        return self.trans_dist.pdf(trans) * self.long_dist.pdf(long) 
 
class RingGaussian(ImageModel):
    """A shower image consisting of a ring with gaussian radial profile.
    Simplified model for a muon ring, optionally with asymmetry based on the muon impact.
    Parameters
    ----------
    x : u.Quantity
        x-coordinate of the ring center
    y : u.Quantity
        y-coordinate of the ring center
    radius : u.Quantity
        ring radius
    sigma : u.Quantity
        std. dev. along the radial axis, i.e. the width of point spread function
    rho : float
        Ratio of the muon impact parameter to the mirror radius. A value of
        0 means a homogeneously illuminated ring, values > 1 result in only
        partially illuminated rings.
    phi0 : u.Quantity[angle]
        Position angle of the muon impact on the mirror.
    """
    def __init__(
        self,
        x,
        y,
        radius,
        sigma,
        rho=0.0,
        phi0=0 * u.deg,
    ):
        self.unit = x.unit
        self.x = x
        self.y = y
        self.sigma = sigma
        self.radius = radius
        self.r_dist = norm(
            self.radius.to_value(self.unit), self.sigma.to_value(self.unit)
        )
        self.rho = rho
        self.phi0 = phi0.to(u.rad)
        if rho >= 1:
            self.phi_max = np.arcsin(1 / self.rho)
            integral, _ = quad(self._inner_term, 0, self.phi_max)
            self.norm = 1.0 / (2.0 * integral)
        else:
            self.norm = 0.25 / ellipe(self.rho**2)
    def pdf(self, x, y):
        """2d probability for photon electrons in the camera plane."""
        dx = (x - self.x).to_value(self.unit)
        dy = (y - self.y).to_value(self.unit)
        r = np.sqrt(dx**2 + dy**2)
        phi = np.arctan2(dy, dx)
        return self.r_dist.pdf(r) * self._pdf_phi(phi)
    def _inner_term(self, phi):
        return np.sqrt(1 - self.rho**2 * np.sin(phi) ** 2)
    def _pdf_phi(self, phi):
        phi = phi + self.phi0.to_value(u.rad)
        if self.rho >= 1:
            mask = np.abs(phi) < self.phi_max
            pdf = np.zeros(phi.shape)
            pdf[mask] = self._inner_term(phi[mask])
            return self.norm * pdf
        else:
            return self.norm * (self._inner_term(phi) + self.rho * np.cos(phi))