"""Tool to generate selections for IRFs production"""
import astropy.units as u
from astropy.table import vstack
from ..core import Provenance, Tool, traits
from ..core.traits import AstroQuantity, Integer, classes_with_traits
from ..io import DL2EventLoader
from ..irf import Spectra
from ..irf.optimize import CutOptimizerBase
__all__ = ["EventSelectionOptimizer"]
[docs]
class EventSelectionOptimizer(Tool):
    "Tool to create optimized cuts for IRF generation"
    name = "ctapipe-optimize-event-selection"
    description = __doc__
    examples = """
    ctapipe-optimize-event-selection \\
        --gamma-file gamma.dl2.h5 \\
        --proton-file proton.dl2.h5 \\
        --electron-file electron.dl2.h5 \\
        --output cuts.fits
    """
    gamma_file = traits.Path(
        default_value=None, directory_ok=False, help="Gamma input filename and path"
    ).tag(config=True)
    gamma_target_spectrum = traits.UseEnum(
        Spectra,
        default_value=Spectra.CRAB_HEGRA,
        help="Name of the spectrum used for weights of gamma events.",
    ).tag(config=True)
    proton_file = traits.Path(
        default_value=None,
        directory_ok=False,
        allow_none=True,
        help=(
            "Proton input filename and path. "
            "Not needed, if ``optimization_algorithm = 'PercentileCuts'``."
        ),
    ).tag(config=True)
    proton_target_spectrum = traits.UseEnum(
        Spectra,
        default_value=Spectra.IRFDOC_PROTON_SPECTRUM,
        help="Name of the spectrum used for weights of proton events.",
    ).tag(config=True)
    electron_file = traits.Path(
        default_value=None,
        directory_ok=False,
        allow_none=True,
        help=(
            "Electron input filename and path. "
            "Not needed, if ``optimization_algorithm = 'PercentileCuts'``."
        ),
    ).tag(config=True)
    electron_target_spectrum = traits.UseEnum(
        Spectra,
        default_value=Spectra.IRFDOC_ELECTRON_SPECTRUM,
        help="Name of the spectrum used for weights of electron events.",
    ).tag(config=True)
    chunk_size = Integer(
        default_value=100000,
        allow_none=True,
        help="How many subarray events to load at once when preselecting events.",
    ).tag(config=True)
    output_path = traits.Path(
        default_value="./Selection_Cuts.fits",
        allow_none=False,
        directory_ok=False,
        help="Output file storing optimization result",
    ).tag(config=True)
    obs_time = AstroQuantity(
        default_value=u.Quantity(50, u.hour),
        physical_type=u.physical.time,
        help=(
            "Observation time in the form ``<value> <unit>``."
            " This is used for flux normalization when calculating sensitivities."
        ),
    ).tag(config=True)
    optimization_algorithm = traits.ComponentName(
        CutOptimizerBase,
        default_value="PointSourceSensitivityOptimizer",
        help="The cut optimization algorithm to be used.",
    ).tag(config=True)
    aliases = {
        ("g", "gamma-file"): "EventSelectionOptimizer.gamma_file",
        ("p", "proton-file"): "EventSelectionOptimizer.proton_file",
        ("e", "electron-file"): "EventSelectionOptimizer.electron_file",
        ("o", "output"): "EventSelectionOptimizer.output_path",
        "chunk_size": "EventSelectionOptimizer.chunk_size",
    }
    classes = [DL2EventLoader] + classes_with_traits(CutOptimizerBase)
[docs]
    def setup(self):
        """
        Initialize components from config.
        """
        self.optimizer = CutOptimizerBase.from_name(
            self.optimization_algorithm, parent=self
        )
        self.event_loaders = {
            "gammas": DL2EventLoader(
                parent=self,
                file=self.gamma_file,
                target_spectrum=self.gamma_target_spectrum,
            )
        }
        if self.optimizer.needs_background:
            if not self.proton_file or (
                self.proton_file and not self.proton_file.exists()
            ):
                raise ValueError(
                    "Need a proton file for cut optimization "
                    f"using {self.optimization_algorithm}."
                )
            self.event_loaders["protons"] = DL2EventLoader(
                parent=self,
                file=self.proton_file,
                target_spectrum=self.proton_target_spectrum,
            )
            if self.electron_file and self.electron_file.exists():
                self.event_loaders["electrons"] = DL2EventLoader(
                    parent=self,
                    file=self.electron_file,
                    target_spectrum=self.electron_target_spectrum,
                )
            else:
                self.log.warning("Optimizing cuts without electron file.") 
[docs]
    def start(self):
        """
        Load events and optimize g/h (and theta) cuts.
        """
        reduced_events = dict()
        for particle_type, loader in self.event_loaders.items():
            self.log.info("Loading %s from '%s'", particle_type, loader.file)
            events, count, meta = loader.load_preselected_events(
                self.chunk_size, self.obs_time
            )
            if self.optimizer.needs_background:
                events = loader.make_event_weights(
                    events,
                    meta["spectrum"],
                    particle_type,
                    (
                        self.optimizer.min_background_fov_offset,
                        self.optimizer.max_background_fov_offset,
                    ),
                )
            reduced_events[particle_type] = events
            reduced_events[f"{particle_type}_count"] = count
            if particle_type == "gammas":
                self.sim_info = meta["sim_info"]
                self.gamma_spectrum = meta["spectrum"]
            self.log.info("Using %8d %s", count, particle_type)
        self.signal_events = reduced_events["gammas"]
        if not self.optimizer.needs_background:
            self.log.debug("Loaded %d gammas" % reduced_events["gammas_count"])
            self.log.debug("Keeping %d gammas" % len(reduced_events["gammas"]))
            self.log.info("Optimizing cuts using %d signal" % len(self.signal_events))
            self.background_events = None
        else:
            if "electrons" not in reduced_events.keys():
                reduced_events["electrons"] = []
                reduced_events["electrons_count"] = 0
            self.log.debug(
                "Loaded %d gammas, %d protons, %d electrons"
                % (
                    reduced_events["gammas_count"],
                    reduced_events["protons_count"],
                    reduced_events["electrons_count"],
                )
            )
            self.log.debug(
                "Keeping %d gammas, %d protons, %d electrons"
                % (
                    len(reduced_events["gammas"]),
                    len(reduced_events["protons"]),
                    len(reduced_events["electrons"]),
                )
            )
            self.background_events = vstack(
                [reduced_events["protons"], reduced_events["electrons"]]
            )
            self.log.info(
                "Optimizing cuts using %d signal and %d background events"
                % (len(self.signal_events), len(self.background_events)),
            )
        self.result = self.optimizer(
            events={"signal": self.signal_events, "background": self.background_events},
            # identical quality_query for all particle types
            quality_query=self.event_loaders["gammas"].epp.quality_query,
            clf_prefix=self.event_loaders["gammas"].epp.gammaness_classifier,
        ) 
[docs]
    def finish(self):
        """
        Write optimized cuts to the output file.
        """
        self.log.info("Writing results to %s" % self.output_path)
        Provenance().add_output_file(self.output_path, role="Optimization Result")
        self.result.write(self.output_path, self.overwrite) 
 
def main():
    tool = EventSelectionOptimizer()
    tool.run()
if __name__ == "main":
    main()