diff --git a/ctapipe/calib/camera/calibrator.py b/ctapipe/calib/camera/calibrator.py index a6eb81c565c..373301ce82b 100644 --- a/ctapipe/calib/camera/calibrator.py +++ b/ctapipe/calib/camera/calibrator.py @@ -13,8 +13,8 @@ from ctapipe.core import TelescopeComponent from ctapipe.core.traits import ( BoolTelescopeParameter, + ComponentName, TelescopeParameter, - create_class_enum_trait, ) from ctapipe.image.extractor import ImageExtractor from ctapipe.image.invalid_pixels import InvalidPixelHandler @@ -53,19 +53,17 @@ class CameraCalibrator(TelescopeComponent): The name of the ImageExtractor subclass to be used for image extraction """ - data_volume_reducer_type = create_class_enum_trait( + data_volume_reducer_type = ComponentName( DataVolumeReducer, default_value="NullDataVolumeReducer" ).tag(config=True) image_extractor_type = TelescopeParameter( - trait=create_class_enum_trait( - ImageExtractor, default_value="NeighborPeakWindowSum" - ), + trait=ComponentName(ImageExtractor, default_value="NeighborPeakWindowSum"), default_value="NeighborPeakWindowSum", help="Name of the ImageExtractor subclass to be used.", ).tag(config=True) - invalid_pixel_handler_type = create_class_enum_trait( + invalid_pixel_handler_type = ComponentName( InvalidPixelHandler, default_value="NeighborAverage", help="Name of the InvalidPixelHandler to use", diff --git a/ctapipe/core/tests/test_traits.py b/ctapipe/core/tests/test_traits.py index f3dd7bbf45b..1a593367092 100644 --- a/ctapipe/core/tests/test_traits.py +++ b/ctapipe/core/tests/test_traits.py @@ -1,9 +1,12 @@ import os import pathlib import tempfile +from abc import ABCMeta, abstractmethod from unittest import mock import pytest +from traitlets import CaselessStrEnum, HasTraits, Int + from ctapipe.core import Component, TelescopeComponent from ctapipe.core.traits import ( AstroTime, @@ -21,7 +24,6 @@ ) from ctapipe.image import ImageExtractor from ctapipe.utils.datasets import DEFAULT_URL, get_dataset_path -from traitlets import CaselessStrEnum, HasTraits, Int @pytest.fixture(scope="module") @@ -74,7 +76,7 @@ class C(Component): def test_path_exists(): - """ require existence of path """ + """require existence of path""" class C1(Component): thepath = Path(exists=False) @@ -120,7 +122,7 @@ class C1(Component): def test_path_directory_ok(): - """ test path is a directory """ + """test path is a directory""" class C(Component): thepath = Path(exists=True, directory_ok=False) @@ -139,7 +141,7 @@ class C(Component): def test_path_file_ok(): - """ check that the file is there and not a directory, etc""" + """check that the file is there and not a directory, etc""" class C(Component): thepath = Path(exists=True, file_ok=False) @@ -200,7 +202,7 @@ class C(Component): def test_enum_trait_default_is_right(): - """ check default value of enum trait """ + """check default value of enum trait""" from ctapipe.core.traits import create_class_enum_trait with pytest.raises(ValueError): @@ -208,7 +210,7 @@ def test_enum_trait_default_is_right(): def test_enum_trait(): - """ check that enum traits are constructable from a complex class """ + """check that enum traits are constructable from a complex class""" from ctapipe.core.traits import create_class_enum_trait trait = create_class_enum_trait( @@ -218,7 +220,7 @@ def test_enum_trait(): def test_enum_classes_with_traits(): - """ test that we can get a list of classes that have traits """ + """test that we can get a list of classes that have traits""" list_of_classes = classes_with_traits(ImageExtractor) assert list_of_classes # should not be empty @@ -248,15 +250,15 @@ class MyTool(Tool): def test_has_traits(): - """ test the has_traits func """ + """test the has_traits func""" class WithoutTraits(HasTraits): - """ a traits class that has no traits """ + """a traits class that has no traits""" pass class WithATrait(HasTraits): - """ a traits class that has a trait """ + """a traits class that has a trait""" my_trait = Int() @@ -290,7 +292,7 @@ def test_telescope_parameter_lookup(mock_subarray): def test_telescope_parameter_patterns(mock_subarray): - """ Test validation of TelescopeParameters""" + """Test validation of TelescopeParameters""" with pytest.raises(TypeError): TelescopeParameter(trait=int) @@ -583,3 +585,90 @@ class NoNone(Component): c = NoNone() with pytest.raises(TraitError): c.time = None + + +def test_component_name(): + from ctapipe.core.traits import ComponentName, ComponentNameList + + class Base(Component, metaclass=ABCMeta): + @abstractmethod + def stuff(self): + pass + + class Foo(Base): + def stuff(self): + pass + + class Baz(Component): + def stuff(self): + pass + + class MyComponent(Component): + base_name = ComponentName( + Base, + default_value="Foo", + help="A Base instance to do stuff", + ).tag(config=True) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.base = Base.from_name(self.base_name, parent=self) + self.base.stuff() + + class MyListComponent(Component): + base_names = ComponentNameList( + Base, + default_value=None, + allow_none=True, + ).tag(config=True) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.bases = [] + + if self.base_names is not None: + self.bases = [ + Base.from_name(name, parent=self) for name in self.base_names + ] + + for base in self.bases: + base.stuff() + + # this is here so we test that also classes defined after the traitlet + # is created work + class Bar(Base): + def stuff(self): + pass + + comp = MyComponent() + assert comp.base_name == "Foo" + + comp = MyComponent(base_name="Bar") + assert comp.base_name == "Bar" + + with pytest.raises(TraitError): + # Base is abstract + MyComponent(base_name="Base") + + with pytest.raises(TraitError): + # not a subclass of Base + MyComponent(base_name="Baz") + + with pytest.raises(TraitError): + # not a class at all + MyComponent(base_name="slakndklas") + + expected = "A Base instance to do stuff. Possible values: ['Foo', 'Bar']" + assert MyComponent.base_name.help == expected + + comp = MyListComponent() + assert comp.base_names is None + + comp = MyListComponent(base_names=["Foo", "Bar"]) + assert comp.base_names == ["Foo", "Bar"] + + with pytest.raises(TraitError): + MyListComponent(base_names=["Foo", "Baz"]) + + expected = "A list of Base subclass names. Possible values: ['Foo', 'Bar']" + assert MyListComponent.base_names.help == expected diff --git a/ctapipe/core/traits.py b/ctapipe/core/traits.py index 7821d897ec9..99b2ef7e443 100644 --- a/ctapipe/core/traits.py +++ b/ctapipe/core/traits.py @@ -14,7 +14,7 @@ from astropy.time import Time from traitlets import Undefined -from .component import non_abstract_children +from .component import Component, non_abstract_children __all__ = [ # Implemented here @@ -226,6 +226,69 @@ def create_class_enum_trait(base_class, default_value, help=None, allow_none=Fal ).tag(config=True) +class ComponentName(Unicode): + """A trait that is the name of a Component class""" + + def __init__(self, cls, **kwargs): + if not issubclass(cls, Component): + raise TypeError(f"cls must be a Component, got {cls}") + + self.cls = cls + super().__init__(**kwargs) + if "help" not in kwargs: + self.help = f"The name of a {cls.__name__} subclass" + + @property + def help(self): + children = list(self.cls.non_abstract_subclasses()) + return f"{self._help}. Possible values: {children}" + + @help.setter + def help(self, value): + self._help = value + + @property + def info_text(self): + return f"Any of {list(self.cls.non_abstract_subclasses())}" + + def validate(self, obj, value): + if self.allow_none and value is None: + return None + + if value in self.cls.non_abstract_subclasses(): + return value + + self.error(obj, value) + + +class ComponentNameList(List): + """A trait that is a list of Component classes""" + + def __init__(self, cls, **kwargs): + if not issubclass(cls, Component): + raise TypeError(f"cls must be a Component, got {cls}") + + self.cls = cls + trait = ComponentName(cls) + super().__init__(trait=trait, **kwargs) + + if "help" not in kwargs: + self.help = f"A list of {cls.__name__} subclass names" + + @property + def help(self): + children = list(self.cls.non_abstract_subclasses()) + return f"{self._help}. Possible values: {children}" + + @help.setter + def help(self, value): + self._help = value + + @property + def info_text(self): + return f"A list of {list(self.cls.non_abstract_subclasses())}" + + def classes_with_traits(base_class): """Returns a list of the base class plus its non-abstract children if they have traits""" @@ -437,6 +500,7 @@ def __init__(self, trait, default_value=Undefined, **kwargs): """ Create a new TelescopeParameter """ + self._help = "" if not isinstance(trait, TraitType): raise TypeError("trait must be a TraitType instance") @@ -445,8 +509,20 @@ def __init__(self, trait, default_value=Undefined, **kwargs): if default_value != Undefined: default_value = self.validate(self, default_value) + if "help" not in kwargs: + self.help = "A TelescopeParameter" + super().__init__(default_value=default_value, **kwargs) + @property + def help(self): + sep = "." if not self._help.endswith(".") else "" + return f"{self._help}{sep} {self._trait.help}" + + @help.setter + def help(self, value): + self._help = value + def from_string(self, s): val = super().from_string(s) # for strings, parsing fails and traitlets returns None diff --git a/ctapipe/image/extractor.py b/ctapipe/image/extractor.py index 4caa6dc056e..0e078582ef8 100644 --- a/ctapipe/image/extractor.py +++ b/ctapipe/image/extractor.py @@ -33,9 +33,9 @@ from ctapipe.core import TelescopeComponent from ctapipe.core.traits import ( BoolTelescopeParameter, + ComponentName, FloatTelescopeParameter, IntTelescopeParameter, - create_class_enum_trait, ) from .cleaning import tailcuts_clean @@ -877,7 +877,7 @@ class TwoPassWindowSum(ImageExtractor): default_value=True, help="Apply the integration window correction" ).tag(config=True) - invalid_pixel_handler_type = create_class_enum_trait( + invalid_pixel_handler_type = ComponentName( InvalidPixelHandler, default_value="NeighborAverage", help="Name of the InvalidPixelHandler to apply in the first pass.", diff --git a/ctapipe/image/image_processor.py b/ctapipe/image/image_processor.py index 0e7e1959601..c07ccc69125 100644 --- a/ctapipe/image/image_processor.py +++ b/ctapipe/image/image_processor.py @@ -17,7 +17,7 @@ TimingParametersContainer, ) from ..core import QualityQuery, TelescopeComponent -from ..core.traits import Bool, BoolTelescopeParameter, List, create_class_enum_trait +from ..core.traits import Bool, BoolTelescopeParameter, ComponentName, List from ..instrument import SubarrayDescription from .cleaning import ImageCleaner from .concentration import concentration_parameters @@ -64,9 +64,9 @@ class ImageProcessor(TelescopeComponent): Should be run after CameraCalibrator to produce all DL1 information. """ - image_cleaner_type = create_class_enum_trait( - base_class=ImageCleaner, default_value="TailcutsImageCleaner" - ) + image_cleaner_type = ComponentName( + ImageCleaner, default_value="TailcutsImageCleaner" + ).tag(config=True) use_telescope_frame = Bool( default_value=True, diff --git a/ctapipe/image/reducer.py b/ctapipe/image/reducer.py index a7c609d6b55..f7e63734d53 100644 --- a/ctapipe/image/reducer.py +++ b/ctapipe/image/reducer.py @@ -9,9 +9,9 @@ from ctapipe.core import TelescopeComponent from ctapipe.core.traits import ( BoolTelescopeParameter, + ComponentName, IntTelescopeParameter, TelescopeParameter, - create_class_enum_trait, ) from ctapipe.image import TailcutsImageCleaner from ctapipe.image.cleaning import dilate @@ -124,9 +124,7 @@ class TailCutsDataVolumeReducer(DataVolumeReducer): """ image_extractor_type = TelescopeParameter( - trait=create_class_enum_trait( - ImageExtractor, default_value="NeighborPeakWindowSum" - ), + trait=ComponentName(ImageExtractor, default_value="NeighborPeakWindowSum"), default_value="NeighborPeakWindowSum", help="Name of the ImageExtractor subclass to be used.", ).tag(config=True) diff --git a/ctapipe/io/simteleventsource.py b/ctapipe/io/simteleventsource.py index fa7e72189df..05f31ef79d0 100644 --- a/ctapipe/io/simteleventsource.py +++ b/ctapipe/io/simteleventsource.py @@ -48,7 +48,7 @@ from ..coordinates import CameraFrame, shower_impact_distance from ..core import Map from ..core.provenance import Provenance -from ..core.traits import Bool, Float, Undefined, UseEnum, create_class_enum_trait +from ..core.traits import Bool, ComponentName, Float, Undefined, UseEnum from ..instrument import ( CameraDescription, CameraGeometry, @@ -448,8 +448,8 @@ class SimTelEventSource(EventSource): ), ).tag(config=True) - gain_selector_type = create_class_enum_trait( - base_class=GainSelector, default_value="ThresholdGainSelector" + gain_selector_type = ComponentName( + GainSelector, default_value="ThresholdGainSelector" ).tag(config=True) calib_scale = Float( diff --git a/ctapipe/reco/shower_processor.py b/ctapipe/reco/shower_processor.py index 6fa6b4cd008..8c5c32d9996 100644 --- a/ctapipe/reco/shower_processor.py +++ b/ctapipe/reco/shower_processor.py @@ -9,7 +9,7 @@ """ from ..containers import ArrayEventContainer from ..core import Component -from ..core.traits import List, create_class_enum_trait +from ..core.traits import ComponentNameList from ..instrument import SubarrayDescription from . import Reconstructor @@ -24,13 +24,10 @@ class ShowerProcessor(Component): Input events must already contain dl1 parameters. """ - reconstructor_types = List( - create_class_enum_trait( - Reconstructor, - default_value="HillasReconstructor", - ), + reconstructor_types = ComponentNameList( + Reconstructor, default_value=["HillasReconstructor"], - help=f"The stereo geometry reconstructors to be used. Choices are: {list(Reconstructor.non_abstract_subclasses().keys())}", + help="The stereo reconstructors to be used", ).tag(config=True) def __init__(