Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow POI-less models via Workspace.model #1636

Merged
merged 14 commits into from
Oct 15, 2021
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def setup(app):
'matplotlib': ('https://matplotlib.org/stable/', None),
'iminuit': ('https://iminuit.readthedocs.io/en/stable/', None),
'uproot': ('https://uproot.readthedocs.io/en/latest/', None),
'jsonpatch': ('https://python-json-patch.readthedocs.io/en/latest/', None),
}

# GitHub repo
Expand Down
59 changes: 33 additions & 26 deletions src/pyhf/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,24 +319,20 @@ def __repr__(self):
"""Representation of the Workspace."""
return object.__repr__(self)

def get_measurement(
self, poi_name=None, measurement_name=None, measurement_index=None
):
def get_measurement(self, measurement_name=None, measurement_index=None):
"""
Get (or create) a measurement object.
Get a measurement object.

The following logic is used:

1. if the poi name is given, create a measurement object for that poi
2. if the measurement name is given, find the measurement for the given name
3. if the measurement index is given, return the measurement at that index
4. if there are measurements but none of the above have been specified, return the 0th measurement
1. if the measurement name is given, find the measurement for the given name
2. if the measurement index is given, return the measurement at that index
3. if there are measurements but none of the above have been specified, return the 0th measurement

Raises:
~pyhf.exceptions.InvalidMeasurement: If the measurement was not found

Args:
poi_name (:obj:`str`): The name of the parameter of interest to create a new measurement from
measurement_name (:obj:`str`): The name of the measurement to use
measurement_index (:obj:`int`): The index of the measurement to use

Expand All @@ -345,12 +341,7 @@ def get_measurement(

"""
measurement = None
if poi_name is not None:
measurement = {
'name': 'NormalMeasurement',
'config': {'poi': poi_name, 'parameters': []},
}
elif self.measurement_names:
if self.measurement_names:
if measurement_name is not None:
if measurement_name not in self.measurement_names:
log.debug(f"measurements defined: {self.measurement_names}")
Expand Down Expand Up @@ -381,41 +372,57 @@ def get_measurement(
utils.validate(measurement, 'measurement.json', self.version)
return measurement

def model(self, **config_kwargs):
def model(
self,
measurement_name=None,
measurement_index=None,
patches=None,
**config_kwargs,
):
"""
Create a model object with/without patches applied.

See :func:`pyhf.workspace.Workspace.get_measurement` and :class:`pyhf.pdf.Model` for possible keyword arguments.

Args:
patches: A list of JSON patches to apply to the model specification
config_kwargs: Possible keyword arguments for the measurement and model configuration
measurement_name (:obj:`str`): The name of the measurement to use
in :func:`~pyhf.workspace.Workspace.get_measurement`.
measurement_index (:obj:`int`): The index of the measurement to use
in :func:`~pyhf.workspace.Workspace.get_measurement`.
patches (:obj:`list` of :class:`jsonpatch.JsonPatch` or :class:`pyhf.patchset.Patch`):
A list of patches to apply to the model specification.
config_kwargs: Possible keyword arguments for the model
configuration.
See :class:`~pyhf.pdf.Model` for more details.
poi_name (:obj:`str` or :obj:`None`): Specify this keyword argument
to override the default parameter of interest specified in the
measurement.
Set to :obj:`None` for a POI-less model.

Returns:
~pyhf.pdf.Model: A model object adhering to the schema model.json

"""

poi_name = config_kwargs.pop('poi_name', None)
measurement_name = config_kwargs.pop('measurement_name', None)
measurement_index = config_kwargs.pop('measurement_index', None)
measurement = self.get_measurement(
poi_name=poi_name,
measurement_name=measurement_name,
measurement_index=measurement_index,
)
log.debug(f"model being created for measurement {measurement['name']:s}")

patches = config_kwargs.pop('patches', [])
# set poi_name if the user does not provide it
config_kwargs.setdefault('poi_name', measurement['config']['poi'])

log.debug(f"model being created for measurement {measurement['name']:s}")

modelspec = {
'channels': self['channels'],
'parameters': measurement['config']['parameters'],
}

patches = patches or []
for patch in patches:
modelspec = jsonpatch.JsonPatch(patch).apply(modelspec)

return Model(modelspec, poi_name=measurement['config']['poi'], **config_kwargs)
return Model(modelspec, **config_kwargs)

def data(self, model, include_auxdata=True):
"""
Expand Down
33 changes: 21 additions & 12 deletions tests/test_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,6 @@ def test_get_measurement(workspace_factory):
assert m['name'] == w.measurement_names[measurement_idx]


def test_get_measurement_fake(workspace_factory):
w = workspace_factory()
m = w.get_measurement(poi_name='fake_poi')
assert m


def test_get_measurement_nonexist(workspace_factory):
w = workspace_factory()
with pytest.raises(pyhf.exceptions.InvalidMeasurement) as excinfo:
Expand All @@ -98,12 +92,6 @@ def test_get_measurement_no_measurements_defined(workspace_factory):
def test_get_workspace_measurement_priority(workspace_factory):
w = workspace_factory()

# does poi_name override all others?
m = w.get_measurement(
poi_name='fake_poi', measurement_name='FakeMeasurement', measurement_index=999
)
assert m['config']['poi'] == 'fake_poi'

# does measurement_name override measurement_index?
m = w.get_measurement(
measurement_name=w.measurement_names[0], measurement_index=999
Expand Down Expand Up @@ -136,6 +124,27 @@ def test_get_workspace_model_default(workspace_factory):
assert m


def test_get_workspace_model_nopoi(workspace_factory):
w = workspace_factory()
m = w.model(poi_name=None)

assert m.config.poi_name is None
assert m.config.poi_index is None


def test_get_workspace_model_overridepoi(workspace_factory):
w = workspace_factory()
m = w.model(poi_name='lumi')

assert m.config.poi_name == 'lumi'


def test_get_workspace_model_fakepoi(workspace_factory):
w = workspace_factory()
with pytest.raises(pyhf.exceptions.InvalidModel):
w.model(poi_name='afakepoi')


def test_workspace_observations(workspace_factory):
w = workspace_factory()
assert w.observations
Expand Down