Source code for ctapipe.tools.train_disp_reconstructor

"""
Tool for training the DispReconstructor
"""

import astropy.units as u
import numpy as np

from ctapipe.core import Tool
from ctapipe.core.traits import Bool, Int, IntTelescopeParameter, Path
from ctapipe.exceptions import InputMissing
from ctapipe.io import TableLoader
from ctapipe.reco import CrossValidator, DispReconstructor
from ctapipe.reco.preprocessing import horizontal_to_telescope

from .utils import read_training_events

__all__ = ["TrainDispReconstructor"]


[docs] class TrainDispReconstructor(Tool): """ Tool to train a `~ctapipe.reco.DispReconstructor` on dl1b/dl2 data. The tool first performs a cross validation to give an initial estimate on the quality of the estimation and then finally trains two models (estimating ``norm(disp)`` and ``sign(disp)`` respectively) per telescope type on the full dataset. """ name = "ctapipe-train-disp-reconstructor" description = __doc__ examples = """ ctapipe-train-disp-reconstructor \\ --config train_disp_reconstructor.yaml \\ --input gamma.dl2.h5 \\ --output disp_models.pkl """ output_path = Path( directory_ok=False, help=( "Output path for the trained reconstructor." " At the moment, pickle is the only supported format." ), ).tag(config=True) n_events = IntTelescopeParameter( default_value=None, allow_none=True, help=( "Number of events for training the models." " If not given, all available events will be used." ), ).tag(config=True) chunk_size = Int( default_value=100000, allow_none=True, help="How many subarray events to load at once before training on n_events.", ).tag(config=True) random_seed = Int( default_value=0, help="Random seed for sampling training events." ).tag(config=True) n_jobs = Int( default_value=None, allow_none=True, help="Number of threads to use for the reconstruction. This overwrites the values in the config of each reconstructor.", ).tag(config=True) project_disp = Bool( default_value=False, help=( "If true, ``true_disp`` is the distance between shower cog and" " the true source position along the reconstructed main shower axis." "If false, ``true_disp`` is the distance between shower cog" " and the true source position." ), ).tag(config=True) aliases = { ("i", "input"): "TableLoader.input_url", ("o", "output"): "TrainDispReconstructor.output_path", "n-events": "TrainDispReconstructor.n_events", "n-jobs": "DispReconstructor.n_jobs", "cv-output": "CrossValidator.output_path", } classes = [TableLoader, DispReconstructor, CrossValidator]
[docs] def setup(self): """ Initialize components from config. """ try: self.loader = self.enter_context(TableLoader(parent=self)) except InputMissing: self.log.critical( "Specifying TableLoader.input_url is required (via -i, --input or a config file)." ) self.exit(1) if self.output_path is None: self.log.critical( "setting output_path is required (via -o, --output or a config file)." ) self.exit(1) self.n_events.attach_subarray(self.loader.subarray) self.models = DispReconstructor(self.loader.subarray, parent=self) self.cross_validate = self.enter_context( CrossValidator( parent=self, model_component=self.models, overwrite=self.overwrite ) ) self.rng = np.random.default_rng(self.random_seed) self.check_output(self.output_path)
[docs] def start(self): """ Train models per telescope type using a cross-validation. """ types = self.loader.subarray.telescope_types self.log.info("Inputfile: %s", self.loader.input_url) self.log.info("Training models for %d types", len(types)) for tel_type in types: self.log.info("Loading events for %s", tel_type) feature_names = self.models.features + [ "true_energy", "true_impact_distance", "subarray_pointing_lat", "subarray_pointing_lon", "true_alt", "true_az", "hillas_fov_lat", "hillas_fov_lon", "hillas_psi", ] table = read_training_events( loader=self.loader, chunk_size=self.chunk_size, telescope_type=tel_type, reconstructor=self.models, feature_names=feature_names, rng=self.rng, log=self.log, n_events=self.n_events.tel[tel_type], ) table[self.models.target] = self._get_true_disp(table) table = table[ self.models.features + [self.models.target, "true_energy", "true_impact_distance"] ] self.log.info("Train models on %s events", len(table)) self.cross_validate(tel_type, table) self.log.info("Performing final fit for %s", tel_type) self.models.fit(tel_type, table) self.log.info("done")
def _get_true_disp(self, table): fov_lon, fov_lat = horizontal_to_telescope( alt=table["true_alt"], az=table["true_az"], pointing_alt=table["subarray_pointing_lat"], pointing_az=table["subarray_pointing_lon"], ) # numpy's trigonometric functions need radians psi = table["hillas_psi"].quantity.to_value(u.rad) cog_lon = table["hillas_fov_lon"].quantity cog_lat = table["hillas_fov_lat"].quantity delta_lon = fov_lon - cog_lon delta_lat = fov_lat - cog_lat true_disp = np.cos(psi) * delta_lon + np.sin(psi) * delta_lat true_sign = np.sign(true_disp) if self.project_disp: true_norm = np.abs(true_disp) else: true_norm = np.sqrt((fov_lon - cog_lon) ** 2 + (fov_lat - cog_lat) ** 2) return true_norm * true_sign
[docs] def finish(self): """ Write-out trained models and cross-validation results. """ self.log.info("Writing output") self.models.n_jobs = None self.models.write(self.output_path, overwrite=self.overwrite) self.loader.close() self.cross_validate.close()
def main(): TrainDispReconstructor().run() if __name__ == "__main__": main()