Skip to content

Commit

Permalink
feat: Handle absolute paths in XML config files (xml2json / readxml) (#…
Browse files Browse the repository at this point in the history
…1909)

* Add support for XML configurations with absolute paths by pruning the
paths out and renaming paths on-the-fly in readxml.
* Add -v/--mounts command line arguments for xml2json with similar behavior
to the Docker equivalent to support this feature.
* Add tests and test XML files with absolute paths.
  • Loading branch information
kratsg authored Aug 11, 2022
1 parent 1ed7003 commit 97664bc
Show file tree
Hide file tree
Showing 9 changed files with 339 additions and 23 deletions.
14 changes: 13 additions & 1 deletion src/pyhf/cli/rootio.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
from pathlib import Path
import jsonpatch
from pyhf.utils import VolumeMountPath

log = logging.getLogger(__name__)

Expand All @@ -23,14 +24,24 @@ def cli():
type=click.Path(exists=True),
default=Path.cwd(),
)
@click.option(
'-v',
'--mount',
help='Consists of two fields, separated by a colon character ( : ). The first field is the local path to where files are located, the second field is the path where the file or directory are saved in the XML configuration. This is similar in spirit to docker.',
type=VolumeMountPath(exists=True, resolve_path=True, path_type=Path),
default=None,
multiple=True,
)
@click.option(
'--output-file',
help='The location of the output json file. If not specified, prints to screen.',
default=None,
)
@click.option('--track-progress/--hide-progress', default=True)
@click.option('--validation-as-error/--validation-as-warning', default=True)
def xml2json(entrypoint_xml, basedir, output_file, track_progress, validation_as_error):
def xml2json(
entrypoint_xml, basedir, mount, output_file, track_progress, validation_as_error
):
"""Entrypoint XML: The top-level XML file for the PDF definition."""
try:
import uproot
Expand All @@ -47,6 +58,7 @@ def xml2json(entrypoint_xml, basedir, output_file, track_progress, validation_as
spec = readxml.parse(
entrypoint_xml,
basedir,
mounts=mount,
track_progress=track_progress,
validation_as_error=validation_as_error,
)
Expand Down
87 changes: 67 additions & 20 deletions src/pyhf/readxml.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
from pyhf import schema
from pyhf import compat
from pyhf import exceptions
from __future__ import annotations

import logging

from pathlib import Path
import os
from typing import TYPE_CHECKING, Callable, Iterable, Tuple, Union, IO
import xml.etree.ElementTree as ET
from pathlib import Path

import numpy as np
import tqdm
import uproot

from pyhf import compat
from pyhf import exceptions
from pyhf import schema

log = logging.getLogger(__name__)

if TYPE_CHECKING:
PathOrStr = Union[str, os.PathLike[str]]
else:
PathOrStr = Union[str, "os.PathLike[str]"]

__FILECACHE__ = {}
MountPathType = Iterable[Tuple[Path, Path]]

__all__ = [
"clear_filecache",
Expand All @@ -31,6 +41,20 @@ def __dir__():
return __all__


def resolver_factory(rootdir: Path, mounts: MountPathType) -> Callable[[str], Path]:
def resolver(filename: str) -> Path:
path = Path(filename)
for host_path, mount_path in mounts:
# NB: path.parents doesn't include the path itself, which might be
# a directory as well, so check that edge case
if mount_path == path or mount_path in path.parents:
path = host_path.joinpath(path.relative_to(mount_path))
break
return rootdir.joinpath(path)

return resolver


def extract_error(hist):
"""
Determine the bin uncertainties for a histogram.
Expand All @@ -50,14 +74,14 @@ def extract_error(hist):
return np.sqrt(variance).tolist()


def import_root_histogram(rootdir, filename, path, name, filecache=None):
def import_root_histogram(resolver, filename, path, name, filecache=None):
global __FILECACHE__
filecache = filecache or __FILECACHE__

# strip leading slashes as uproot doesn't use "/" for top-level
path = path or ''
path = path.strip('/')
fullpath = str(Path(rootdir).joinpath(filename))
fullpath = str(resolver(filename))
if fullpath not in filecache:
f = uproot.open(fullpath)
keys = set(f.keys(cycle=False))
Expand All @@ -79,15 +103,15 @@ def import_root_histogram(rootdir, filename, path, name, filecache=None):


def process_sample(
sample, rootdir, inputfile, histopath, channel_name, track_progress=False
sample, resolver, inputfile, histopath, channel_name, track_progress=False
):
if 'InputFile' in sample.attrib:
inputfile = sample.attrib.get('InputFile')
if 'HistoPath' in sample.attrib:
histopath = sample.attrib.get('HistoPath')
histoname = sample.attrib['HistoName']

data, err = import_root_histogram(rootdir, inputfile, histopath, histoname)
data, err = import_root_histogram(resolver, inputfile, histopath, histoname)

parameter_configs = []
modifiers = []
Expand Down Expand Up @@ -131,13 +155,13 @@ def process_sample(
parameter_configs.append(parameter_config)
elif modtag.tag == 'HistoSys':
lo, _ = import_root_histogram(
rootdir,
resolver,
modtag.attrib.get('HistoFileLow', inputfile),
modtag.attrib.get('HistoPathLow', ''),
modtag.attrib['HistoNameLow'],
)
hi, _ = import_root_histogram(
rootdir,
resolver,
modtag.attrib.get('HistoFileHigh', inputfile),
modtag.attrib.get('HistoPathHigh', ''),
modtag.attrib['HistoNameHigh'],
Expand All @@ -154,7 +178,7 @@ def process_sample(
staterr = err
else:
extstat, _ = import_root_histogram(
rootdir,
resolver,
modtag.attrib.get('HistoFile', inputfile),
modtag.attrib.get('HistoPath', ''),
modtag.attrib['HistoName'],
Expand All @@ -177,7 +201,7 @@ def process_sample(
modtag.attrib['Name'],
)
shapesys_data, _ = import_root_histogram(
rootdir,
resolver,
modtag.attrib.get('InputFile', inputfile),
modtag.attrib.get('HistoPath', ''),
modtag.attrib['HistoName'],
Expand Down Expand Up @@ -205,18 +229,18 @@ def process_sample(
}


def process_data(sample, rootdir, inputfile, histopath):
def process_data(sample, resolver, inputfile, histopath):
if 'InputFile' in sample.attrib:
inputfile = sample.attrib.get('InputFile')
if 'HistoPath' in sample.attrib:
histopath = sample.attrib.get('HistoPath')
histoname = sample.attrib['HistoName']

data, _ = import_root_histogram(rootdir, inputfile, histopath, histoname)
data, _ = import_root_histogram(resolver, inputfile, histopath, histoname)
return data


def process_channel(channelxml, rootdir, track_progress=False):
def process_channel(channelxml, resolver, track_progress=False):
channel = channelxml.getroot()

inputfile = channel.attrib.get('InputFile')
Expand All @@ -230,7 +254,7 @@ def process_channel(channelxml, rootdir, track_progress=False):

data = channel.findall('Data')
if data:
parsed_data = process_data(data[0], rootdir, inputfile, histopath)
parsed_data = process_data(data[0], resolver, inputfile, histopath)
else:
raise RuntimeError(f"Channel {channel_name} is missing data. See issue #1911.")

Expand All @@ -239,7 +263,7 @@ def process_channel(channelxml, rootdir, track_progress=False):
for sample in samples:
samples.set_description(f" - sample {sample.attrib.get('Name')}")
result = process_sample(
sample, rootdir, inputfile, histopath, channel_name, track_progress
sample, resolver, inputfile, histopath, channel_name, track_progress
)
channel_parameter_configs.extend(result.pop('parameter_configs'))
results.append(result)
Expand Down Expand Up @@ -343,20 +367,43 @@ def dedupe_parameters(parameters):
return list({v['name']: v for v in parameters}.values())


def parse(configfile, rootdir, track_progress=False, validation_as_error=True):
def parse(
configfile: PathOrStr | IO[bytes] | IO[str],
rootdir: PathOrStr,
mounts: MountPathType | None = None,
track_progress: bool = False,
validation_as_error: bool = True,
):
"""
Parse the ``configfile`` with respect to the ``rootdir``.
Args:
configfile (:class:`pathlib.Path` or :obj:`str` or file object): The top-level XML config file to parse.
rootdir (:class:`pathlib.Path` or :obj:`str`): The path to the working directory for interpreting relative paths in the configuration.
mounts (:obj:`None` or :obj:`list` of 2-:obj:`tuple` of :class:`pathlib.Path` objects): The first field is the local path to where files are located, the second field is the path where the file or directory are saved in the XML configuration. This is similar in spirit to Docker volume mounts. Default is ``None``.
track_progress (:obj:`bool`): Show the progress bar. Default is to hide the progress bar.
validation_as_error (:obj:`bool`): Throw an exception (``True``) or print a warning (``False``) if the resulting HistFactory JSON does not adhere to the schema. Default is to throw an exception.
Returns:
spec (:obj:`jsonable`): The newly built HistFactory JSON specification
"""
mounts = mounts or []
toplvl = ET.parse(configfile)
inputs = tqdm.tqdm(
[x.text for x in toplvl.findall('Input')],
unit='channel',
disable=not (track_progress),
)

# create a resolver for finding files
resolver = resolver_factory(Path(rootdir), mounts)

channels = {}
parameter_configs = []
for inp in inputs:
inputs.set_description(f'Processing {inp}')
channel, data, samples, channel_parameter_configs = process_channel(
ET.parse(Path(rootdir).joinpath(inp)), rootdir, track_progress
ET.parse(resolver(inp)), resolver, track_progress
)
channels[channel] = {'data': data, 'samples': samples}
parameter_configs.extend(channel_parameter_configs)
Expand Down
19 changes: 19 additions & 0 deletions src/pyhf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import yaml
import click
import hashlib
from gettext import gettext

import sys

Expand Down Expand Up @@ -41,6 +42,24 @@ def convert(self, value, param, ctx):
self.fail(f'{value:s} is not a valid equal-delimited string', param, ctx)


class VolumeMountPath(click.Path):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.name = f'{self.name}:{gettext("path")}'

def convert(self, value, param, ctx):
try:
path_host, path_mount = value.split(':')
except ValueError:
# too many values to unpack / not enough values to unpack
self.fail(f"{value!r} is not a valid colon-separated option", param, ctx)

return (
super().convert(path_host, param, ctx),
self.coerce_path_result(path_mount),
)


def digest(obj, algorithm='sha256'):
"""
Get the digest for the provided object. Note: object must be JSON-serializable.
Expand Down
30 changes: 28 additions & 2 deletions tests/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,10 @@ def test_process_normfactor_configs():

def test_import_histogram():
data, uncert = pyhf.readxml.import_root_histogram(
"validation/xmlimport_input/data", "example.root", "", "data"
lambda x: Path("validation/xmlimport_input/data").joinpath(x),
"example.root",
"",
"data",
)
assert data == [122.0, 112.0]
assert uncert == [11.045360565185547, 10.58300495147705]
Expand All @@ -120,7 +123,10 @@ def test_import_histogram():
def test_import_histogram_KeyError():
with pytest.raises(KeyError):
pyhf.readxml.import_root_histogram(
"validation/xmlimport_input/data", "example.root", "", "invalid_key"
lambda x: Path("validation/xmlimport_input/data").joinpath(x),
"example.root",
"",
"invalid_key",
)


Expand Down Expand Up @@ -498,3 +504,23 @@ def test_import_missingPOI(mocker, datadir):
assert 'Measurement GaussExample is missing POI specification' in str(
excinfo.value
)


def test_import_resolver(mocker):
rootdir = Path('/current/working/dir')
mounts = [(Path('/this/path/changed'), Path('/my/abs/path'))]
resolver = pyhf.readxml.resolver_factory(rootdir, mounts)

assert resolver('relative/path') == Path('/current/working/dir/relative/path')
assert resolver('relative/path/') == Path('/current/working/dir/relative/path')
assert resolver('relative/path/to/file.txt') == Path(
'/current/working/dir/relative/path/to/file.txt'
)
assert resolver('/absolute/path') == Path('/absolute/path')
assert resolver('/absolute/path/') == Path('/absolute/path')
assert resolver('/absolute/path/to/file.txt') == Path('/absolute/path/to/file.txt')
assert resolver('/my/abs/path') == Path('/this/path/changed')
assert resolver('/my/abs/path/') == Path('/this/path/changed')
assert resolver('/my/abs/path/to/file.txt') == Path(
'/this/path/changed/to/file.txt'
)
28 changes: 28 additions & 0 deletions tests/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,34 @@ def test_import_prepHistFactory_and_cls(tmpdir, script_runner):
assert 'CLs_exp' in d


def test_import_usingMounts(datadir, tmpdir, script_runner):
data = datadir.joinpath("xmlimport_absolutePaths")

temp = tmpdir.join("parsed_output.json")
command = f'pyhf xml2json --hide-progress -v {data}:/absolute/path/to -v {data}:/another/absolute/path/to --output-file {temp.strpath:s} {data.joinpath("config/example.xml")}'

ret = script_runner.run(*shlex.split(command))
assert ret.success
assert ret.stdout == ''
assert ret.stderr == ''

parsed_xml = json.loads(temp.read())
spec = {'channels': parsed_xml['channels']}
pyhf.schema.validate(spec, 'model.json')


def test_import_usingMounts_badDelimitedPaths(datadir, tmpdir, script_runner):
data = datadir.joinpath("xmlimport_absolutePaths")

temp = tmpdir.join("parsed_output.json")
command = f'pyhf xml2json --hide-progress -v {data}::/absolute/path/to -v {data}/another/absolute/path/to --output-file {temp.strpath:s} {data.joinpath("config/example.xml")}'

ret = script_runner.run(*shlex.split(command))
assert not ret.success
assert ret.stdout == ''
assert 'is not a valid colon-separated option' in ret.stderr


@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "pytorch", "jax"])
def test_fit_backend_option(tmpdir, script_runner, backend):
temp = tmpdir.join("parsed_output.json")
Expand Down
Loading

0 comments on commit 97664bc

Please sign in to comment.