Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specify an input glob pattern #16

Merged
merged 6 commits into from
Apr 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Remember to replace `[version]` with [a deciles-charts version][4]:
generate_deciles_charts:
run: >
deciles-charts:[version]
--input_dir output
--input-files output/measure_*.csv
--output_dir output
needs: [generate_measures]
outputs:
Expand Down
35 changes: 19 additions & 16 deletions analysis/deciles_charts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import glob
import pathlib
import re

Expand All @@ -17,18 +18,12 @@ def _get_group_by(measure_table):
return list(measure_table.columns[:-4])


def get_measure_tables(path):
if not path.is_dir():
raise AttributeError()

for sub_path in path.iterdir():
if not sub_path.is_file():
continue

measure_fname_match = re.match(MEASURE_FNAME_REGEX, sub_path.name)
def get_measure_tables(input_files):
for input_file in input_files:
measure_fname_match = re.match(MEASURE_FNAME_REGEX, input_file.name)
if measure_fname_match is not None:
# The `date` column is assigned by the measures framework.
measure_table = pandas.read_csv(sub_path, parse_dates=["date"])
measure_table = pandas.read_csv(input_file, parse_dates=["date"])

# We can reconstruct the parameters passed to `Measure` without
# the study definition.
Expand All @@ -52,29 +47,37 @@ def write_deciles_chart(deciles_chart, path):
deciles_chart.savefig(path)


def get_path(*args):
return pathlib.Path(*args).resolve()


def match_paths(pattern):
return [get_path(x) for x in glob.glob(pattern)]


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input-dir",
"--input-files",
required=True,
type=pathlib.Path,
help="Path to the input directory",
type=match_paths,
help="Glob pattern for matching one or more input files",
)
parser.add_argument(
"--output-dir",
required=True,
type=pathlib.Path,
type=get_path,
help="Path to the output directory",
)
return parser.parse_args()


def main():
args = parse_args()
input_dir = args.input_dir
input_files = args.input_files
output_dir = args.output_dir

for measure_table in get_measure_tables(input_dir):
for measure_table in get_measure_tables(input_files):
measure_table = drop_zero_denominator_rows(measure_table)
chart = get_deciles_chart(measure_table)
id_ = measure_table.attrs["id"]
Expand Down
2 changes: 1 addition & 1 deletion project.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ actions:
generate_deciles_charts:
run: >
python:latest analysis/deciles_charts.py
--input-dir output
--input-files output/measure_*.csv
--output-dir output
needs: [generate_measures]
outputs:
Expand Down
111 changes: 63 additions & 48 deletions tests/test_deciles_charts.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,48 @@
import pandas
import pytest
from pandas import testing

from analysis import deciles_charts


class TestGetMeasureTables:
def test_path_is_not_dir(self, tmp_path):
tmp_file = tmp_path / "measure_sbp_by_practice.csv"
tmp_file.touch()
with pytest.raises(AttributeError):
next(deciles_charts.get_measure_tables(tmp_file))

def test_no_recurse(self, tmp_path):
tmp_sub_path = tmp_path / "measures"
tmp_sub_path.mkdir()
with pytest.raises(StopIteration):
next(deciles_charts.get_measure_tables(tmp_path))

def test_input_table(self, tmp_path):
tmp_file = tmp_path / "input_2019-01-01.csv"
tmp_file.touch()
with pytest.raises(StopIteration):
next(deciles_charts.get_measure_tables(tmp_path))

def test_measure_table(self, tmp_path):
# arrange
measure_table_in = pandas.DataFrame(
{
"practice": [1], # group_by
"has_sbp_event": [1], # numerator
"population": [1], # denominator
"value": [1], # assigned by the measures framework
"date": ["2021-01-01"], # assigned by the measures framework
}
)
measure_table_in["date"] = pandas.to_datetime(measure_table_in["date"])
measure_table_in.to_csv(tmp_path / "measure_sbp_by_practice.csv", index=False)

# act
measure_table_out = next(deciles_charts.get_measure_tables(tmp_path))

# assert
testing.assert_frame_equal(measure_table_out, measure_table_in)
assert measure_table_out.attrs["id"] == "sbp_by_practice"
assert measure_table_out.attrs["denominator"] == "population"
assert measure_table_out.attrs["group_by"] == ["practice"]
def test_get_measure_tables(tmp_path):
# For each measure, the measures framework writes a csv for each week/month, adding
# the date as a suffix to the file name; and a csv for all weeks/months, adding the
# date to a column in the file. We define a "measure table" as the latter, because
# it's easier to work with one file, than with many files/file names. However, it's
# hard to write a glob pattern that matches the latter but not the former, so
# `get_measure_tables` filters `input_files`.

# arrange
# this is a csv for a week/month
input_file_1 = tmp_path / "measure_sbp_by_practice_2021-01-01.csv"
input_file_1.touch()

# this is a csv for all weeks/months
measure_table_in = pandas.DataFrame(
{
"practice": [1], # group_by
"has_sbp_event": [1], # numerator
"population": [1], # denominator
"value": [1], # assigned by the measures framework
"date": ["2021-01-01"], # assigned by the measures framework
}
)
measure_table_in["date"] = pandas.to_datetime(measure_table_in["date"])
input_file_2 = tmp_path / "measure_sbp_by_practice.csv"
measure_table_in.to_csv(input_file_2, index=False)

# act
measure_tables_out = list(
deciles_charts.get_measure_tables([input_file_1, input_file_2])
)

# assert
assert len(measure_tables_out) == 1
measure_table_out = measure_tables_out[0]
testing.assert_frame_equal(measure_table_out, measure_table_in)
assert measure_table_out.attrs["id"] == "sbp_by_practice"
assert measure_table_out.attrs["denominator"] == "population"
assert measure_table_out.attrs["group_by"] == ["practice"]


def test_drop_zero_denominator_rows():
Expand Down Expand Up @@ -88,18 +85,36 @@ def test_parse_args(tmp_path, monkeypatch):
"sys.argv",
[
"deciles_charts.py",
"--input-dir",
"input",
"--input-files",
"input/measure_*.csv",
"--output-dir",
"output",
],
)
(tmp_path / "input").mkdir()
(tmp_path / "output").mkdir()

input_dir = tmp_path / "input"
input_dir.mkdir()

input_files = []
for input_file_name in [
"measure_has_sbp_event_by_stp_code.csv",
"measure_has_sbp_event_by_stp_code_2021-01-01.csv",
"measure_has_sbp_event_by_stp_code_2021-02-01.csv",
"measure_has_sbp_event_by_stp_code_2021-03-01.csv",
"measure_has_sbp_event_by_stp_code_2021-04-01.csv",
"measure_has_sbp_event_by_stp_code_2021-05-01.csv",
"measure_has_sbp_event_by_stp_code_2021-06-01.csv",
]:
input_file = input_dir / input_file_name
input_file.touch()
input_files.append(input_file)

output_dir = tmp_path / "output"
output_dir.mkdir()

# act
args = deciles_charts.parse_args()

# assert
args.input_dir == "input"
args.output_dir == "output"
assert sorted(args.input_files) == sorted(input_files)
assert args.output_dir == output_dir