diff --git a/docs/conf.py b/docs/conf.py index 40e605b721..7bb6962aaa 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -77,6 +77,7 @@ def setup(app): 'matplotlib': ('https://matplotlib.org/stable/', None), 'iminuit': ('https://iminuit.readthedocs.io/en/stable/', None), 'uproot': ('https://uproot.readthedocs.io/en/latest/', None), + 'jsonpatch': ('https://python-json-patch.readthedocs.io/en/latest/', None), } # GitHub repo diff --git a/src/pyhf/workspace.py b/src/pyhf/workspace.py index a9bfe18542..1abcf5113f 100644 --- a/src/pyhf/workspace.py +++ b/src/pyhf/workspace.py @@ -319,24 +319,20 @@ def __repr__(self): """Representation of the Workspace.""" return object.__repr__(self) - def get_measurement( - self, poi_name=None, measurement_name=None, measurement_index=None - ): + def get_measurement(self, measurement_name=None, measurement_index=None): """ - Get (or create) a measurement object. + Get a measurement object. The following logic is used: - 1. if the poi name is given, create a measurement object for that poi - 2. if the measurement name is given, find the measurement for the given name - 3. if the measurement index is given, return the measurement at that index - 4. if there are measurements but none of the above have been specified, return the 0th measurement + 1. if the measurement name is given, find the measurement for the given name + 2. if the measurement index is given, return the measurement at that index + 3. if there are measurements but none of the above have been specified, return the 0th measurement Raises: ~pyhf.exceptions.InvalidMeasurement: If the measurement was not found Args: - poi_name (:obj:`str`): The name of the parameter of interest to create a new measurement from measurement_name (:obj:`str`): The name of the measurement to use measurement_index (:obj:`int`): The index of the measurement to use @@ -345,12 +341,7 @@ def get_measurement( """ measurement = None - if poi_name is not None: - measurement = { - 'name': 'NormalMeasurement', - 'config': {'poi': poi_name, 'parameters': []}, - } - elif self.measurement_names: + if self.measurement_names: if measurement_name is not None: if measurement_name not in self.measurement_names: log.debug(f"measurements defined: {self.measurement_names}") @@ -381,41 +372,57 @@ def get_measurement( utils.validate(measurement, 'measurement.json', self.version) return measurement - def model(self, **config_kwargs): + def model( + self, + measurement_name=None, + measurement_index=None, + patches=None, + **config_kwargs, + ): """ Create a model object with/without patches applied. See :func:`pyhf.workspace.Workspace.get_measurement` and :class:`pyhf.pdf.Model` for possible keyword arguments. Args: - patches: A list of JSON patches to apply to the model specification - config_kwargs: Possible keyword arguments for the measurement and model configuration + measurement_name (:obj:`str`): The name of the measurement to use + in :func:`~pyhf.workspace.Workspace.get_measurement`. + measurement_index (:obj:`int`): The index of the measurement to use + in :func:`~pyhf.workspace.Workspace.get_measurement`. + patches (:obj:`list` of :class:`jsonpatch.JsonPatch` or :class:`pyhf.patchset.Patch`): + A list of patches to apply to the model specification. + config_kwargs: Possible keyword arguments for the model + configuration. + See :class:`~pyhf.pdf.Model` for more details. + poi_name (:obj:`str` or :obj:`None`): Specify this keyword argument + to override the default parameter of interest specified in the + measurement. + Set to :obj:`None` for a POI-less model. Returns: ~pyhf.pdf.Model: A model object adhering to the schema model.json """ - - poi_name = config_kwargs.pop('poi_name', None) - measurement_name = config_kwargs.pop('measurement_name', None) - measurement_index = config_kwargs.pop('measurement_index', None) measurement = self.get_measurement( - poi_name=poi_name, measurement_name=measurement_name, measurement_index=measurement_index, ) - log.debug(f"model being created for measurement {measurement['name']:s}") - patches = config_kwargs.pop('patches', []) + # set poi_name if the user does not provide it + config_kwargs.setdefault('poi_name', measurement['config']['poi']) + + log.debug(f"model being created for measurement {measurement['name']:s}") modelspec = { 'channels': self['channels'], 'parameters': measurement['config']['parameters'], } + + patches = patches or [] for patch in patches: modelspec = jsonpatch.JsonPatch(patch).apply(modelspec) - return Model(modelspec, poi_name=measurement['config']['poi'], **config_kwargs) + return Model(modelspec, **config_kwargs) def data(self, model, include_auxdata=True): """ diff --git a/tests/test_workspace.py b/tests/test_workspace.py index 0724968885..6c1a289ff0 100644 --- a/tests/test_workspace.py +++ b/tests/test_workspace.py @@ -67,12 +67,6 @@ def test_get_measurement(workspace_factory): assert m['name'] == w.measurement_names[measurement_idx] -def test_get_measurement_fake(workspace_factory): - w = workspace_factory() - m = w.get_measurement(poi_name='fake_poi') - assert m - - def test_get_measurement_nonexist(workspace_factory): w = workspace_factory() with pytest.raises(pyhf.exceptions.InvalidMeasurement) as excinfo: @@ -98,12 +92,6 @@ def test_get_measurement_no_measurements_defined(workspace_factory): def test_get_workspace_measurement_priority(workspace_factory): w = workspace_factory() - # does poi_name override all others? - m = w.get_measurement( - poi_name='fake_poi', measurement_name='FakeMeasurement', measurement_index=999 - ) - assert m['config']['poi'] == 'fake_poi' - # does measurement_name override measurement_index? m = w.get_measurement( measurement_name=w.measurement_names[0], measurement_index=999 @@ -136,6 +124,27 @@ def test_get_workspace_model_default(workspace_factory): assert m +def test_get_workspace_model_nopoi(workspace_factory): + w = workspace_factory() + m = w.model(poi_name=None) + + assert m.config.poi_name is None + assert m.config.poi_index is None + + +def test_get_workspace_model_overridepoi(workspace_factory): + w = workspace_factory() + m = w.model(poi_name='lumi') + + assert m.config.poi_name == 'lumi' + + +def test_get_workspace_model_fakepoi(workspace_factory): + w = workspace_factory() + with pytest.raises(pyhf.exceptions.InvalidModel): + w.model(poi_name='afakepoi') + + def test_workspace_observations(workspace_factory): w = workspace_factory() assert w.observations