Skip to content

Commit

Permalink
Remove cohort extractor support in v4
Browse files Browse the repository at this point in the history
  • Loading branch information
alarthast committed Oct 4, 2024
1 parent 82e2182 commit 2c5cb81
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 10 deletions.
31 changes: 21 additions & 10 deletions pipeline/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
validate_databuilder_outputs,
validate_glob_pattern,
validate_no_kwargs,
validate_not_cohort_extractor_action,
validate_type,
)

Expand Down Expand Up @@ -238,7 +239,7 @@ def is_database_action(self) -> bool:
class Pipeline:
version: float
actions: dict[str, Action]
expectations: Expectations
expectations: Expectations | None

@classmethod
def build(
Expand All @@ -262,13 +263,18 @@ def build(
raise ValidationError(
f"`version` must be a number between 1 and {LATEST_VERSION}"
)
feat = get_feature_flags_for_version(version)

validate_type(actions, dict, "Project `actions` section")
actions = {
action_id: Action.build(action_id, **action_config)
for action_id, action_config in actions.items()
}

if feat.REMOVE_SUPPORT_FOR_COHORT_EXTRACTOR:
for config in actions.values():
validate_not_cohort_extractor_action(config)

seen: dict[Command, list[str]] = defaultdict(list)
for name, config in actions.items():
run = config.run
Expand All @@ -278,7 +284,7 @@ def build(
)
seen[run].append(name)

if get_feature_flags_for_version(version).UNIQUE_OUTPUT_PATH:
if feat.UNIQUE_OUTPUT_PATH:
# find duplicate paths defined in the outputs section
seen_files = []
for config in actions.values():
Expand All @@ -298,19 +304,24 @@ def build(
f"Action `{a.action_id}` references an unknown action in its `needs` list: {n}"
)

feat = get_feature_flags_for_version(version)
if feat.EXPECTATIONS_POPULATION:
if feat.REMOVE_SUPPORT_FOR_COHORT_EXTRACTOR:
if expectations is not None:
raise ValidationError(
"Project includes `expectations` section, which is not supported in this version"
)
elif feat.EXPECTATIONS_POPULATION:
if expectations is None:
raise ValidationError("Project must include `expectations` section")
else:
expectations = {"population_size": 1000}

validate_type(expectations, dict, "Project `expectations` section")
if "population_size" not in expectations:
raise ValidationError(
"Project `expectations` section must include `population_size` section",
)
expectations = Expectations.build(**expectations)
if expectations is not None:
validate_type(expectations, dict, "Project `expectations` section")
if "population_size" not in expectations:
raise ValidationError(
"Project `expectations` section must include `population_size` section",
)
expectations = Expectations.build(**expectations)

return cls(version, actions, expectations)

Expand Down
7 changes: 7 additions & 0 deletions pipeline/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ def validate_glob_pattern(pattern: str, privacy_level: str) -> None:
raise InvalidPatternError("is an absolute path")


def validate_not_cohort_extractor_action(action: Action) -> None:
if action.run.parts[0].startswith("cohortextractor"):
raise ValidationError(
f"Action {action.action_id} uses cohortextractor actions, which are not supported in this version."
)


def validate_cohortextractor_outputs(action_id: str, action: Action) -> None:
"""
Check cohortextractor's output config is valid for this command
Expand Down
33 changes: 33 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,23 @@ def test_action_extraction_command_with_one_outputs():
assert len(outputs.values()) == 1


def test_cohortextractor_actions_not_used_after_v3():
data = {
"version": "4",
"actions": {
"generate_cohort": {
"run": "cohortextractor:latest generate_cohort",
"outputs": {
"highly_sensitive": {"cohort": "output/input.csv"},
},
},
},
}
msg = "uses cohortextractor actions, which are not supported in this version."
with pytest.raises(ValidationError, match=msg):
Pipeline.build(**data)


def test_command_properties():
data = {
"version": 1,
Expand Down Expand Up @@ -168,6 +185,22 @@ def test_expectations_before_v3_has_a_default_set():
assert config.expectations.population_size == 1000


def test_expectations_does_not_exist_after_v3():
data = {
"version": 4,
"expectations": {},
"actions": {
"generate_dataset": {
"run": "ehrql:v1 generate-dataset args --output output/dataset.csv.gz",
"outputs": {"highly_sensitive": {"dataset": "output/dataset.csv.gz"}},
}
},
}
msg = "Project includes `expectations` section"
with pytest.raises(ValidationError, match=msg):
Pipeline.build(**data)


def test_expectations_exists_for_v3():
# our logic for this is custom so ensure it works as expected
data = {
Expand Down

0 comments on commit 2c5cb81

Please sign in to comment.