"""Module containing classes related to event loading and preprocessing"""
from pathlib import Path
import astropy.units as u
import numpy as np
from astropy.coordinates import AltAz, SkyCoord
from astropy.table import Column, QTable, Table, vstack
try:
from pyirf.simulations import SimulatedEventsInfo
from pyirf.spectral import (
DIFFUSE_FLUX_UNIT,
POINT_SOURCE_FLUX_UNIT,
PowerLaw,
calculate_event_weights,
)
from pyirf.utils import calculate_source_fov_offset, calculate_theta
has_pyirf = True
except ModuleNotFoundError:
has_pyirf = False
from tables import NoSuchNodeError
from traitlets import default
from ..compat import COPY_IF_NEEDED
from ..containers import CoordinateFrameType
from ..coordinates import NominalFrame
from ..core import Component, QualityQuery
from ..core.traits import Bool, Dict, List, Tuple, Unicode
from .tableloader import TableLoader
__all__ = ["DL2EventLoader", "DL2EventPreprocessor", "DL2EventQualityQuery"]
[docs]
class DL2EventQualityQuery(QualityQuery):
"""
Event pre-selection quality criteria for IRF computation with different defaults.
"""
quality_criteria = List(
Tuple(Unicode(), Unicode()),
default_value=[
(
"multiplicity 4",
"np.count_nonzero(HillasReconstructor_telescopes,axis=1) >= 4",
),
("valid classifier", "RandomForestClassifier_is_valid"),
("valid geom reco", "HillasReconstructor_is_valid"),
("valid energy reco", "RandomForestRegressor_is_valid"),
],
help=QualityQuery.quality_criteria.help,
).tag(config=True)
[docs]
class DL2EventPreprocessor(Component):
"""Defines pre-selection cuts and the necessary renaming of columns."""
classes = [DL2EventQualityQuery]
energy_reconstructor = Unicode(
default_value="RandomForestRegressor",
help="Prefix of the reco `_energy` column",
).tag(config=True)
geometry_reconstructor = Unicode(
default_value="HillasReconstructor",
help="Prefix of the `_alt` and `_az` reco geometry columns",
).tag(config=True)
gammaness_classifier = Unicode(
default_value="RandomForestClassifier",
help="Prefix of the classifier `_prediction` column",
).tag(config=True)
apply_derived_columns = Bool(
default_value=True, help="Whether to compute derived columns"
).tag(config=True)
allow_unsupported_pointing_frames = Bool(
default_value=False,
help=(
"Check whether the pointing is supported."
"For the moment, only pointing in altaz is supported."
"Divergent pointing is also not supported."
),
).tag(config=True)
columns_to_rename = Dict(
key_trait=Unicode(),
value_trait=Unicode(),
help=(
"Dictionary of columns to rename. "
"Leave unset to apply default renaming. "
"Set to an empty dictionary to disable renaming entirely. "
"Set to a partial dictionary to override only some names."
),
).tag(config=True)
output_table_schema = List(
default_value=[
Column(name="obs_id", dtype=np.uint64, description="Observation Block ID"),
Column(name="event_id", dtype=np.uint64, description="Array event ID"),
Column(name="true_energy", unit=u.TeV, description="Simulated energy"),
Column(name="true_az", unit=u.deg, description="Simulated azimuth"),
Column(name="true_alt", unit=u.deg, description="Simulated altitude"),
Column(name="reco_energy", unit=u.TeV, description="Reconstructed energy"),
Column(name="reco_az", unit=u.deg, description="Reconstructed azimuth"),
Column(name="reco_alt", unit=u.deg, description="Reconstructed altitude"),
Column(name="pointing_az", unit=u.deg, description="Pointing azimuth"),
Column(name="pointing_alt", unit=u.deg, description="Pointing altitude"),
Column(
name="gh_score",
dtype=np.float64,
description="prediction of the classifier, defined between [0,1],"
" where values close to 1 mean that the positive class"
" (e.g. gamma in gamma-ray analysis) is more likely",
),
],
help="Schema definition for output event QTable",
).tag(config=True)
def __init__(self, config=None, parent=None, **kwargs):
super().__init__(config=config, parent=parent, **kwargs)
self.quality_query = DL2EventQualityQuery(parent=self)
@default("columns_to_rename")
def _default_columns_to_rename(self):
return {
f"{self.energy_reconstructor}_energy": "reco_energy",
f"{self.geometry_reconstructor}_az": "reco_az",
f"{self.geometry_reconstructor}_alt": "reco_alt",
f"{self.gammaness_classifier}_prediction": "gh_score",
"subarray_pointing_lat": "pointing_alt",
"subarray_pointing_lon": "pointing_az",
}
[docs]
def normalise_column_names(self, events: QTable) -> QTable:
"""
Rename column names according to configuration.
Parameters
----------
events : QTable
Input event table.
Returns
-------
QTable
Table with selected and renamed columns.
Raises
------
NotImplementedError
If pointing is not AltAz or varies too much.
ValueError
If required columns are missing.
"""
if not self.allow_unsupported_pointing_frames:
if events["subarray_pointing_lat"].std() > 1e-3:
raise NotImplementedError(
"No support for making irfs from varying pointings yet"
)
if any(
events["subarray_pointing_frame"] != CoordinateFrameType.ALTAZ.value
):
raise NotImplementedError(
"At the moment only pointing in altaz is supported."
)
columns_to_keep = [col.name for col in self.output_table_schema]
rename_dict = self.columns_to_rename
rename_from = list(rename_dict.keys())
rename_to = list(rename_dict.values())
fixed_columns = list(set(columns_to_keep) - set(rename_to))
columns_to_read = fixed_columns + rename_from
for col in columns_to_read:
if col not in events.colnames:
raise ValueError(
f"Input files must conform to the ctapipe DL2 data model. "
f"Required column {col} is missing."
)
events = QTable(events[columns_to_read], copy=COPY_IF_NEEDED)
if rename_from and rename_to:
events.rename_columns(rename_from, rename_to)
return events
[docs]
def make_empty_table(self) -> QTable:
"""
Create an empty event table based on the configured output schema.
"""
schema = list(
self.output_table_schema
) # make a shallow copy to extend the schema with derived columns
if self.apply_derived_columns:
schema.extend(
[
Column(
name="reco_fov_lat",
unit=u.deg,
description="Reconstructed FOV lat",
),
Column(
name="reco_fov_lon",
unit=u.deg,
description="Reconstructed FOV lon",
),
Column(
name="theta",
unit=u.deg,
description="Angular offset from source",
),
Column(
name="true_source_fov_offset",
unit=u.deg,
description="Simulated angular offset from pointing direction",
),
Column(
name="reco_source_fov_offset",
unit=u.deg,
description="Reconstructed angular offset from pointing direction",
),
Column(
name="weight",
dtype=np.float64,
description="Event weight",
),
]
)
return QTable(
names=[col.name for col in schema],
dtype=[col.dtype for col in schema],
units=[col.unit for col in schema]
if any(col.unit for col in schema)
else None,
meta={},
)
[docs]
class DL2EventLoader(Component):
"""
Component for loading events and simulation metadata, applying preselection and optional derived column logic.
"""
classes = [DL2EventPreprocessor]
# User-selectable event reading function and kwargs
event_reader_function = Unicode(
default_value="read_subarray_events_chunked",
help=(
"Function of TableLoader used to read event chunks. "
"E.g., 'read_subarray_events_chunked' or 'read_telescope_events_chunked'."
),
).tag(config=True)
event_reader_kwargs = Dict(
default_value={},
help="Extra keyword arguments passed to the event reading function, e.g., {'path': '/dl2/event/telescope/Reconstructor'}",
).tag(config=True)
def __init__(self, file: Path, target_spectrum: "Spectra", **kwargs): # noqa: F821
from ..irf.spectra import SPECTRA
super().__init__(**kwargs)
self.epp = DL2EventPreprocessor(parent=self)
self.target_spectrum = SPECTRA[target_spectrum]
self.file = file
[docs]
def load_preselected_events(
self, chunk_size: int, obs_time: u.Quantity
) -> tuple[QTable, int, dict]:
"""
Load and filter events from the file.
Parameters
----------
chunk_size : int
Size of chunks to read from the file.
obs_time : Quantity
Observation time to scale weights.
Returns
-------
table : QTable
Filtered and processed event table.
n_raw_events : int
Number of events before selection.
meta : dict
Metadata dictionary with simulation info and input spectrum.
"""
opts = dict(dl2=True, simulated=True, observation_info=True)
with TableLoader(self.file, parent=self, **opts) as loader:
table_template = self.epp.make_empty_table()
sim_info, spectrum = self.get_simulation_information(loader, obs_time)
meta = {"sim_info": sim_info, "spectrum": spectrum}
event_chunks = [table_template]
n_raw_events = 0
reader_func = getattr(loader, self.event_reader_function)
table_reader = reader_func(chunk_size, **opts, **self.event_reader_kwargs)
for _, _, events in table_reader:
selected = events[self.epp.quality_query.get_table_mask(events)]
selected = self.epp.normalise_column_names(selected)
if self.epp.apply_derived_columns:
selected = self.make_derived_columns(selected)
event_chunks.append(selected)
n_raw_events += len(events)
event_chunks.append(
table_template
) # Putting it last ensures the correct metadata is used
table = vstack(event_chunks, join_type="exact", metadata_conflicts="silent")
return table, n_raw_events, meta
[docs]
def make_derived_columns(self, events: QTable) -> QTable:
"""
Add derived quantities (e.g., theta, FOV offsets) to the table.
Parameters
----------
events : QTable
Table containing normalized events.
Returns
-------
QTable
Table with added derived columns.
"""
events["theta"] = calculate_theta(
events,
assumed_source_az=events["true_az"],
assumed_source_alt=events["true_alt"],
)
events["true_source_fov_offset"] = calculate_source_fov_offset(
events, prefix="true"
)
events["reco_source_fov_offset"] = calculate_source_fov_offset(
events, prefix="reco"
)
pointing = SkyCoord(
alt=events["pointing_alt"], az=events["pointing_az"], frame=AltAz()
)
reco = SkyCoord(alt=events["reco_alt"], az=events["reco_az"], frame=AltAz())
nominal = NominalFrame(origin=pointing)
reco_nominal = reco.transform_to(nominal)
events["reco_fov_lon"] = u.Quantity(-reco_nominal.fov_lon) # minus for GADF
events["reco_fov_lat"] = u.Quantity(reco_nominal.fov_lat)
events["weight"] = 1.0 # defer calculation of proper weights to later
return events
[docs]
def make_event_weights(
self,
events: QTable,
spectrum: "PowerLaw",
kind: str,
fov_offset_bins: u.Quantity | None = None,
) -> QTable:
"""
Compute event weights to match the target spectrum.
Parameters
----------
events : QTable
Input events.
spectrum : PowerLaw
Spectrum from simulation.
kind : str
Type of events ("gammas", etc.).
fov_offset_bins : Quantity, optional
Offset bins for integrating the diffuse flux into point source bins.
Returns
-------
QTable
Table with updated weights.
Raises
------
ValueError
If ``fov_offset_bins`` is required but not provided.
"""
if (
kind == "gammas"
and self.target_spectrum.normalization.unit.is_equivalent(
POINT_SOURCE_FLUX_UNIT
)
and spectrum.normalization.unit.is_equivalent(DIFFUSE_FLUX_UNIT)
):
if fov_offset_bins is None:
raise ValueError(
"gamma_target_spectrum is point-like, but no fov offset bins "
"for the integration of the simulated diffuse spectrum were given."
)
for low, high in zip(fov_offset_bins[:-1], fov_offset_bins[1:]):
fov_mask = events["true_source_fov_offset"] >= low
fov_mask &= events["true_source_fov_offset"] < high
events["weight"][fov_mask] = calculate_event_weights(
events[fov_mask]["true_energy"],
target_spectrum=self.target_spectrum,
simulated_spectrum=spectrum.integrate_cone(low, high),
)
else:
events["weight"] = calculate_event_weights(
events["true_energy"],
target_spectrum=self.target_spectrum,
simulated_spectrum=spectrum,
)
return events