diff --git a/ctapipe/analysis/camera/charge_resolution.py b/ctapipe/analysis/camera/charge_resolution.py index 832c161f641..f36b4a87d49 100644 --- a/ctapipe/analysis/camera/charge_resolution.py +++ b/ctapipe/analysis/camera/charge_resolution.py @@ -5,37 +5,38 @@ class ChargeResolutionCalculator: - def __init__(self, mc_true=True): - """ - Calculates the charge resolution with an efficient, low-memory, - interative approach, allowing the contribution of data/events - without reading the entire dataset into memory. - - Utilises Pandas DataFrames, and makes no assumptions on the order of - the data, and does not require the true charge to be integer (as may - be the case for lab measurements where an average illumination - is used). - - A list is filled with a dataframe for each contribution, and only - amalgamated into a single dataframe (reducing memory) once the memory - of the list becomes large (or at the end of the filling), - reducing the time required to produce the output. - - Parameters - ---------- - mc_true : bool - Indicate if the "true charge" values are from the sim_telarray - files, and therefore without poisson error. The poisson error will - therefore be included in the charge resolution calculation. + """ + Calculates the charge resolution with an efficient, low-memory, + interative approach, allowing the contribution of data/events + without reading the entire dataset into memory. + + Utilises Pandas DataFrames, and makes no assumptions on the order of + the data, and does not require the true charge to be integer (as may + be the case for lab measurements where an average illumination + is used). + + A list is filled with a dataframe for each contribution, and only + amalgamated into a single dataframe (reducing memory) once the memory + of the list becomes large (or at the end of the filling), + reducing the time required to produce the output. + + Parameters + ---------- + mc_true : bool + Indicate if the "true charge" values are from the sim_telarray + files, and therefore without poisson error. The poisson error will + therefore be included in the charge resolution calculation. + + Attributes + ---------- + self._mc_true : bool + self._df_list : list + self._df : pd.DataFrame + self._n_bytes : int + Monitors the number of bytes being held in memory + """ - Attributes - ---------- - self._mc_true : bool - self._df_list : list - self._df : pd.DataFrame - self._n_bytes : int - Monitors the number of bytes being held in memory - """ + def __init__(self, mc_true=True): self._mc_true = mc_true self._df_list = [] self._df = pd.DataFrame() diff --git a/ctapipe/core/__init__.py b/ctapipe/core/__init__.py index 67a3460ccc8..829aaf6205e 100644 --- a/ctapipe/core/__init__.py +++ b/ctapipe/core/__init__.py @@ -1,4 +1,7 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst +""" +Core functionality of ctapipe +""" from .component import Component, non_abstract_children from .container import Container, Field, Map diff --git a/ctapipe/core/component.py b/ctapipe/core/component.py index 795aea1c587..20df0090820 100644 --- a/ctapipe/core/component.py +++ b/ctapipe/core/component.py @@ -1,9 +1,11 @@ """ Class to handle configuration for algorithms """ from abc import ABCMeta -from logging import getLogger from inspect import isabstract -from traitlets.config import Configurable +from logging import getLogger + from traitlets import TraitError +from traitlets.config import Configurable + from ctapipe.core.plugins import detect_and_import_io_plugins @@ -31,10 +33,10 @@ def non_abstract_children(base): class AbstractConfigurableMeta(type(Configurable), ABCMeta): - ''' + """ Metaclass to be able to make Component abstract see: http://stackoverflow.com/a/7314847/3838691 - ''' + """ pass @@ -108,7 +110,8 @@ def __init__(self, config=None, parent=None, **kwargs): if not self.has_trait(key): raise TraitError(f"Traitlet does not exist: {key}") - # set up logging + # set up logging (for some reason the logger registered by LoggingConfig + # doesn't use a child logger of the parent by default) if self.parent: self.log = self.parent.log.getChild(self.__class__.__name__) else: @@ -130,11 +133,11 @@ def from_name(cls, name, config=None, parent=None): Used to set traitlet values. This argument is typically only specified when using this method from within a Tool. - tool : ctapipe.core.Tool + parent : ctapipe.core.Tool Tool executable that is calling this component. - Passes the correct logger to the component. + Passes the correct logger and configuration to the component. This argument is typically only specified when using this method - from within a Tool. + from within a Tool (config need not be passed if parent is used). Returns ------- @@ -149,3 +152,33 @@ def from_name(cls, name, config=None, parent=None): requested_subclass = subclasses[name] return requested_subclass(config=config, parent=parent) + + def get_current_config(self): + """ return the current configuration as a dict (e.g. the values + of all traits, even if they were not set during configuration) + """ + return { + self.__class__.__name__: { + k: v.get(self) for k, v in self.traits(config=True).items() + } + } + + def _repr_html_(self): + """ nice HTML rep, with blue for non-default values""" + traits = self.traits() + name = self.__class__.__name__ + lines = [ + f"{name}", + f"

{self.__class__.__doc__ or 'Undocumented!'}

", + "" + ] + for key, val in self.get_current_config()[name].items(): + thehelp = f'{traits[key].help} (default: {traits[key].default_value})' + lines.append(f"") + if val != traits[key].default_value: + lines.append(f"") + else: + lines.append(f"") + lines.append(f'') + lines.append("
{key}{val}{val}{thehelp}
") + return "\n".join(lines) diff --git a/ctapipe/core/container.py b/ctapipe/core/container.py index bb38f89e2df..8027b2b0bae 100644 --- a/ctapipe/core/container.py +++ b/ctapipe/core/container.py @@ -35,7 +35,7 @@ def __repr__(self): class ContainerMeta(type): - ''' + """ The MetaClass for the Containers It reserves __slots__ for every class variable, @@ -44,7 +44,8 @@ class ContainerMeta(type): This makes sure, that the metadata is immutable, and no new fields can be added to a container by accident. - ''' + """ + def __new__(cls, name, bases, dct): field_names = [ k for k, v in dct.items() @@ -100,7 +101,7 @@ class Container(metaclass=ContainerMeta): >>> >>> cont = MyContainer() >>> print(cont.x) - >>> # metdata will become header keywords in an output file: + >>> # metadata will become header keywords in an output file: >>> cont.meta['KEY'] = value `Field`s inside `Containers` can contain instances of other @@ -119,6 +120,7 @@ class Container(metaclass=ContainerMeta): `meta` attribute, which is a `dict` of keywords to values. """ + def __init__(self, **fields): self.meta = {} # __slots__ cannot be provided with defaults @@ -164,6 +166,8 @@ def as_dict(self, recursive=False, flatten=False, add_prefix=False): flatten: type return a flat dictionary, with any sub-field keys generated by appending the sub-Container name. + add_prefix: bool + include the container's prefix in the name of each item """ if not recursive: return dict(self.items(add_prefix=add_prefix)) diff --git a/ctapipe/core/logging.py b/ctapipe/core/logging.py index bf4145c92d7..2dd0e36f765 100644 --- a/ctapipe/core/logging.py +++ b/ctapipe/core/logging.py @@ -1,3 +1,5 @@ +""" helpers for better logging """ + import logging @@ -14,17 +16,19 @@ def format(self, record): reset_seq = "\033[0m" color_seq = "\033[1;%dm" colors = { - 'WARNING': yellow, 'INFO': green, 'DEBUG': blue, - 'CRITICAL': yellow, + 'WARNING': yellow, + 'CRITICAL': magenta, 'ERROR': red } levelname = record.levelname if levelname in colors: - levelname_color = color_seq % (30 + colors[levelname]) \ + levelname_color = ( + color_seq % (30 + colors[levelname]) + levelname + reset_seq + ) record.levelname = levelname_color if record.levelno >= self.highlevel_limit: diff --git a/ctapipe/core/plugins.py b/ctapipe/core/plugins.py index 1ca44eccbed..4c34f6f9004 100644 --- a/ctapipe/core/plugins.py +++ b/ctapipe/core/plugins.py @@ -1,9 +1,11 @@ +""" Functions for dealing with IO plugins """ + import importlib import pkgutil def detect_and_import_plugins(prefix): - ''' detect and import plugin modules with given prefix, ''' + """ detect and import plugin modules with given prefix, """ return { name: importlib.import_module(name) for finder, name, ispkg diff --git a/ctapipe/core/provenance.py b/ctapipe/core/provenance.py index d406a7cbd6e..8299a9d0e4a 100644 --- a/ctapipe/core/provenance.py +++ b/ctapipe/core/provenance.py @@ -12,12 +12,12 @@ import sys import uuid from contextlib import contextmanager -from os.path import abspath from importlib import import_module -from pkg_resources import get_distribution +from os.path import abspath import psutil from astropy.time import Time +from pkg_resources import get_distribution import ctapipe from .support import Singleton @@ -157,7 +157,13 @@ def provenance(self): def as_json(self, **kwargs): """ return all finished provenance as JSON. Kwargs for `json.dumps` may be included, e.g. `indent=4`""" - return json.dumps(self.provenance, **kwargs) + + def set_default(obj): + """ handle sets (not part of JSON) by converting to list""" + if isinstance(obj, set): + return list(obj) + + return json.dumps(self.provenance, default=set_default, **kwargs) @property def active_activity_names(self): diff --git a/ctapipe/core/tests/__init__.py b/ctapipe/core/tests/__init__.py index e69de29bb2d..3de94947755 100644 --- a/ctapipe/core/tests/__init__.py +++ b/ctapipe/core/tests/__init__.py @@ -0,0 +1,3 @@ +""" +tests of core functionality of ctapipe +""" \ No newline at end of file diff --git a/ctapipe/core/tests/test_component.py b/ctapipe/core/tests/test_component.py index db97e3254b2..51b296e7f25 100644 --- a/ctapipe/core/tests/test_component.py +++ b/ctapipe/core/tests/test_component.py @@ -1,11 +1,14 @@ from abc import abstractmethod, ABC + import pytest from traitlets import Float, TraitError from traitlets.config.loader import Config + from ctapipe.core import Component def test_non_abstract_children(): + """ check that we can find all constructable children """ from ctapipe.core import non_abstract_children class AbstractBase(ABC): @@ -36,17 +39,18 @@ class AbstractChild(AbstractBase): class ExampleComponent(Component): - description = "this is a test" + """ An Example Component, this is the help text""" param = Float(default_value=1.0, help="float parameter").tag(config=True) class ExampleSubclass1(ExampleComponent): - description = "this is a test" - + """ a subclass of ExampleComponent""" + pass class ExampleSubclass2(ExampleComponent): - description = "this is a test" + """ Another ExampleComponent """ + description = "A shorter description" param = Float(default_value=3.0, help="float parameter").tag(config=True) extra = Float(default_value=5.0, @@ -54,7 +58,7 @@ class ExampleSubclass2(ExampleComponent): def test_component_is_abstract(): - + """ check that we can make an abstract component """ class AbstractComponent(Component): @abstractmethod def test(self): @@ -79,7 +83,7 @@ def test_component_simple(): def test_component_kwarg_setting(): - + """ check that we can construct a component by setting traits via kwargs """ comp = ExampleComponent(param=3) assert comp.param == 3 @@ -93,11 +97,13 @@ def test_component_kwarg_setting(): def test_help(): + """ check that component help strings are generated correctly """ help_msg = ExampleComponent.class_get_help() assert "Default: 1.0" in help_msg def test_config(): + """ check that components can be constructed by config dict """ config = Config() config['ExampleComponent'] = Config() config['ExampleComponent']['param'] = 199. @@ -106,6 +112,7 @@ def test_config(): def test_config_baseclass(): + """ check that parent and subclass configuration works """ config = Config() config['ExampleComponent'] = Config() config['ExampleComponent']['param'] = 199. @@ -116,6 +123,7 @@ def test_config_baseclass(): def test_config_subclass1(): + """check sub-class config""" config = Config() config['ExampleSubclass1'] = Config() config['ExampleSubclass1']['param'] = 199. @@ -124,6 +132,7 @@ def test_config_subclass1(): def test_config_subclass2(): + """check another sub-class config""" config = Config() config['ExampleSubclass2'] = Config() config['ExampleSubclass2']['param'] = 199. @@ -132,6 +141,7 @@ def test_config_subclass2(): def test_config_sibling1(): + """ check sibling config """ config = Config() config['ExampleSubclass1'] = Config() config['ExampleSubclass1']['param'] = 199. @@ -142,6 +152,7 @@ def test_config_sibling1(): def test_config_sibling2(): + """ check sibling config """ config = Config() config['ExampleSubclass2'] = Config() config['ExampleSubclass2']['param'] = 199. @@ -152,6 +163,7 @@ def test_config_sibling2(): def test_config_baseclass_then_subclass(): + """ check base and subclass config """ config = Config() config['ExampleComponent'] = Config() config['ExampleComponent']['param'] = 199. @@ -162,6 +174,7 @@ def test_config_baseclass_then_subclass(): def test_config_subclass_then_baseclass(): + """ check subclass and base config """ config = Config() config['ExampleSubclass1'] = Config() config['ExampleSubclass1']['param'] = 229. @@ -172,6 +185,7 @@ def test_config_subclass_then_baseclass(): def test_config_override(): + """ check that we can override a trait set in the config """ config = Config() config['ExampleComponent'] = Config() config['ExampleComponent']['param'] = 199. @@ -180,6 +194,7 @@ def test_config_override(): def test_config_override_subclass(): + """ check that we can override a trait set in the config """ config = Config() config['ExampleComponent'] = Config() config['ExampleComponent']['param'] = 199. @@ -188,12 +203,14 @@ def test_config_override_subclass(): def test_extra(): + """ check that traits are settable """ comp = ExampleSubclass2(extra=229.) assert comp.has_trait('extra') is True assert comp.extra == 229. def test_extra_config(): + """ check setting trait via config """ config = Config() config['ExampleSubclass2'] = Config() config['ExampleSubclass2']['extra'] = 229. @@ -202,11 +219,16 @@ def test_extra_config(): def test_extra_missing(): + """ check that setting an incorrect trait raises an exception """ with pytest.raises(TraitError): ExampleSubclass1(extra=229.) def test_extra_config_missing(): + """ + check that setting an incorrect trait via config also raises + an exception + """ config = Config() config['ExampleSubclass1'] = Config() config['ExampleSubclass1']['extra'] = 199. @@ -218,21 +240,25 @@ def test_extra_config_missing(): def test_default(): + """ check default values work""" comp = ExampleComponent() assert comp.param == 1. def test_default_subclass(): + """ check default values work in subclasses""" comp = ExampleSubclass1() assert comp.param == 1. def test_default_subclass_override(): + """ check overrides work in subclasses""" comp = ExampleSubclass2() assert comp.param == 3. def test_change_default(): + """ check we can change a default value""" old_default = ExampleComponent.param.default_value ExampleComponent.param.default_value = 199. comp = ExampleComponent() @@ -241,6 +267,7 @@ def test_change_default(): def test_change_default_subclass(): + """ check we can change a default value in subclass """ old_default = ExampleComponent.param.default_value ExampleComponent.param.default_value = 199. comp = ExampleSubclass1() @@ -249,6 +276,7 @@ def test_change_default_subclass(): def test_change_default_subclass_override(): + """ check override default value """ old_default = ExampleComponent.param.default_value ExampleComponent.param.default_value = 199. comp = ExampleSubclass2() @@ -257,6 +285,7 @@ def test_change_default_subclass_override(): def test_help_changed_default(): + """ check that the help text is updated if the default is changed """ old_default = ExampleComponent.param.default_value ExampleComponent.param.default_value = 199. help_msg = ExampleComponent.class_get_help() @@ -265,6 +294,7 @@ def test_help_changed_default(): def test_from_name(): + """ Make sure one can construct a Component subclass by name""" subclass = ExampleComponent.from_name("ExampleSubclass1") assert isinstance(subclass, ExampleSubclass1) subclass = ExampleComponent.from_name("ExampleSubclass2") @@ -272,6 +302,23 @@ def test_from_name(): def test_from_name_config(): + """ make sure one can construct a Component subclass by name + config""" config = Config({'ExampleComponent': {'param': 229.}}) subclass = ExampleComponent.from_name("ExampleSubclass1", config=config) assert subclass.param == 229. + + +def test_component_current_config(): + """ make sure one can get the full current configuration""" + comp = ExampleComponent() + full_config = comp.get_current_config() + assert "ExampleComponent" in full_config + assert 'param' in full_config['ExampleComponent'] + assert full_config["ExampleComponent"]['param'] == 1.0 + + +def test_component_html_repr(): + """ check the HTML repr for Jupyter notebooks """ + comp = ExampleComponent() + html = comp._repr_html_() + assert len(html) > 10 diff --git a/ctapipe/core/tests/test_tool.py b/ctapipe/core/tests/test_tool.py index 9fa732e992f..e63ae3bec69 100644 --- a/ctapipe/core/tests/test_tool.py +++ b/ctapipe/core/tests/test_tool.py @@ -2,6 +2,7 @@ from traitlets import Float, TraitError from .. import Tool +from ..tool import export_tool_config_to_commented_yaml def test_tool_simple(): @@ -24,10 +25,56 @@ class MyTool(Tool): def test_tool_version(): - + """ check that the tool gets an automatic version string""" class MyTool(Tool): description = "test" userparam = Float(5.0, help="parameter").tag(config=True) tool = MyTool() assert tool.version_string != "" + + +def test_export_config_to_yaml(): + """ test that we can export a Tool's config to YAML""" + import yaml + from ctapipe.tools.camdemo import CameraDemo + + tool = CameraDemo() + tool.num_events = 2 + yaml_string = export_tool_config_to_commented_yaml(tool) + + # check round-trip back from yaml: + config_dict = yaml.load(yaml_string, Loader=yaml.SafeLoader) + + assert config_dict['CameraDemo']['num_events'] == 2 + + +def test_tool_html_rep(): + """ check that the HTML rep for Jupyter notebooks works""" + class MyTool(Tool): + description = "test" + userparam = Float(5.0, help="parameter").tag(config=True) + + class MyTool2(Tool): + """ A docstring description""" + userparam = Float(5.0, help="parameter").tag(config=True) + + tool = MyTool() + tool2 = MyTool2() + assert len(tool._repr_html_()) > 0 + assert len(tool2._repr_html_()) > 0 + + +def test_tool_current_config(): + """ Check that we can get the full instance configuration """ + class MyTool(Tool): + description = "test" + userparam = Float(5.0, help="parameter").tag(config=True) + + tool = MyTool() + conf1 = tool.get_current_config() + tool.userparam = -1.0 + conf2 = tool.get_current_config() + + assert conf1['MyTool']['userparam'] == 5.0 + assert conf2['MyTool']['userparam'] == -1.0 diff --git a/ctapipe/core/tests/test_traits.py b/ctapipe/core/tests/test_traits.py index 6761f78f536..8149494263f 100644 --- a/ctapipe/core/tests/test_traits.py +++ b/ctapipe/core/tests/test_traits.py @@ -1,65 +1,104 @@ -from ctapipe.core import Component -from ctapipe.core.traits import Path, TraitError - -from pytest import raises import tempfile +import pytest +from traitlets import CaselessStrEnum, HasTraits, Int + +from ctapipe.core import Component +from ctapipe.core.traits import ( + Path, + TraitError, + classes_with_traits, + enum_trait, + has_traits, +) +from ctapipe.image import ImageExtractor -def test_path_exists(): +def test_path_exists(): class C1(Component): - p = Path(exists=False) + thepath = Path(exists=False) c1 = C1() - c1.p = 'test' + c1.thepath = "test" with tempfile.NamedTemporaryFile() as f: - with raises(TraitError): - c1.p = f.name + with pytest.raises(TraitError): + c1.thepath = f.name class C2(Component): - p = Path(exists=True) + thepath = Path(exists=True) c2 = C2() with tempfile.TemporaryDirectory() as d: - c2.p = d + c2.thepath = d with tempfile.NamedTemporaryFile() as f: - c2.p = f.name + c2.thepath = f.name def test_path_directory_ok(): - class C(Component): - p = Path(exists=True, directory_ok=False) + thepath = Path(exists=True, directory_ok=False) c = C() - with raises(TraitError): - c.p = 'lknasdlakndlandslknalkndslakndslkan' + with pytest.raises(TraitError): + c.thepath = "lknasdlakndlandslknalkndslakndslkan" with tempfile.TemporaryDirectory() as d: - with raises(TraitError): - c.p = d + with pytest.raises(TraitError): + c.thepath = d with tempfile.NamedTemporaryFile() as f: - c.p = f.name + c.thepath = f.name def test_path_file_ok(): - class C(Component): - p = Path(exists=True, file_ok=False) + thepath = Path(exists=True, file_ok=False) c = C() - with raises(TraitError): - c.p = 'lknasdlakndlandslknalkndslakndslkan' + with pytest.raises(TraitError): + c.thepath = "lknasdlakndlandslknalkndslakndslkan" with tempfile.TemporaryDirectory() as d: - c.p = d + c.thepath = d with tempfile.NamedTemporaryFile() as f: - with raises(TraitError): - c.p = f.name + with pytest.raises(TraitError): + c.thepath = f.name + + +def test_enum_trait_default_is_right(): + """ check default value of enum trait """ + with pytest.raises(ValueError): + enum_trait(ImageExtractor, default="name_of_default_choice") + + +def test_enum_trait(): + """ check that enum traits are constructable from a complex class """ + trait = enum_trait(ImageExtractor, default="NeighborPeakWindowSum") + assert isinstance(trait, CaselessStrEnum) + + +def test_enum_classes_with_traits(): + """ test that we can get a list of classes that have traits """ + list_of_classes = classes_with_traits(ImageExtractor) + assert list_of_classes # should not be empty + + +def test_has_traits(): + """ test the has_traits func """ + + class WithoutTraits(HasTraits): + """ a traits class that has no traits """ + pass + + class WithATrait(HasTraits): + """ a traits class that has a trait """ + my_trait = Int() + + assert not has_traits(WithoutTraits) + assert has_traits(WithATrait) diff --git a/ctapipe/core/tool.py b/ctapipe/core/tool.py index 25c8e8ba23a..0ec15dcec24 100644 --- a/ctapipe/core/tool.py +++ b/ctapipe/core/tool.py @@ -1,14 +1,14 @@ +""" Classes to handle configurable command-line user interfaces """ import logging +import textwrap from abc import abstractmethod from traitlets import Unicode -from traitlets.config import Application +from traitlets.config import Application, Configurable from ctapipe import __version__ as version -from .logging import ColoredFormatter from . import Provenance - -logging.basicConfig(level=logging.WARNING) +from .logging import ColoredFormatter class ToolConfigurationError(Exception): @@ -101,8 +101,12 @@ def main(): """ config_file = Unicode('', help=("name of a configuration file with " - "parameters to load in addition to " - "command-line parameters")).tag(config=True) + "parameters to load in addition to " + "command-line parameters")).tag(config=True) + log_format = Unicode( + '%(levelname)s [%(name)s] (%(module)s/%(funcName)s): %(message)s', + help='The Logging format template' + ).tag(config=True) _log_formatter_cls = ColoredFormatter @@ -113,10 +117,9 @@ def __init__(self, **kwargs): self.aliases['config'] = 'Tool.config_file' super().__init__(**kwargs) - self.log_format = ('%(levelname)8s [%(name)s] ' - '(%(module)s/%(funcName)s): %(message)s') self.log_level = logging.INFO self.is_setup = False + self._registered_components = [] def initialize(self, argv=None): """ handle config and any other low-level setup """ @@ -126,6 +129,35 @@ def initialize(self, argv=None): self.load_config_file(self.config_file) self.log.info(f"ctapipe version {self.version_string}") + def add_component(self, component_instance): + """ + constructs and adds a component to the list of registered components, + so that later we can ask for the current configuration of all instances, + e.g. in`get_full_config()`. All sub-components of a tool should be + constructed using this function, in order to ensure the configuration is + properly traced. + + Parameters + ---------- + component_instance: Component + constructed instance of a component + + Returns + ------- + Component: + the same component instance that was passed in, so that the call + can be chained. + + Example + ------- + .. code-block:: python3 + + self.mycomp = self.add_component(MyComponent(parent=self)) + + """ + self._registered_components.append(component_instance) + return component_instance + @abstractmethod def setup(self): """set up the tool (override in subclass). Here the user should @@ -134,14 +166,14 @@ def setup(self): @abstractmethod def start(self): - """main body of tool (override in subclass). This is automatially + """main body of tool (override in subclass). This is automatically called after `initialize()` when the `run()` is called. """ pass @abstractmethod def finish(self): - """finish up (override in subclass). This is called automatially + """finish up (override in subclass). This is called automatically after `start()` when `run()` is called.""" self.log.info("Goodbye") @@ -159,11 +191,11 @@ def run(self, argv=None): try: self.initialize(argv) self.log.info(f"Starting: {self.name}") - self.log.debug(f"CONFIG: {self.config}") Provenance().start_activity(self.name) - Provenance().add_config(self.config) self.setup() self.is_setup = True + self.log.info(f"CONFIG: {self.get_current_config()}") + Provenance().add_config(self.get_current_config()) self.start() self.finish() self.log.info(f"Finished: {self.name}") @@ -191,3 +223,123 @@ def run(self, argv=None): def version_string(self): """ a formatted version string with version, release, and git hash""" return f"{version}" + + def get_current_config(self): + """ return the current configuration as a dict (e.g. the values + of all traits, even if they were not set during configuration) + """ + conf = { + self.__class__.__name__: { + k: v.get(self) for k, v in self.traits(config=True).items() + } + } + for component in self._registered_components: + conf.update(component.get_current_config()) + + return conf + + def _repr_html_(self): + """ nice HTML rep, with blue for non-default values""" + traits = self.traits() + name = self.__class__.__name__ + lines = [ + f"{name}", + f"

{self.__class__.__doc__ or self.description}

", + "", + ] + for key, val in self.get_current_config()[name].items(): + default = traits[key].default_value + thehelp = f'{traits[key].help} (default: {default})' + lines.append(f"") + if val != default: + lines.append(f"") + else: + lines.append(f"") + lines.append(f'') + lines.append("
{key}{val}{val}{thehelp}
") + lines.append("

Components:") + lines.append(", ".join([x.__name__ for x in self.classes])) + lines.append("

") + + return "\n".join(lines) + + +def export_tool_config_to_commented_yaml(tool_instance: Tool, classes=None): + """ + Turn the config of a single Component into a commented YAML string. + + This is a hacked version of + traitlets.config.Configurable._class_config_section() changed to + output a YAML file with defaults *and* current values filled in. + + Parameters + ---------- + tool_instance: Tool + a constructed Tool instance + classes: list, optional + The list of other classes in the config file. + Used to reduce redundant information. + """ + + tool = tool_instance.__class__ + config = tool_instance.get_current_config()[tool_instance.__class__.__name__] + + def commented(text, indent_level=2, width=70): + """return a commented, wrapped block.""" + return textwrap.fill( + text, + width=width, + initial_indent=" " * indent_level + "# ", + subsequent_indent=" " * indent_level + "# ", + ) + + # section header + breaker = '#' + '-' * 78 + parent_classes = ', '.join( + p.__name__ for p in tool.__bases__ + if issubclass(p, Configurable) + ) + + section_header = f"# {tool.__name__}({parent_classes}) configuration" + + lines = [breaker, section_header] + # get the description trait + desc = tool.class_traits().get('description') + if desc: + desc = desc.default_value + if not desc: + # no description from trait, use __doc__ + desc = getattr(tool, '__doc__', '') + if desc: + lines.append(commented(desc, indent_level=0)) + lines.append(breaker) + lines.append(f'{tool.__name__}:') + + for name, trait in sorted(tool.class_traits(config=True).items()): + default_repr = trait.default_value_repr() + current_repr = config.get(name, "") + if isinstance(current_repr, str): + current_repr = f'"{current_repr}"' + + if classes: + defining_class = tool._defining_class(trait, classes) + else: + defining_class = tool + if defining_class is tool: + # cls owns the trait, show full help + if trait.help: + lines.append(commented(trait.help)) + if 'Enum' in type(trait).__name__: + # include Enum choices + lines.append(commented(f'Choices: {trait.info()}')) + lines.append(commented(f'Default: {default_repr}')) + else: + # Trait appears multiple times and isn't defined here. + # Truncate help to first line + "See also Original.trait" + if trait.help: + lines.append(commented(trait.help.split('\n', 1)[0])) + lines.append( + f' # See also: {defining_class.__name__}.{name}') + lines.append(f' {name}: {current_repr}') + lines.append('') + return '\n'.join(lines) diff --git a/ctapipe/core/traits.py b/ctapipe/core/traits.py index 854a12f634e..2eb422ef486 100644 --- a/ctapipe/core/traits.py +++ b/ctapipe/core/traits.py @@ -1,17 +1,50 @@ -from traitlets import (Int, Integer, Float, Unicode, Enum, Long, List, - Bool, CRegExp, Dict, TraitError, observe, - CaselessStrEnum, TraitType) -from traitlets.config import boolean_flag as flag import os -__all__ = ['Path', 'Int', 'Integer', 'Float', 'Unicode', 'Enum', 'Long', 'List', - 'Bool', 'CRegExp', 'Dict', 'flag', 'TraitError', 'observe', - 'CaselessStrEnum'] +from traitlets import ( + Bool, + CaselessStrEnum, + CRegExp, + Dict, + Enum, + Float, + Int, + Integer, + List, + Long, + TraitError, + TraitType, + Unicode, + observe, +) +from traitlets.config import boolean_flag as flag + +from .component import non_abstract_children + +__all__ = [ + "Path", + "Int", + "Integer", + "Float", + "Unicode", + "Enum", + "Long", + "List", + "Bool", + "CRegExp", + "Dict", + "flag", + "TraitError", + "observe", + "CaselessStrEnum", + "enum_trait", + "classes_with_traits", + "has_traits", +] class Path(TraitType): def __init__(self, exists=None, directory_ok=True, file_ok=True): - ''' + """ A path Trait for input/output files. Parameters @@ -23,7 +56,7 @@ def __init__(self, exists=None, directory_ok=True, file_ok=True): If False, path must not be a directory file_ok: boolean If False, path must not be a file - ''' + """ super().__init__() self.exists = exists self.directory_ok = directory_ok @@ -35,20 +68,59 @@ def validate(self, obj, value): value = os.path.abspath(value) if self.exists is not None: if os.path.exists(value) != self.exists: - raise TraitError('Path "{}" {} exist'.format( - value, - 'does not' if self.exists else 'must' - )) - if os.path.exists(value): - if os.path.isdir(value) and not self.directory_ok: raise TraitError( - f'Path "{value}" must not be a directory' + 'Path "{}" {} exist'.format( + value, "does not" if self.exists else "must" + ) ) + if os.path.exists(value): + if os.path.isdir(value) and not self.directory_ok: + raise TraitError(f'Path "{value}" must not be a directory') if os.path.isfile(value) and not self.file_ok: - raise TraitError( - f'Path "{value}" must not be a file' - ) + raise TraitError(f'Path "{value}" must not be a file') return value return self.error(obj, value) + + +def enum_trait(base_class, default, help_str=None): + """create a configurable CaselessStrEnum traitlet from baseclass + + the enumeration should contain all names of non_abstract_children() + of said baseclass and the default choice should be given by + `base_class._default` name. + + default must be specified and must be the name of one child-class + """ + if help_str is None: + help_str = "{} to use.".format(base_class.__name__) + + choices = [cls.__name__ for cls in non_abstract_children(base_class)] + if default not in choices: + raise ValueError( + "{default} is not in choices: {choices}".format( + default=default, choices=choices + ) + ) + + return CaselessStrEnum(choices, default, allow_none=True, help=help_str).tag( + config=True + ) + + +def classes_with_traits(base_class): + """ Returns a list of the base class plus its non-abstract children + if they have traits """ + all_classes = [base_class] + non_abstract_children(base_class) + return [cls for cls in all_classes if has_traits(cls)] + + +def has_traits(cls, ignore=("config", "parent")): + """True if cls has any traits apart from the usual ones + + all our components have at least 'config' and 'parent' as traitlets + this is inherited from `traitlets.config.Configurable` so we ignore them + here. + """ + return bool(set(cls.class_trait_names()) - set(ignore)) diff --git a/ctapipe/io/eventsource.py b/ctapipe/io/eventsource.py index a058d82d8ef..552e94d4531 100644 --- a/ctapipe/io/eventsource.py +++ b/ctapipe/io/eventsource.py @@ -3,10 +3,12 @@ """ from abc import abstractmethod from os.path import exists + from traitlets import Unicode, Int, Set, TraitError +from traitlets.config.loader import LazyConfigValue + from ctapipe.core import Component, non_abstract_children from ctapipe.core import Provenance -from traitlets.config.loader import LazyConfigValue from ctapipe.core.plugins import detect_and_import_io_plugins __all__ = [ diff --git a/ctapipe/io/simteleventsource.py b/ctapipe/io/simteleventsource.py index 4416bda606b..c85c3803a31 100644 --- a/ctapipe/io/simteleventsource.py +++ b/ctapipe/io/simteleventsource.py @@ -73,7 +73,7 @@ def __init__(self, config=None, parent=None, **kwargs): # so we explicitly pass None in that case self.file_ = SimTelFile( self.input_url, - allowed_telescopes=self.allowed_tels if self.allowed_tels else None, + allowed_telescopes=set(self.allowed_tels) if self.allowed_tels else None, skip_calibration=self.skip_calibration_events, zcat=not self.back_seekable, ) diff --git a/ctapipe/tools/bokeh/file_viewer.py b/ctapipe/tools/bokeh/file_viewer.py index 6b63e68813f..2b92ac4b242 100644 --- a/ctapipe/tools/bokeh/file_viewer.py +++ b/ctapipe/tools/bokeh/file_viewer.py @@ -1,18 +1,19 @@ import os -from bokeh.layouts import widgetbox, layout -from bokeh.models import Select, TextInput, PreText, Button -from bokeh.server.server import Server + from bokeh.document.document import jinja2 +from bokeh.layouts import layout, widgetbox +from bokeh.models import Button, PreText, Select, TextInput +from bokeh.server.server import Server from bokeh.themes import Theme -from traitlets import Dict, List, Int, Bool +from traitlets import Bool, Dict, Int, List + from ctapipe.calib import CameraCalibrator -from ctapipe.core import Tool +from ctapipe.core import Tool, traits from ctapipe.image.extractor import ImageExtractor from ctapipe.io import EventSource from ctapipe.io.eventseeker import EventSeeker from ctapipe.plotting.bokeh_event_viewer import BokehEventViewer from ctapipe.utils import get_dataset_path -import ctapipe.utils.tools as tool_utils class BokehFileViewer(Tool): @@ -27,7 +28,7 @@ class BokehFileViewer(Tool): default_url = get_dataset_path("gamma_test_large.simtel.gz") EventSource.input_url.default_value = default_url - extractor_product = tool_utils.enum_trait( + extractor_product = traits.enum_trait( ImageExtractor, default='NeighborPeakWindowSum' ) @@ -43,7 +44,7 @@ class BokehFileViewer(Tool): classes = List( [ EventSource, - ] + tool_utils.classes_with_traits(ImageExtractor) + ] + traits.classes_with_traits(ImageExtractor) ) def __init__(self, **kwargs): diff --git a/ctapipe/tools/display_dl1.py b/ctapipe/tools/display_dl1.py index 7dd22d9e5a0..d9b1c9eb430 100644 --- a/ctapipe/tools/display_dl1.py +++ b/ctapipe/tools/display_dl1.py @@ -1,31 +1,32 @@ """ Calibrate dl0 data to dl1, and plot the photoelectron images. """ -from matplotlib import pyplot as plt, colors +from matplotlib import colors +from matplotlib import pyplot as plt from matplotlib.backends.backend_pdf import PdfPages -from traitlets import Dict, List, Int, Bool, Unicode +from traitlets import Bool, Dict, Int, List, Unicode from ctapipe.calib import CameraCalibrator -from ctapipe.visualization import CameraDisplay -from ctapipe.core import Tool, Component -from ctapipe.utils import get_dataset_path +from ctapipe.core import Component, Tool +from ctapipe.core import traits from ctapipe.image.extractor import ImageExtractor from ctapipe.io import EventSource -import ctapipe.utils.tools as tool_utils +from ctapipe.utils import get_dataset_path +from ctapipe.visualization import CameraDisplay class ImagePlotter(Component): + """ Plotter for camera images """ + display = Bool( - True, - help='Display the photoelectron images on-screen as they ' - 'are produced.' + True, help="Display the photoelectron images on-screen as they are produced." ).tag(config=True) output_path = Unicode( None, allow_none=True, - help='Output path for the pdf containing all the ' - 'images. Set to None for no saved ' - 'output.' + help="Output path for the pdf containing all the " + "images. Set to None for no saved " + "output.", ).tag(config=True) def __init__(self, config=None, parent=None, **kwargs): @@ -85,17 +86,20 @@ def plot(self, event, telid): tmaxmin = event.dl0.tel[telid].waveform.shape[2] t_chargemax = pulse_time[image.argmax()] cmap_time = colors.LinearSegmentedColormap.from_list( - 'cmap_t', - [(0 / tmaxmin, 'darkgreen'), - (0.6 * t_chargemax / tmaxmin, 'green'), - (t_chargemax / tmaxmin, 'yellow'), - (1.4 * t_chargemax / tmaxmin, 'blue'), (1, 'darkblue')] + "cmap_t", + [ + (0 / tmaxmin, "darkgreen"), + (0.6 * t_chargemax / tmaxmin, "green"), + (t_chargemax / tmaxmin, "yellow"), + (1.4 * t_chargemax / tmaxmin, "blue"), + (1, "darkblue"), + ], ) self.c_pulse_time.pixels.set_cmap(cmap_time) if not self.cb_intensity: self.c_intensity.add_colorbar( - ax=self.ax_intensity, label='Intensity (p.e.)' + ax=self.ax_intensity, label="Intensity (p.e.)" ) self.cb_intensity = self.c_intensity.colorbar else: @@ -103,7 +107,7 @@ def plot(self, event, telid): self.c_intensity.update(True) if not self.cb_pulse_time: self.c_pulse_time.add_colorbar( - ax=self.ax_pulse_time, label='Pulse Time (ns)' + ax=self.ax_pulse_time, label="Pulse Time (ns)" ) self.cb_pulse_time = self.c_pulse_time.colorbar else: @@ -115,8 +119,9 @@ def plot(self, event, telid): self.c_pulse_time.image = pulse_time self.fig.suptitle( - "Event_index={} Event_id={} Telescope={}" - .format(event.count, event.r0.event_id, telid) + "Event_index={} Event_id={} Telescope={}".format( + event.count, event.r0.event_id, telid + ) ) if self.display: @@ -137,41 +142,32 @@ class DisplayDL1Calib(Tool): telescope = Int( None, allow_none=True, - help='Telescope to view. Set to None to display all ' - 'telescopes.' + help="Telescope to view. Set to None to display all telescopes.", ).tag(config=True) - extractor_product = tool_utils.enum_trait( - ImageExtractor, - default='NeighborPeakWindowSum' + extractor_product = traits.enum_trait( + ImageExtractor, default="NeighborPeakWindowSum" ) aliases = Dict( dict( - max_events='EventSource.max_events', - extractor='DisplayDL1Calib.extractor_product', - T='DisplayDL1Calib.telescope', - O='ImagePlotter.output_path' + input="EventSource.input_url", + max_events="EventSource.max_events", + extractor="DisplayDL1Calib.extractor_product", + T="DisplayDL1Calib.telescope", + O="ImagePlotter.output_path", ) ) flags = Dict( dict( D=( - { - 'ImagePlotter': { - 'display': True - } - }, - "Display the photoelectron images on-screen as they " - "are produced." + {"ImagePlotter": {"display": True}}, + "Display the photo-electron images on-screen as they are produced.", ) ) ) classes = List( - [ - EventSource, - ImagePlotter - ] + tool_utils.classes_with_traits(ImageExtractor) + [EventSource, ImagePlotter] + traits.classes_with_traits(ImageExtractor) ) def __init__(self, **kwargs): @@ -181,14 +177,14 @@ def __init__(self, **kwargs): self.plotter = None def setup(self): - self.eventsource = EventSource.from_url( - get_dataset_path("gamma_test_large.simtel.gz"), - parent=self, + self.eventsource = self.add_component( + EventSource.from_url( + get_dataset_path("gamma_test_large.simtel.gz"), parent=self + ) ) - self.calibrator = CameraCalibrator(parent=self) - - self.plotter = ImagePlotter(parent=self) + self.calibrator = self.add_component(CameraCalibrator(parent=self)) + self.plotter = self.add_component(ImagePlotter(parent=self)) def start(self): for event in self.eventsource: diff --git a/ctapipe/tools/display_events_single_tel.py b/ctapipe/tools/display_events_single_tel.py index 88327546c4c..7fde80cb196 100755 --- a/ctapipe/tools/display_events_single_tel.py +++ b/ctapipe/tools/display_events_single_tel.py @@ -73,10 +73,14 @@ def __init__(self, **kwargs): def setup(self): print('TOLLES INFILE', self.infile) - self.event_source = EventSource.from_url(self.infile, parent=self) + self.event_source = self.add_component( + EventSource.from_url(self.infile, parent=self) + ) self.event_source.allowed_tels = {self.tel, } - self.calibrator = CameraCalibrator(parent=self) + self.calibrator = self.add_component( + CameraCalibrator(parent=self) + ) self.log.info(f'SELECTING EVENTS FROM TELESCOPE {self.tel}') diff --git a/ctapipe/tools/display_integrator.py b/ctapipe/tools/display_integrator.py index 6843e4fedda..74114d48cfd 100644 --- a/ctapipe/tools/display_integrator.py +++ b/ctapipe/tools/display_integrator.py @@ -7,12 +7,12 @@ from matplotlib import pyplot as plt from traitlets import Dict, List, Int, Bool, Enum -import ctapipe.utils.tools as tool_utils +from ctapipe.core import traits from ctapipe.calib import CameraCalibrator from ctapipe.core import Tool from ctapipe.image.extractor import ImageExtractor -from ctapipe.io.eventseeker import EventSeeker from ctapipe.io import EventSource +from ctapipe.io.eventseeker import EventSeeker from ctapipe.visualization import CameraDisplay @@ -38,7 +38,7 @@ def plot(event, telid, chan, extractor_name): ax_max_nei = {} ax_min_nei = {} fig_waveforms = plt.figure(figsize=(18, 9)) - fig_waveforms.subplots_adjust(hspace=.5) + fig_waveforms.subplots_adjust(hspace=0.5) fig_camera = plt.figure(figsize=(15, 12)) ax_max_pix = fig_waveforms.add_subplot(4, 2, 1) @@ -60,8 +60,8 @@ def plot(event, telid, chan, extractor_name): ax_max_pix.set_xlabel("Time (ns)") ax_max_pix.set_ylabel("DL0 Samples (ADC)") ax_max_pix.set_title( - f'(Max) Pixel: {max_pix}, True: {t_pe[max_pix]}, ' - f'Measured = {dl1[max_pix]:.3f}' + f"(Max) Pixel: {max_pix}, True: {t_pe[max_pix]}, " + f"Measured = {dl1[max_pix]:.3f}" ) max_ylim = ax_max_pix.get_ylim() for i, ax in ax_max_nei.items(): @@ -71,8 +71,9 @@ def plot(event, telid, chan, extractor_name): ax.set_xlabel("Time (ns)") ax.set_ylabel("DL0 Samples (ADC)") ax.set_title( - "(Max Nei) Pixel: {}, True: {}, Measured = {:.3f}" - .format(pix, t_pe[pix], dl1[pix]) + "(Max Nei) Pixel: {}, True: {}, Measured = {:.3f}".format( + pix, t_pe[pix], dl1[pix] + ) ) ax.set_ylim(max_ylim) @@ -81,8 +82,8 @@ def plot(event, telid, chan, extractor_name): ax_min_pix.set_xlabel("Time (ns)") ax_min_pix.set_ylabel("DL0 Samples (ADC)") ax_min_pix.set_title( - f'(Min) Pixel: {min_pix}, True: {t_pe[min_pix]}, ' - f'Measured = {dl1[min_pix]:.3f}' + f"(Min) Pixel: {min_pix}, True: {t_pe[min_pix]}, " + f"Measured = {dl1[min_pix]:.3f}" ) ax_min_pix.set_ylim(max_ylim) for i, ax in ax_min_nei.items(): @@ -92,8 +93,8 @@ def plot(event, telid, chan, extractor_name): ax.set_xlabel("Time (ns)") ax.set_ylabel("DL0 Samples (ADC)") ax.set_title( - f'(Min Nei) Pixel: {pix}, True: {t_pe[pix]}, ' - f'Measured = {dl1[pix]:.3f}' + f"(Min Nei) Pixel: {pix}, True: {t_pe[pix]}, " + f"Measured = {dl1[pix]:.3f}" ) ax.set_ylim(max_ylim) @@ -109,22 +110,22 @@ def plot(event, telid, chan, extractor_name): ax_img_nei.annotate( f"Pixel: {max_pix}", xy=(geom.pix_x.value[max_pix], geom.pix_y.value[max_pix]), - xycoords='data', + xycoords="data", xytext=(0.05, 0.98), - textcoords='axes fraction', - arrowprops=dict(facecolor='red', width=2, alpha=0.4), - horizontalalignment='left', - verticalalignment='top' + textcoords="axes fraction", + arrowprops=dict(facecolor="red", width=2, alpha=0.4), + horizontalalignment="left", + verticalalignment="top", ) ax_img_nei.annotate( f"Pixel: {min_pix}", xy=(geom.pix_x.value[min_pix], geom.pix_y.value[min_pix]), - xycoords='data', + xycoords="data", xytext=(0.05, 0.94), - textcoords='axes fraction', - arrowprops=dict(facecolor='orange', width=2, alpha=0.4), - horizontalalignment='left', - verticalalignment='top' + textcoords="axes fraction", + arrowprops=dict(facecolor="orange", width=2, alpha=0.4), + horizontalalignment="left", + verticalalignment="top", ) camera = CameraDisplay(geom, ax=ax_img_max) camera.image = dl0[:, max_time] @@ -133,22 +134,22 @@ def plot(event, telid, chan, extractor_name): ax_img_max.annotate( f"Pixel: {max_pix}", xy=(geom.pix_x.value[max_pix], geom.pix_y.value[max_pix]), - xycoords='data', + xycoords="data", xytext=(0.05, 0.98), - textcoords='axes fraction', - arrowprops=dict(facecolor='red', width=2, alpha=0.4), - horizontalalignment='left', - verticalalignment='top' + textcoords="axes fraction", + arrowprops=dict(facecolor="red", width=2, alpha=0.4), + horizontalalignment="left", + verticalalignment="top", ) ax_img_max.annotate( f"Pixel: {min_pix}", xy=(geom.pix_x.value[min_pix], geom.pix_y.value[min_pix]), - xycoords='data', + xycoords="data", xytext=(0.05, 0.94), - textcoords='axes fraction', - arrowprops=dict(facecolor='orange', width=2, alpha=0.4), - horizontalalignment='left', - verticalalignment='top' + textcoords="axes fraction", + arrowprops=dict(facecolor="orange", width=2, alpha=0.4), + horizontalalignment="left", + verticalalignment="top", ) camera = CameraDisplay(geom, ax=ax_img_true) @@ -158,22 +159,22 @@ def plot(event, telid, chan, extractor_name): ax_img_true.annotate( f"Pixel: {max_pix}", xy=(geom.pix_x.value[max_pix], geom.pix_y.value[max_pix]), - xycoords='data', + xycoords="data", xytext=(0.05, 0.98), - textcoords='axes fraction', - arrowprops=dict(facecolor='red', width=2, alpha=0.4), - horizontalalignment='left', - verticalalignment='top' + textcoords="axes fraction", + arrowprops=dict(facecolor="red", width=2, alpha=0.4), + horizontalalignment="left", + verticalalignment="top", ) ax_img_true.annotate( f"Pixel: {min_pix}", xy=(geom.pix_x.value[min_pix], geom.pix_y.value[min_pix]), - xycoords='data', + xycoords="data", xytext=(0.05, 0.94), - textcoords='axes fraction', - arrowprops=dict(facecolor='orange', width=2, alpha=0.4), - horizontalalignment='left', - verticalalignment='top' + textcoords="axes fraction", + arrowprops=dict(facecolor="orange", width=2, alpha=0.4), + horizontalalignment="left", + verticalalignment="top", ) camera = CameraDisplay(geom, ax=ax_img_cal) @@ -183,22 +184,22 @@ def plot(event, telid, chan, extractor_name): ax_img_cal.annotate( f"Pixel: {max_pix}", xy=(geom.pix_x.value[max_pix], geom.pix_y.value[max_pix]), - xycoords='data', + xycoords="data", xytext=(0.05, 0.98), - textcoords='axes fraction', - arrowprops=dict(facecolor='red', width=2, alpha=0.4), - horizontalalignment='left', - verticalalignment='top' + textcoords="axes fraction", + arrowprops=dict(facecolor="red", width=2, alpha=0.4), + horizontalalignment="left", + verticalalignment="top", ) ax_img_cal.annotate( f"Pixel: {min_pix}", xy=(geom.pix_x.value[min_pix], geom.pix_y.value[min_pix]), - xycoords='data', + xycoords="data", xytext=(0.05, 0.94), - textcoords='axes fraction', - arrowprops=dict(facecolor='orange', width=2, alpha=0.4), - horizontalalignment='left', - verticalalignment='top' + textcoords="axes fraction", + arrowprops=dict(facecolor="orange", width=2, alpha=0.4), + horizontalalignment="left", + verticalalignment="top", ) fig_waveforms.suptitle(f"Integrator = {extractor_name}") @@ -211,51 +212,42 @@ class DisplayIntegrator(Tool): name = "ctapipe-display-integration" description = __doc__ - event_index = Int(0, help='Event index to view.').tag(config=True) + event_index = Int(0, help="Event index to view.").tag(config=True) use_event_id = Bool( False, - help='event_index will obtain an event using event_id instead of ' - 'index.' + help="event_index will obtain an event using event_id instead of index.", ).tag(config=True) telescope = Int( None, allow_none=True, - help='Telescope to view. Set to None to display the first' - 'telescope with data.' + help="Telescope to view. Set to None to display the first" + "telescope with data.", ).tag(config=True) - channel = Enum([0, 1], 0, help='Channel to view').tag(config=True) + channel = Enum([0, 1], 0, help="Channel to view").tag(config=True) - extractor_product = tool_utils.enum_trait( - ImageExtractor, - default='NeighborPeakWindowSum' + extractor_product = traits.enum_trait( + ImageExtractor, default="NeighborPeakWindowSum" ) aliases = Dict( dict( - f='EventSource.input_url', - max_events='EventSource.max_events', - extractor='DisplayIntegrator.extractor_product', - E='DisplayIntegrator.event_index', - T='DisplayIntegrator.telescope', - C='DisplayIntegrator.channel', + f="EventSource.input_url", + max_events="EventSource.max_events", + extractor="DisplayIntegrator.extractor_product", + E="DisplayIntegrator.event_index", + T="DisplayIntegrator.telescope", + C="DisplayIntegrator.channel", ) ) flags = Dict( dict( id=( - { - 'DisplayDL1Calib': { - 'use_event_index': True - } - }, 'event_index will obtain an event using ' - 'event_id instead of index.') + {"DisplayDL1Calib": {"use_event_index": True}}, + "event_index will obtain an event using event_id instead of index.", + ) ) ) - classes = List( - [ - EventSource, - ] + tool_utils.classes_with_traits(ImageExtractor) - ) + classes = List([EventSource] + traits.classes_with_traits(ImageExtractor)) def __init__(self, **kwargs): super().__init__(**kwargs) @@ -268,15 +260,13 @@ def __init__(self, **kwargs): def setup(self): self.log_format = "%(levelname)s: %(message)s [%(name)s.%(funcName)s]" - event_source = EventSource.from_config(parent=self) - self.eventseeker = EventSeeker(event_source, parent=self) - self.extractor = ImageExtractor.from_name( - self.extractor_product, - parent=self, + event_source = self.add_component(EventSource.from_config(parent=self)) + self.eventseeker = self.add_component(EventSeeker(event_source, parent=self)) + self.extractor = self.add_component( + ImageExtractor.from_name(self.extractor_product, parent=self) ) - self.calibrator = CameraCalibrator( - parent=self, - image_extractor=self.extractor, + self.calibrate = self.add_component( + CameraCalibrator(parent=self, image_extractor=self.extractor) ) def start(self): @@ -286,7 +276,7 @@ def start(self): event = self.eventseeker[event_num] # Calibrate - self.calibrator(event) + self.calibrate(event) # Select telescope tels = list(event.r0.tels_with_data) diff --git a/ctapipe/tools/extract_charge_resolution.py b/ctapipe/tools/extract_charge_resolution.py index 08a92e5cff4..17b94026137 100644 --- a/ctapipe/tools/extract_charge_resolution.py +++ b/ctapipe/tools/extract_charge_resolution.py @@ -4,54 +4,51 @@ """ import os + import numpy as np import pandas as pd from tqdm import tqdm +from traitlets import Dict, Int, List, Unicode -from traitlets import Dict, List, Int, Unicode - -import ctapipe.utils.tools as tool_utils - -from ctapipe.analysis.camera.charge_resolution import \ - ChargeResolutionCalculator +from ctapipe.analysis.camera.charge_resolution import ChargeResolutionCalculator from ctapipe.calib import CameraCalibrator -from ctapipe.core import Tool, Provenance +from ctapipe.core import Provenance, Tool, traits from ctapipe.image.extractor import ImageExtractor - from ctapipe.io.simteleventsource import SimTelEventSource class ChargeResolutionGenerator(Tool): name = "ChargeResolutionGenerator" - description = ("Calculate the Charge Resolution from a sim_telarray " - "simulation and store within a HDF5 file.") + description = ( + "Calculate the Charge Resolution from a sim_telarray " + "simulation and store within a HDF5 file." + ) - telescopes = List(Int, None, allow_none=True, - help='Telescopes to include from the event file. ' - 'Default = All telescopes').tag(config=True) + telescopes = List( + Int, + None, + allow_none=True, + help="Telescopes to include from the event file. Default = All telescopes", + ).tag(config=True) output_path = Unicode( - 'charge_resolution.h5', - help='Path to store the output HDF5 file' + "charge_resolution.h5", help="Path to store the output HDF5 file" ).tag(config=True) - extractor_product = tool_utils.enum_trait( - ImageExtractor, - default='NeighborPeakWindowSum' + extractor_product = traits.enum_trait( + ImageExtractor, default="NeighborPeakWindowSum" ) - aliases = Dict(dict( - f='SimTelEventSource.input_url', - max_events='SimTelEventSource.max_events', - T='SimTelEventSource.allowed_tels', - extractor='ChargeResolutionGenerator.extractor_product', - O='ChargeResolutionGenerator.output_path', - )) - - classes = List( - [ - SimTelEventSource, - ] + tool_utils.classes_with_traits(ImageExtractor) + aliases = Dict( + dict( + f="SimTelEventSource.input_url", + max_events="SimTelEventSource.max_events", + T="SimTelEventSource.allowed_tels", + extractor="ChargeResolutionGenerator.extractor_product", + O="ChargeResolutionGenerator.output_path", + ) ) + classes = List([SimTelEventSource] + traits.classes_with_traits(ImageExtractor)) + def __init__(self, **kwargs): super().__init__(**kwargs) self.eventsource = None @@ -61,16 +58,14 @@ def __init__(self, **kwargs): def setup(self): self.log_format = "%(levelname)s: %(message)s [%(name)s.%(funcName)s]" - self.eventsource = SimTelEventSource(parent=self) + self.eventsource = self.add_component(SimTelEventSource(parent=self)) - extractor = ImageExtractor.from_name( - self.extractor_product, - parent=self + extractor = self.add_component( + ImageExtractor.from_name(self.extractor_product, parent=self) ) - self.calibrator = CameraCalibrator( - parent=self, - image_extractor=extractor, + self.calibrator = self.add_component( + CameraCalibrator(parent=self, image_extractor=extractor) ) self.calculator = ChargeResolutionCalculator() @@ -86,9 +81,7 @@ def start(self): if np.all(pe == 0): raise KeyError except KeyError: - self.log.exception( - 'Source does not contain true charge!' - ) + self.log.exception("Source does not contain true charge!") raise for mc, dl1 in zip(event.mc.tel.values(), event.dl1.tel.values()): @@ -105,12 +98,11 @@ def finish(self): self.log.info(f"Creating directory: {output_directory}") os.makedirs(output_directory) - with pd.HDFStore(self.output_path, 'w') as store: - store['charge_resolution_pixel'] = df_p - store['charge_resolution_camera'] = df_c + with pd.HDFStore(self.output_path, "w") as store: + store["charge_resolution_pixel"] = df_p + store["charge_resolution_camera"] = df_c - self.log.info("Created charge resolution file: {}" - .format(self.output_path)) + self.log.info("Created charge resolution file: {}".format(self.output_path)) Provenance().add_output_file(self.output_path) @@ -119,5 +111,5 @@ def main(): exe.run() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/ctapipe/tools/muon_reconstruction.py b/ctapipe/tools/muon_reconstruction.py index a6fb68f09b4..131b3484316 100644 --- a/ctapipe/tools/muon_reconstruction.py +++ b/ctapipe/tools/muon_reconstruction.py @@ -17,7 +17,7 @@ from ctapipe.core import traits as t from ctapipe.image.muon.muon_diagnostic_plots import plot_muon_event from ctapipe.image.muon.muon_reco_functions import analyze_muon_event -from ctapipe.io import EventSource, event_source +from ctapipe.io import EventSource from ctapipe.io import HDF5TableWriter warnings.filterwarnings("ignore") # Supresses iminuit warnings @@ -40,11 +40,10 @@ class MuonDisplayerTool(Tool): name = 'ctapipe-reconstruct-muons' description = t.Unicode(__doc__) - events = t.Unicode("", - help="input event data file").tag(config=True) - - outfile = t.Unicode("muons.hdf5", help='HDF5 output file name').tag( - config=True) + outfile = t.Unicode( + "muons.hdf5", + help='HDF5 output file name' + ).tag(config=True) display = t.Bool( help='display the camera events', default=False @@ -55,7 +54,7 @@ class MuonDisplayerTool(Tool): ]) aliases = t.Dict({ - 'input': 'MuonDisplayerTool.events', + 'input': 'EventSource.input_url', 'outfile': 'MuonDisplayerTool.outfile', 'display': 'MuonDisplayerTool.display', 'max_events': 'EventSource.max_events', @@ -63,12 +62,14 @@ class MuonDisplayerTool(Tool): }) def setup(self): - if self.events == '': + self.source: EventSource = self.add_component( + EventSource.from_config(parent=self) + ) + if self.source.input_url == '': raise ToolConfigurationError("please specify --input ") - self.log.debug("input: %s", self.events) - self.source = event_source(self.events) - self.calib = CameraCalibrator(parent=self) - self.writer = HDF5TableWriter(self.outfile, "muons") + self.calib = self.add_component(CameraCalibrator(parent=self)) + self.writer = self.add_component(HDF5TableWriter(self.outfile, "muons")) + def start(self): diff --git a/ctapipe/tools/plot_charge_resolution.py b/ctapipe/tools/plot_charge_resolution.py index 03fa53864bb..3913de54ea1 100644 --- a/ctapipe/tools/plot_charge_resolution.py +++ b/ctapipe/tools/plot_charge_resolution.py @@ -3,6 +3,7 @@ """ import numpy as np from traitlets import Dict, List, Unicode + from ctapipe.core import Tool from ctapipe.plotting.charge_resolution import ChargeResolutionPlotter @@ -28,12 +29,11 @@ class ChargeResolutionViewer(Tool): def __init__(self, **kwargs): super().__init__(**kwargs) - self.calculator = None self.plotter = None def setup(self): self.log_format = "%(levelname)s: %(message)s [%(name)s.%(funcName)s]" - self.plotter = ChargeResolutionPlotter(parent=self) + self.plotter = self.add_component(ChargeResolutionPlotter(parent=self)) def start(self): for fp in self.input_files: diff --git a/ctapipe/tools/tests/test_tools.py b/ctapipe/tools/tests/test_tools.py index 4266098342c..ad572122372 100644 --- a/ctapipe/tools/tests/test_tools.py +++ b/ctapipe/tools/tests/test_tools.py @@ -1,8 +1,9 @@ import os -import sys -import pytest import shlex +import sys + import matplotlib as mpl +import pytest from ctapipe.utils import get_dataset_path @@ -14,7 +15,7 @@ def test_muon_reconstruction(tmpdir): tool = MuonDisplayerTool() tool.run( argv=shlex.split( - f'--events={GAMMA_TEST_LARGE} ' + f'--input={GAMMA_TEST_LARGE} ' '--max_events=2 ' ) ) diff --git a/ctapipe/utils/tests/test_tools.py b/ctapipe/utils/tests/test_tools.py deleted file mode 100644 index da1be9d80df..00000000000 --- a/ctapipe/utils/tests/test_tools.py +++ /dev/null @@ -1,44 +0,0 @@ -import pytest -import traitlets -from traitlets import HasTraits -from traitlets import Int -# using this class as test input -from ctapipe.image.extractor import ImageExtractor - - -def test_enum_trait_default_is_right(): - # function under test - from ctapipe.utils.tools import enum_trait - - with pytest.raises(ValueError): - enum_trait(ImageExtractor, default='name_of_default_choice') - - -def test_enum_trait(): - # function under test - from ctapipe.utils.tools import enum_trait - - trait = enum_trait(ImageExtractor, default='NeighborPeakWindowSum') - assert isinstance(trait, traitlets.traitlets.CaselessStrEnum) - - -def test_enum_classes_with_traits(): - # function under test - from ctapipe.utils.tools import classes_with_traits - - list_of_classes = classes_with_traits(ImageExtractor) - assert list_of_classes # should not be empty - - -def test_has_traits(): - # function under test - from ctapipe.utils.tools import has_traits - - class WithoutTraits(HasTraits): - pass - - class WithATrait(HasTraits): - my_trait = Int() - - assert not has_traits(WithoutTraits) - assert has_traits(WithATrait) diff --git a/ctapipe/utils/tools.py b/ctapipe/utils/tools.py deleted file mode 100644 index 4ca4c7cba96..00000000000 --- a/ctapipe/utils/tools.py +++ /dev/null @@ -1,53 +0,0 @@ -'''some utils for Tool Developers -''' -from ctapipe.core import non_abstract_children -from traitlets import CaselessStrEnum - - -def enum_trait(base_class, default, help_str=None): - '''create a configurable CaselessStrEnum traitlet from baseclass - - the enumeration should contain all names of non_abstract_children() - of said baseclass and the default choice should be given by - `base_class._default` name. - - default must be specified and must be the name of one child-class - ''' - if help_str is None: - help_str = '{} to use.'.format(base_class.__name__) - - choices = [ - cls.__name__ - for cls in non_abstract_children(base_class) - ] - if default not in choices: - raise ValueError( - '{default} is not in choices: {choices}'.format( - default=default, - choices=choices, - ) - ) - - return CaselessStrEnum( - choices, - default, - allow_none=True, - help=help_str - ).tag(config=True) - - -def classes_with_traits(base_class): - all_classes = [base_class] + non_abstract_children(base_class) - return [cls for cls in all_classes if has_traits(cls)] - - -def has_traits(cls, ignore=('config', 'parent')): - '''True if cls has any traits apart from the usual ones - - all our components have at least 'config' and 'parent' as traitlets - this is inherited from `traitlets.config.Configurable` so we ignore them - here. - ''' - return bool( - set(cls.class_trait_names()) - set(ignore) - ) diff --git a/docs/examples/Tools.ipynb b/docs/examples/Tools.ipynb index 918498efcc1..db9fd6fd77d 100644 --- a/docs/examples/Tools.ipynb +++ b/docs/examples/Tools.ipynb @@ -42,8 +42,7 @@ "outputs": [], "source": [ "class MyComponent(Component):\n", - " description = \"Do some things\"\n", - "\n", + " \"\"\" A Component that does stuff \"\"\"\n", " value = Integer(default_value=-1, help=\"Value to use\").tag(config=True)\n", "\n", " def do_thing(self):\n", @@ -55,8 +54,7 @@ "\n", "\n", "class AdvancedComponent(Component):\n", - " name=\"AdvancedComponent\"\n", - " description = \"something more advanced\"\n", + " \"\"\" An advanced technique \"\"\"\n", "\n", " value1 = Integer(default_value=-1, help=\"Value to use\").tag(config=True)\n", " infile = Unicode(help=\"input file name\").tag(config=True)\n", @@ -67,6 +65,24 @@ " self.log.warning(\"Outfile was changed to '{}'\".format(change))" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "MyComponent()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "AdvancedComponent()" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -93,11 +109,15 @@ " iterations = Integer(5,help=\"Number of times to run\",allow_none=False).tag(config=True)\n", "\n", " def setup_comp(self):\n", - " self.comp = MyComponent(parent=self)\n", - " self.comp2 = SecondaryMyComponent(parent=self)\n", + " # when constructing Components, you must add them to the \n", + " # list of registered instances using add_component. This allows\n", + " # the full configuration to be tracked\n", + " self.comp = self.add_component(MyComponent(parent=self))\n", + " self.comp2 = self.add_component(SecondaryMyComponent(parent=self))\n", + " \n", "\n", " def setup_advanced(self):\n", - " self.advanced = AdvancedComponent(parent=self)\n", + " self.advanced = self.add_component(AdvancedComponent(parent=self))\n", "\n", " def setup(self):\n", " self.setup_comp()\n", @@ -322,6 +342,13 @@ "print(tool2.config)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, @@ -366,6 +393,75 @@ "source": [ "tool3.is_setup" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tool3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tool" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tool.comp2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Getting the configuration of an instance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tool.get_current_config()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tool.iterations = 12\n", + "tool.get_current_config()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Writing a Sample Config File" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(tool.generate_config_file())" + ] } ], "metadata": { @@ -384,7 +480,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.8" + "version": "3.7.3" } }, "nbformat": 4, diff --git a/examples/simple_event_writer.py b/examples/simple_event_writer.py index 826852d9867..3d0f939d7c1 100755 --- a/examples/simple_event_writer.py +++ b/examples/simple_event_writer.py @@ -38,18 +38,23 @@ class SimpleEventWriter(Tool): def setup(self): self.log.info('Configure EventSource...') - self.event_source = EventSource.from_config( - config=self.config, - parent=self + self.event_source = self.add_component( + EventSource.from_config( + config=self.config, + parent=self + ) ) - self.event_source.allowed_tels = self.config['Analysis']['allowed_tels'] - self.calibrator = CameraCalibrator( - parent=self + self.calibrator = self.add_component( + CameraCalibrator(parent=self) ) - self.writer = HDF5TableWriter( - filename=self.outfile, group_name='image_infos', overwrite=True + self.writer = self.add_component( + HDF5TableWriter( + filename=self.outfile, + group_name='image_infos', + overwrite=True + ) ) # Define Pre-selection for images diff --git a/examples/tool_example.py b/examples/tool_example.py new file mode 100644 index 00000000000..20572a72c1a --- /dev/null +++ b/examples/tool_example.py @@ -0,0 +1,106 @@ +"""A simple example of how to use traitlets.config.application.Application. +This should serve as a simple example that shows how the traitlets config +system works. The main classes are: +* traitlets.config.Configurable +* traitlets.config.SingletonConfigurable +* traitlets.config.Config +* traitlets.config.Application +To see the command line option help, run this program from the command line:: + $ python test_tool.py --help +To make one of your classes configurable (from the command line and config +files) inherit from Configurable and declare class attributes as traits (see +classes Foo and Bar below). To make the traits configurable, you will need +to set the following options: +* ``config``: set to ``True`` to make the attribute configurable. +* ``shortname``: by default, configurable attributes are set using the syntax + "Classname.attributename". At the command line, this is a bit verbose, so + we allow "shortnames" to be declared. Setting a shortname is optional, but + when you do this, you can set the option at the command line using the + syntax: "shortname=value". +* ``help``: set the help string to display a help message when the ``-h`` + option is given at the command line. The help string should be valid ReST. +When the config attribute of an Application is updated, it will fire all of +the trait's events for all of the config=True attributes. +""" + +from traitlets import Bool, Unicode, Int, List, Dict + +from ctapipe.core import Component, Tool + + +class AComponent(Component): + """ + A class that has configurable, typed attributes. + """ + + i = Int(0, help="The integer i.").tag(config=True) + j = Int(1, help="The integer j.").tag(config=True) + name = Unicode("Brian", help="First name.").tag(config=True) + + def __call__(self): + self.log.info("CALLED FOO") + + +class BComponent(Component): + """ Some Other Component """ + + enabled = Bool(True, help="Enable bar.").tag(config=True) + + +class MyTool(Tool): + """ My Tool """ + + name = Unicode("myapp") + running = Bool(False, help="Is the app running?").tag(config=True) + classes = List([BComponent, AComponent]) + config_file = Unicode("", help="Load this config file").tag(config=True) + + aliases = Dict( + dict( + i="Foo.i", + j="Foo.j", + name="Foo.name", + running="MyApp.running", + enabled="Bar.enabled", + log_level="MyApp.log_level", + ) + ) + + flags = Dict( + dict( + enable=({"Bar": {"enabled": True}}, "Enable Bar"), + disable=({"Bar": {"enabled": False}}, "Disable Bar"), + debug=({"MyApp": {"log_level": 10}}, "Set loglevel to DEBUG"), + ) + ) + + def init_a_component(self): + """ setup the Foo component""" + self.log.info("INIT FOO") + self.a_component = self.add_component(AComponent(parent=self)) + + def init_b_component(self): + """ setup the Bar component""" + self.log.info("INIT BAR") + self.b_component = self.add_component(BComponent(parent=self)) + + def setup(self): + """ Setup all components and the tool""" + self.init_a_component() + self.init_b_component() + + def start(self): + """ run the tool""" + self.log.info("app.config:") + self.log.info("THE CONFIGURATION: %s", self.get_current_config()) + self.a_component() + + +def main(): + """ run the app """ + tool = MyTool() + tool.run() + + +if __name__ == "__main__": + main()