diff --git a/README.md b/README.md index 2085de8..79e7236 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ Remember to replace `[version]` with [a deciles-charts version][4]: generate_deciles_charts: run: > deciles-charts:[version] - --input_dir output + --input-files output/measure_*.csv --output_dir output needs: [generate_measures] outputs: diff --git a/analysis/deciles_charts.py b/analysis/deciles_charts.py index e6916bb..85bc09c 100644 --- a/analysis/deciles_charts.py +++ b/analysis/deciles_charts.py @@ -1,4 +1,5 @@ import argparse +import glob import pathlib import re @@ -17,18 +18,12 @@ def _get_group_by(measure_table): return list(measure_table.columns[:-4]) -def get_measure_tables(path): - if not path.is_dir(): - raise AttributeError() - - for sub_path in path.iterdir(): - if not sub_path.is_file(): - continue - - measure_fname_match = re.match(MEASURE_FNAME_REGEX, sub_path.name) +def get_measure_tables(input_files): + for input_file in input_files: + measure_fname_match = re.match(MEASURE_FNAME_REGEX, input_file.name) if measure_fname_match is not None: # The `date` column is assigned by the measures framework. - measure_table = pandas.read_csv(sub_path, parse_dates=["date"]) + measure_table = pandas.read_csv(input_file, parse_dates=["date"]) # We can reconstruct the parameters passed to `Measure` without # the study definition. @@ -52,13 +47,21 @@ def write_deciles_chart(deciles_chart, path): deciles_chart.savefig(path) +def get_path(*args): + return pathlib.Path(*args).resolve() + + +def match_paths(pattern): + return [get_path(x) for x in glob.glob(pattern)] + + def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( - "--input-dir", + "--input-files", required=True, - type=pathlib.Path, - help="Path to the input directory", + type=match_paths, + help="Glob pattern for matching one or more input files", ) parser.add_argument( "--output-dir", @@ -71,10 +74,10 @@ def parse_args(): def main(): args = parse_args() - input_dir = args.input_dir + input_files = args.input_files output_dir = args.output_dir - for measure_table in get_measure_tables(input_dir): + for measure_table in get_measure_tables(input_files): measure_table = drop_zero_denominator_rows(measure_table) chart = get_deciles_chart(measure_table) id_ = measure_table.attrs["id"] diff --git a/project.yaml b/project.yaml index 3add6ba..2d24814 100644 --- a/project.yaml +++ b/project.yaml @@ -27,7 +27,7 @@ actions: generate_deciles_charts: run: > python:latest analysis/deciles_charts.py - --input-dir output + --input-files output/measure_*.csv --output-dir output needs: [generate_measures] outputs: diff --git a/tests/test_deciles_charts.py b/tests/test_deciles_charts.py index de9413b..a32af5f 100644 --- a/tests/test_deciles_charts.py +++ b/tests/test_deciles_charts.py @@ -8,23 +8,11 @@ class TestGetMeasureTables: - def test_path_is_not_dir(self, tmp_path): - tmp_file = tmp_path / "measure_sbp_by_practice.csv" - tmp_file.touch() - with pytest.raises(AttributeError): - next(deciles_charts.get_measure_tables(tmp_file)) - - def test_no_recurse(self, tmp_path): - tmp_sub_path = tmp_path / "measures" - tmp_sub_path.mkdir() - with pytest.raises(StopIteration): - next(deciles_charts.get_measure_tables(tmp_path)) - def test_input_table(self, tmp_path): tmp_file = tmp_path / "input_2019-01-01.csv" tmp_file.touch() with pytest.raises(StopIteration): - next(deciles_charts.get_measure_tables(tmp_path)) + next(deciles_charts.get_measure_tables([tmp_path])) def test_measure_table(self, tmp_path): # arrange @@ -38,10 +26,11 @@ def test_measure_table(self, tmp_path): } ) measure_table_in["date"] = pandas.to_datetime(measure_table_in["date"]) - measure_table_in.to_csv(tmp_path / "measure_sbp_by_practice.csv", index=False) + input_file = tmp_path / "measure_sbp_by_practice.csv" + measure_table_in.to_csv(input_file, index=False) # act - measure_table_out = next(deciles_charts.get_measure_tables(tmp_path)) + measure_table_out = next(deciles_charts.get_measure_tables([input_file])) # assert testing.assert_frame_equal(measure_table_out, measure_table_in) @@ -90,18 +79,35 @@ def test_parse_args(tmp_path, monkeypatch): "sys.argv", [ "deciles_charts.py", - "--input-dir", - "input", + "--input-files", + "input/measure_*.csv", "--output-dir", "output", ], ) - (tmp_path / "input").mkdir() + + input_dir = tmp_path / "input" + input_dir.mkdir() + + input_files = [] + for input_file_name in [ + "measure_has_sbp_event_by_stp_code.csv", + "measure_has_sbp_event_by_stp_code_2021-01-01.csv", + "measure_has_sbp_event_by_stp_code_2021-02-01.csv", + "measure_has_sbp_event_by_stp_code_2021-03-01.csv", + "measure_has_sbp_event_by_stp_code_2021-04-01.csv", + "measure_has_sbp_event_by_stp_code_2021-05-01.csv", + "measure_has_sbp_event_by_stp_code_2021-06-01.csv", + ]: + input_file = input_dir / input_file_name + input_file.touch() + input_files.append(input_file) + (tmp_path / "output").mkdir() # act args = deciles_charts.parse_args() # assert - assert args.input_dir == pathlib.Path("input") + assert sorted(args.input_files) == sorted(input_files) assert args.output_dir == pathlib.Path("output")