From b4bf07dc64f5784ba3af528a5a600989f971a714 Mon Sep 17 00:00:00 2001 From: Peter Inglesby Date: Fri, 27 Sep 2024 10:22:01 +0100 Subject: [PATCH 01/20] Remove types They'll get in the way and we'll reinstate them later --- pipeline/models.py | 75 ++++++++++++++++------------------------------ pipeline/types.py | 33 -------------------- 2 files changed, 25 insertions(+), 83 deletions(-) delete mode 100644 pipeline/types.py diff --git a/pipeline/models.py b/pipeline/models.py index a89c595..458cb1c 100644 --- a/pipeline/models.py +++ b/pipeline/models.py @@ -2,14 +2,13 @@ import re import shlex from collections import defaultdict -from typing import Any, Dict, Iterable, List, Optional, Set, TypedDict +from typing import Any, Dict, List, Optional from pydantic import BaseModel, root_validator, validator from .constants import RUN_ALL_COMMAND from .exceptions import InvalidPatternError from .features import LATEST_VERSION, get_feature_flags_for_version -from .types import RawOutputs, RawPipeline from .validation import ( assert_valid_glob_pattern, validate_cohortextractor_outputs, @@ -29,7 +28,7 @@ } -def is_database_action(args: List[str]) -> bool: +def is_database_action(args): """ By default actions do not have database access, but certain trusted actions require it """ @@ -54,7 +53,7 @@ class Expectations(BaseModel): population_size: int @validator("population_size", pre=True) - def validate_population_size(cls, population_size: str) -> int: + def validate_population_size(cls, population_size): try: return int(population_size) except (TypeError, ValueError): @@ -68,11 +67,11 @@ class Outputs(BaseModel): moderately_sensitive: Optional[Dict[str, str]] minimally_sensitive: Optional[Dict[str, str]] - def __len__(self) -> int: + def __len__(self): return len(self.dict(exclude_unset=True)) @root_validator() - def at_least_one_output(cls, outputs: Dict[str, str]) -> Dict[str, str]: + def at_least_one_output(cls, outputs): if not any(outputs.values()): raise ValueError( f"must specify at least one output of: {', '.join(outputs)}" @@ -81,7 +80,7 @@ def at_least_one_output(cls, outputs: Dict[str, str]) -> Dict[str, str]: return outputs @root_validator(pre=True) - def validate_output_filenames_are_valid(cls, outputs: RawOutputs) -> RawOutputs: + def validate_output_filenames_are_valid(cls, outputs): # we use pre=True here so that we only get the outputs specified in the # input data. With Optional[…] wrapped fields pydantic will set None # for us and that just makes the logic a little fiddler with no @@ -105,20 +104,20 @@ class Config: frozen = True @property - def args(self) -> str: + def args(self): return " ".join(self.parts[1:]) @property - def name(self) -> str: + def name(self): # parts[0] with version split off return self.parts[0].split(":")[0] @property - def parts(self) -> List[str]: + def parts(self): return shlex.split(self.raw) @property - def version(self) -> str: + def version(self): # parts[0] with name split off return self.parts[0].split(":")[1] @@ -131,7 +130,7 @@ class Action(BaseModel): dummy_data_file: Optional[pathlib.Path] @validator("run", pre=True) - def parse_run_string(cls, run: str) -> Command: + def parse_run_string(cls, run): parts = shlex.split(run) name, _, version = parts[0].partition(":") @@ -143,33 +142,17 @@ def parse_run_string(cls, run: str) -> Command: return Command(raw=run) @property - def is_database_action(self) -> bool: + def is_database_action(self): return is_database_action(self.run.parts) -class PartiallyValidatedPipeline(TypedDict): - """ - A custom type to type-check the values in "post" root validators - - A root_validator with pre=False (or no kwargs) runs after the values have - been ingested already, and the `values` arg is a dictionary of model types. - - Note: This is defined here so we don't have to deal with forward reference - types. - """ - - version: float - expectations: Expectations - actions: Dict[str, Action] - - class Pipeline(BaseModel): version: float expectations: Expectations actions: Dict[str, Action] @property - def all_actions(self) -> List[str]: + def all_actions(self): """ Get all actions for this Pipeline instance @@ -180,9 +163,7 @@ def all_actions(self) -> List[str]: return [action for action in self.actions.keys() if action != RUN_ALL_COMMAND] @root_validator() - def validate_actions( - cls, values: PartiallyValidatedPipeline - ) -> PartiallyValidatedPipeline: + def validate_actions(cls, values): # TODO: move to Action when we move name onto it validators = { cohortextractor_pat: validate_cohortextractor_outputs, @@ -196,7 +177,7 @@ def validate_actions( return values @root_validator(pre=True) - def validate_expectations_per_version(cls, values: RawPipeline) -> RawPipeline: + def validate_expectations_per_version(cls, values): """Ensure the expectations key exists for version 3 onwards""" try: version = float(values["version"]) @@ -224,9 +205,7 @@ def validate_expectations_per_version(cls, values: RawPipeline) -> RawPipeline: return values @root_validator() - def validate_outputs_per_version( - cls, values: PartiallyValidatedPipeline - ) -> PartiallyValidatedPipeline: + def validate_outputs_per_version(cls, values): """ Ensure outputs are unique for version 2 onwards @@ -259,7 +238,7 @@ def validate_outputs_per_version( return values @root_validator(pre=True) - def validate_actions_run(cls, values: RawPipeline) -> RawPipeline: + def validate_actions_run(cls, values): # TODO: move to Action when we move name onto it for action_id, config in values.get("actions", {}).items(): if config["run"] == "": @@ -271,8 +250,8 @@ def validate_actions_run(cls, values: RawPipeline) -> RawPipeline: return values @validator("actions") - def validate_unique_commands(cls, actions: Dict[str, Action]) -> Dict[str, Action]: - seen: Dict[Command, List[str]] = defaultdict(list) + def validate_unique_commands(cls, actions): + seen = defaultdict(list) for name, config in actions.items(): run = config.run if run in seen: @@ -284,9 +263,7 @@ def validate_unique_commands(cls, actions: Dict[str, Action]) -> Dict[str, Actio return actions @validator("actions") - def validate_needs_are_comma_delimited( - cls, actions: Dict[str, Action] - ) -> Dict[str, Action]: + def validate_needs_are_comma_delimited(cls, actions): space_delimited = {} for name, action in actions.items(): # find needs definitions with spaces in them @@ -297,9 +274,7 @@ def validate_needs_are_comma_delimited( if not space_delimited: return actions - def iter_incorrect_needs( - space_delimited: Dict[str, List[str]] - ) -> Iterable[str]: + def iter_incorrect_needs(space_delimited): for name, needs in space_delimited.items(): yield f"Action: {name}" for need in needs: @@ -313,7 +288,7 @@ def iter_incorrect_needs( raise ValueError("\n".join(msg)) @validator("actions") - def validate_needs_exist(cls, actions: Dict[str, Action]) -> Dict[str, Action]: + def validate_needs_exist(cls, actions): missing = {} for name, action in actions.items(): unknown_needs = set(action.needs) - set(actions) @@ -323,7 +298,7 @@ def validate_needs_exist(cls, actions: Dict[str, Action]) -> Dict[str, Action]: if not missing: return actions - def iter_missing_needs(missing: Dict[str, Set[str]]) -> Iterable[str]: + def iter_missing_needs(missing): for name, needs in missing.items(): yield f"Action: {name}" for need in needs: @@ -336,7 +311,7 @@ def iter_missing_needs(missing: Dict[str, Set[str]]) -> Iterable[str]: raise ValueError("\n".join(msg)) @root_validator(pre=True) - def validate_version_exists(cls, values: RawPipeline) -> RawPipeline: + def validate_version_exists(cls, values): """ Ensure the version key exists. @@ -354,7 +329,7 @@ def validate_version_exists(cls, values: RawPipeline) -> RawPipeline: ) @validator("version", pre=True) - def validate_version_value(cls, value: str) -> float: + def validate_version_value(cls, value): try: return float(value) except (TypeError, ValueError): diff --git a/pipeline/types.py b/pipeline/types.py deleted file mode 100644 index aa359e5..0000000 --- a/pipeline/types.py +++ /dev/null @@ -1,33 +0,0 @@ -""" -Custom types to type check the raw input data - -When loading data from YAML we get dictionaries which pydantic attempts to -validate. Some of our validation is done via custom methods using the raw -dictionary data. -""" - -from __future__ import annotations - -import pathlib -from typing import Any, Dict, TypedDict - - -RawOutputs = Dict[str, Dict[str, str]] - - -class RawAction(TypedDict): - config: dict[Any, Any] | None - run: str - needs: list[str] | None - outputs: RawOutputs - dummy_data_file: pathlib.Path | None - - -class RawExpectations(TypedDict): - population_size: str | int | None - - -class RawPipeline(TypedDict): - version: str | float | int - expectations: RawExpectations - actions: dict[str, RawAction] From d79bd23f7531c619bacbaa7f20296fbc20de8d3b Mon Sep 17 00:00:00 2001 From: Peter Inglesby Date: Fri, 27 Sep 2024 10:22:01 +0100 Subject: [PATCH 02/20] Remove pydantic --- pipeline/exceptions.py | 4 + pipeline/legacy.py | 2 +- pipeline/main.py | 6 +- pipeline/models.py | 240 +++++++++++++++++++---------------------- pipeline/outputs.py | 2 +- pipeline/validation.py | 10 +- pyproject.toml | 4 - requirements.prod.txt | 4 - tests/test_main.py | 4 +- tests/test_models.py | 12 +-- 10 files changed, 132 insertions(+), 156 deletions(-) diff --git a/pipeline/exceptions.py b/pipeline/exceptions.py index 5befe71..313aace 100644 --- a/pipeline/exceptions.py +++ b/pipeline/exceptions.py @@ -8,3 +8,7 @@ class InvalidPatternError(ProjectValidationError): class YAMLError(Exception): pass + + +class ValidationError(Exception): + pass diff --git a/pipeline/legacy.py b/pipeline/legacy.py index 2ea83cf..7902164 100644 --- a/pipeline/legacy.py +++ b/pipeline/legacy.py @@ -7,6 +7,6 @@ def get_all_output_patterns_from_project_file(project_file: str) -> list[str]: config = load_pipeline(project_file) all_patterns = set() for action in config.actions.values(): - for patterns in action.outputs.dict(exclude_unset=True).values(): + for patterns in action.outputs.dict().values(): all_patterns.update(patterns.values()) return list(all_patterns) diff --git a/pipeline/main.py b/pipeline/main.py index 824a78d..e261be6 100644 --- a/pipeline/main.py +++ b/pipeline/main.py @@ -2,9 +2,7 @@ from pathlib import Path -import pydantic - -from .exceptions import ProjectValidationError, YAMLError +from .exceptions import ProjectValidationError, ValidationError, YAMLError from .loading import parse_yaml_file from .models import Pipeline @@ -28,7 +26,7 @@ def load_pipeline(pipeline_config: str | Path, filename: str | None = None) -> P # validate try: return Pipeline(**parsed_data) - except pydantic.ValidationError as exc: + except ValidationError as exc: raise ProjectValidationError( f"Invalid project: {filename or ''}\n{exc}" ) from exc diff --git a/pipeline/models.py b/pipeline/models.py index 458cb1c..5effb51 100644 --- a/pipeline/models.py +++ b/pipeline/models.py @@ -1,13 +1,9 @@ -import pathlib import re import shlex from collections import defaultdict -from typing import Any, Dict, List, Optional - -from pydantic import BaseModel, root_validator, validator from .constants import RUN_ALL_COMMAND -from .exceptions import InvalidPatternError +from .exceptions import InvalidPatternError, ValidationError from .features import LATEST_VERSION, get_feature_flags_for_version from .validation import ( assert_valid_glob_pattern, @@ -49,59 +45,71 @@ def is_database_action(args): return args[1] in db_commands -class Expectations(BaseModel): - population_size: int - - @validator("population_size", pre=True) - def validate_population_size(cls, population_size): +class Expectations: + def __init__(self, population_size): try: - return int(population_size) + self.population_size = int(population_size) except (TypeError, ValueError): - raise ValueError( + raise ValidationError( "Project expectations population size must be a number", ) -class Outputs(BaseModel): - highly_sensitive: Optional[Dict[str, str]] - moderately_sensitive: Optional[Dict[str, str]] - minimally_sensitive: Optional[Dict[str, str]] +class Outputs: + def __init__( + self, + highly_sensitive=None, + moderately_sensitive=None, + minimally_sensitive=None, + **kwargs, + ): + self.highly_sensitive = highly_sensitive + self.moderately_sensitive = moderately_sensitive + self.minimally_sensitive = minimally_sensitive + + self.at_least_one_output() + self.validate_output_filenames_are_valid() def __len__(self): - return len(self.dict(exclude_unset=True)) + return len(self.dict()) + + def dict(self): + d = { + k: getattr(self, k) + for k in [ + "highly_sensitive", + "moderately_sensitive", + "minimally_sensitive", + ] + } + return {k: v for k, v in d.items() if v is not None} - @root_validator() - def at_least_one_output(cls, outputs): - if not any(outputs.values()): - raise ValueError( - f"must specify at least one output of: {', '.join(outputs)}" + def at_least_one_output(self): + if not self.dict(): + raise ValidationError( + f"must specify at least one output of: {', '.join(vars(self))}" ) - return outputs - - @root_validator(pre=True) - def validate_output_filenames_are_valid(cls, outputs): - # we use pre=True here so that we only get the outputs specified in the - # input data. With Optional[…] wrapped fields pydantic will set None - # for us and that just makes the logic a little fiddler with no - # benefit. - for privacy_level, output in outputs.items(): + def validate_output_filenames_are_valid(self): + for privacy_level, output in self.dict().items(): for output_id, filename in output.items(): try: assert_valid_glob_pattern(filename, privacy_level) except InvalidPatternError as e: - raise ValueError(f"Output path {filename} is invalid: {e}") + raise ValidationError(f"Output path {filename} is invalid: {e}") - return outputs +class Command: + def __init__(self, raw): + self.raw = raw -class Command(BaseModel): - raw: str # original string + def __eq__(self, other): + if not isinstance(other, Command): # pragma: no cover + return NotImplemented + return self.raw == other.raw - class Config: - # this makes Command hashable, which for some reason due to the - # Action.parse_run_string works, pydantic requires. - frozen = True + def __hash__(self): + return hash(self.raw) @property def args(self): @@ -122,20 +130,20 @@ def version(self): return self.parts[0].split(":")[1] -class Action(BaseModel): - config: Optional[Dict[Any, Any]] = None - run: Command - needs: List[str] = [] - outputs: Outputs - dummy_data_file: Optional[pathlib.Path] +class Action: + def __init__(self, outputs, run, needs=None, config=None, dummy_data_file=None): + self.outputs = Outputs(**outputs) + self.run = self.parse_run_string(run) + self.needs = needs or [] + self.config = config + self.dummy_data_file = dummy_data_file - @validator("run", pre=True) - def parse_run_string(cls, run): + def parse_run_string(self, run): parts = shlex.split(run) name, _, version = parts[0].partition(":") if not version: - raise ValueError( + raise ValidationError( f"{name} must have a version specified (e.g. {name}:0.5.2)", ) @@ -146,10 +154,24 @@ def is_database_action(self): return is_database_action(self.run.parts) -class Pipeline(BaseModel): - version: float - expectations: Expectations - actions: Dict[str, Action] +class Pipeline: + def __init__(self, version=None, actions=None, expectations=None): + self.validate_version_exists(version) + self.version = self.validate_version_value(version) + + self.validate_actions_run(actions) + self.actions = { + action_id: Action(**action_config) + for action_id, action_config in actions.items() + } + self.validate_actions() + self.validate_needs_are_comma_delimited() + self.validate_needs_exist() + self.validate_unique_commands() + self.validate_outputs_per_version() + + expectations = self.validate_expectations_per_version(expectations) + self.expectations = Expectations(**expectations) @property def all_actions(self): @@ -162,117 +184,84 @@ def all_actions(self): """ return [action for action in self.actions.keys() if action != RUN_ALL_COMMAND] - @root_validator() - def validate_actions(cls, values): + def validate_actions(self): # TODO: move to Action when we move name onto it validators = { cohortextractor_pat: validate_cohortextractor_outputs, databuilder_pat: validate_databuilder_outputs, } - for action_id, config in values.get("actions", {}).items(): + for action_id, config in self.actions.items(): for cmd, validator_func in validators.items(): if cmd.match(config.run.raw): validator_func(action_id, config) - return values - - @root_validator(pre=True) - def validate_expectations_per_version(cls, values): + def validate_expectations_per_version(self, expectations): """Ensure the expectations key exists for version 3 onwards""" - try: - version = float(values["version"]) - except (KeyError, TypeError, ValueError): - # this is handled in the validate_version_exists and - # validate_version_value validators - return values - - feat = get_feature_flags_for_version(version) + feat = get_feature_flags_for_version(self.version) if not feat.EXPECTATIONS_POPULATION: - # set the default here because pydantic doesn't seem to set it - # otherwise - values["expectations"] = {"population_size": 1000} - return values + return {"population_size": 1000} - if "expectations" not in values: - raise ValueError("Project must include `expectations` section") + if expectations is None: + raise ValidationError("Project must include `expectations` section") - if "population_size" not in values["expectations"]: - raise ValueError( + if "population_size" not in expectations: + raise ValidationError( "Project `expectations` section must include `population_size` section", ) - return values + return expectations - @root_validator() - def validate_outputs_per_version(cls, values): + def validate_outputs_per_version(self): """ Ensure outputs are unique for version 2 onwards We validate this on Pipeline so we can get the version """ - # we're not using pre=True in the validator so we can rely on the - # version and action keys being the correct type but we have to handle - # them not existing - if not (version := values.get("version")): - return values # handle missing version - - if (actions := values.get("actions")) is None: - return values # hand no actions - - feat = get_feature_flags_for_version(version) + feat = get_feature_flags_for_version(self.version) if not feat.UNIQUE_OUTPUT_PATH: - return values + return # find duplicate paths defined in the outputs section seen_files = [] - for config in actions.values(): - for output in config.outputs.dict(exclude_unset=True).values(): + for config in self.actions.values(): + for output in config.outputs.dict().values(): for filename in output.values(): if filename in seen_files: - raise ValueError(f"Output path {filename} is not unique") + raise ValidationError(f"Output path {filename} is not unique") seen_files.append(filename) - return values - - @root_validator(pre=True) - def validate_actions_run(cls, values): + def validate_actions_run(self, actions): # TODO: move to Action when we move name onto it - for action_id, config in values.get("actions", {}).items(): + for action_id, config in actions.items(): if config["run"] == "": # key is present but empty - raise ValueError( + raise ValidationError( f"run must have a value, {action_id} has an empty run key" ) - return values - - @validator("actions") - def validate_unique_commands(cls, actions): + def validate_unique_commands(self): seen = defaultdict(list) - for name, config in actions.items(): + for name, config in self.actions.items(): run = config.run if run in seen: - raise ValueError( + raise ValidationError( f"Action {name} has the same 'run' command as other actions: {' ,'.join(seen[run])}" ) seen[run].append(name) - return actions - - @validator("actions") - def validate_needs_are_comma_delimited(cls, actions): + def validate_needs_are_comma_delimited(self): space_delimited = {} - for name, action in actions.items(): + for name, action in self.actions.items(): # find needs definitions with spaces in them incorrect = [dep for dep in action.needs if " " in dep] if incorrect: space_delimited[name] = incorrect if not space_delimited: - return actions + return def iter_incorrect_needs(space_delimited): for name, needs in space_delimited.items(): @@ -285,18 +274,17 @@ def iter_incorrect_needs(space_delimited): *iter_incorrect_needs(space_delimited), ] - raise ValueError("\n".join(msg)) + raise ValidationError("\n".join(msg)) - @validator("actions") - def validate_needs_exist(cls, actions): + def validate_needs_exist(self): missing = {} - for name, action in actions.items(): - unknown_needs = set(action.needs) - set(actions) + for name, action in self.actions.items(): + unknown_needs = set(action.needs) - set(self.actions) if unknown_needs: missing[name] = unknown_needs if not missing: - return actions + return def iter_missing_needs(missing): for name, needs in missing.items(): @@ -308,31 +296,25 @@ def iter_missing_needs(missing): "One or more actions is referencing unknown actions in its needs list:", *iter_missing_needs(missing), ] - raise ValueError("\n".join(msg)) + raise ValidationError("\n".join(msg)) - @root_validator(pre=True) - def validate_version_exists(cls, values): + def validate_version_exists(self, version): """ Ensure the version key exists. - - This is a re-implementation of pydantic's field validation so we can - get a custom error message. This can be removed when we add a wrapper - around the models to generate more UI friendly error messages. """ - if "version" in values: - return values + if version is not None: + return - raise ValueError( + raise ValidationError( f"Project file must have a `version` attribute specifying which " f"version of the project configuration format it uses (current " f"latest version is {LATEST_VERSION})" ) - @validator("version", pre=True) - def validate_version_value(cls, value): + def validate_version_value(self, value): try: return float(value) except (TypeError, ValueError): - raise ValueError( + raise ValidationError( f"`version` must be a number between 1 and {LATEST_VERSION}" ) diff --git a/pipeline/outputs.py b/pipeline/outputs.py index 3699cc1..c76ec3a 100644 --- a/pipeline/outputs.py +++ b/pipeline/outputs.py @@ -23,5 +23,5 @@ def get_output_dirs(output_spec: Outputs) -> list[PurePosixPath]: def iter_all_outputs(output_spec: Outputs) -> Iterator[str]: - for group in output_spec.dict(exclude_unset=True).values(): + for group in output_spec.dict().values(): yield from group.values() diff --git a/pipeline/validation.py b/pipeline/validation.py index e74ae93..3027a46 100644 --- a/pipeline/validation.py +++ b/pipeline/validation.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING from .constants import LEVEL4_FILE_TYPES -from .exceptions import InvalidPatternError +from .exceptions import InvalidPatternError, ValidationError from .outputs import get_first_output_file, get_output_dirs @@ -72,7 +72,7 @@ def validate_cohortextractor_outputs(action_id: str, action: Action) -> None: # ensure we only have output level defined. num_output_levels = len(action.outputs) if num_output_levels != 1: - raise ValueError( + raise ValidationError( "A `generate_cohort` action must have exactly one output; " f"{action_id} had {num_output_levels}" ) @@ -90,7 +90,7 @@ def validate_cohortextractor_outputs(action_id: str, action: Action) -> None: arg == flag or arg.startswith(f"{flag}=") for arg in action.run.parts ) if not has_output_dir: - raise ValueError( + raise ValidationError( f"generate_cohort command should produce output in only one " f"directory, found {len(output_dirs)}:\n" + "\n".join([f" - {d}/" for d in output_dirs]) @@ -107,11 +107,11 @@ def validate_databuilder_outputs(action_id: str, action: Action) -> None: # TODO: should this be checking output _paths_ instead of levels? num_output_levels = len(action.outputs) if num_output_levels != 1: - raise ValueError( + raise ValidationError( "A `generate-dataset` action must have exactly one output; " f"{action_id} had {num_output_levels}" ) first_output_file = get_first_output_file(action.outputs) if first_output_file not in action.run.raw: - raise ValueError("--output in run command and outputs must match") + raise ValidationError("--output in run command and outputs must match") diff --git a/pyproject.toml b/pyproject.toml index 4a91f9e..3c5ae3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,6 @@ classifiers = [ ] requires-python = ">=3.8" dependencies = [ - "pydantic<2", "ruyaml", ] dynamic = ["version"] @@ -51,9 +50,6 @@ skip_glob = [".direnv", "venv", ".venv"] [tool.mypy] files = "pipeline" exclude = "^pipeline/__main__.py$" -plugins = [ - "pydantic.mypy", -] strict = true warn_redundant_casts = true warn_unused_ignores = true diff --git a/requirements.prod.txt b/requirements.prod.txt index d770e2e..fb5ad3c 100644 --- a/requirements.prod.txt +++ b/requirements.prod.txt @@ -6,12 +6,8 @@ # distro==1.9.0 # via ruyaml -pydantic==1.10.17 - # via opensafely-pipeline (pyproject.toml) ruyaml==0.91.0 # via opensafely-pipeline (pyproject.toml) -typing-extensions==4.12.2 - # via pydantic # The following packages are considered to be unsafe in a requirements file: setuptools==71.1.0 diff --git a/tests/test_main.py b/tests/test_main.py index 7e3f1c3..03aeae2 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,7 +1,7 @@ -import pydantic import pytest from pipeline import ProjectValidationError, load_pipeline +from pipeline.exceptions import ValidationError from pipeline.models import Pipeline @@ -67,4 +67,4 @@ def test_load_pipeline_with_project_error_raises_projectvalidationerror(): with pytest.raises(ProjectValidationError, match="Invalid project") as exc: load_pipeline(config) - assert isinstance(exc.value.__cause__, pydantic.ValidationError) + assert isinstance(exc.value.__cause__, ValidationError) diff --git a/tests/test_models.py b/tests/test_models.py index 922fd19..3e4eac3 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,7 +1,7 @@ import pytest -from pydantic import ValidationError from pipeline import load_pipeline +from pipeline.exceptions import ValidationError from pipeline.models import Pipeline @@ -126,7 +126,7 @@ def test_action_extraction_command_with_one_outputs(): config = Pipeline(**data) - outputs = config.actions["generate_cohort"].outputs.dict(exclude_unset=True) + outputs = config.actions["generate_cohort"].outputs.dict() assert len(outputs.values()) == 1 @@ -359,6 +359,10 @@ def test_pipeline_with_missing_or_none_version(): with pytest.raises(ValidationError, match=msg): Pipeline(**data) + with pytest.raises(ValidationError, match=msg): + data["version"] = None + Pipeline(**data) + def test_pipeline_with_non_numeric_version(): data = { @@ -376,10 +380,6 @@ def test_pipeline_with_non_numeric_version(): data["version"] = "test" Pipeline(**data) - with pytest.raises(ValidationError, match=msg): - data["version"] = None - Pipeline(**data) - def test_outputs_files_are_unique(): data = { From 50a5e54562c84864cb928b94fd1477224f9cb328 Mon Sep 17 00:00:00 2001 From: Peter Inglesby Date: Fri, 27 Sep 2024 10:22:01 +0100 Subject: [PATCH 03/20] Move validation out of __init__ methods --- pipeline/main.py | 2 +- pipeline/models.py | 168 ++++++++++++++++++++++++++++-------------- tests/test_models.py | 56 +++++++------- tests/test_outputs.py | 4 +- 4 files changed, 142 insertions(+), 88 deletions(-) diff --git a/pipeline/main.py b/pipeline/main.py index e261be6..a2ce221 100644 --- a/pipeline/main.py +++ b/pipeline/main.py @@ -25,7 +25,7 @@ def load_pipeline(pipeline_config: str | Path, filename: str | None = None) -> P # validate try: - return Pipeline(**parsed_data) + return Pipeline.build(**parsed_data) except ValidationError as exc: raise ProjectValidationError( f"Invalid project: {filename or ''}\n{exc}" diff --git a/pipeline/models.py b/pipeline/models.py index 5effb51..6e58f11 100644 --- a/pipeline/models.py +++ b/pipeline/models.py @@ -47,28 +47,54 @@ def is_database_action(args): class Expectations: def __init__(self, population_size): + self.population_size = population_size + + @classmethod + def build(cls, population_size=None, **kwargs): try: - self.population_size = int(population_size) + population_size = int(population_size) except (TypeError, ValueError): raise ValidationError( "Project expectations population size must be a number", ) + return cls(population_size) class Outputs: - def __init__( - self, + def __init__(self, highly_sensitive, moderately_sensitive, minimally_sensitive): + self.highly_sensitive = highly_sensitive + self.moderately_sensitive = moderately_sensitive + self.minimally_sensitive = minimally_sensitive + + @classmethod + def build( + cls, highly_sensitive=None, moderately_sensitive=None, minimally_sensitive=None, **kwargs, ): - self.highly_sensitive = highly_sensitive - self.moderately_sensitive = moderately_sensitive - self.minimally_sensitive = minimally_sensitive + if ( + highly_sensitive is None + and moderately_sensitive is None + and minimally_sensitive is None + ): + raise ValidationError( + f"must specify at least one output of: {', '.join(['highly_sensitive', 'moderately_sensitive', 'minimally_sensitive'])}" + ) - self.at_least_one_output() - self.validate_output_filenames_are_valid() + cls.validate_output_filenames_are_valid( + "highly_sensitive", + highly_sensitive, + ) + cls.validate_output_filenames_are_valid( + "moderately_sensitive", moderately_sensitive + ) + cls.validate_output_filenames_are_valid( + "minimally_sensitive", minimally_sensitive + ) + + return cls(highly_sensitive, moderately_sensitive, minimally_sensitive) def __len__(self): return len(self.dict()) @@ -84,19 +110,15 @@ def dict(self): } return {k: v for k, v in d.items() if v is not None} - def at_least_one_output(self): - if not self.dict(): - raise ValidationError( - f"must specify at least one output of: {', '.join(vars(self))}" - ) - - def validate_output_filenames_are_valid(self): - for privacy_level, output in self.dict().items(): - for output_id, filename in output.items(): - try: - assert_valid_glob_pattern(filename, privacy_level) - except InvalidPatternError as e: - raise ValidationError(f"Output path {filename} is invalid: {e}") + @classmethod + def validate_output_filenames_are_valid(cls, privacy_level, output): + if output is None: + return + for output_id, filename in output.items(): + try: + assert_valid_glob_pattern(filename, privacy_level) + except InvalidPatternError as e: + raise ValidationError(f"Output path {filename} is invalid: {e}") class Command: @@ -131,14 +153,30 @@ def version(self): class Action: - def __init__(self, outputs, run, needs=None, config=None, dummy_data_file=None): - self.outputs = Outputs(**outputs) - self.run = self.parse_run_string(run) - self.needs = needs or [] + def __init__(self, outputs, run, needs, config, dummy_data_file): + self.outputs = outputs + self.run = run + self.needs = needs self.config = config self.dummy_data_file = dummy_data_file - def parse_run_string(self, run): + @classmethod + def build( + cls, + outputs=None, + run=None, + needs=None, + config=None, + dummy_data_file=None, + **kwargs, + ): + outputs = Outputs.build(**outputs) + run = cls.parse_run_string(run) + needs = needs or [] + return cls(outputs, run, needs, config, dummy_data_file) + + @classmethod + def parse_run_string(cls, run): parts = shlex.split(run) name, _, version = parts[0].partition(":") @@ -155,23 +193,30 @@ def is_database_action(self): class Pipeline: - def __init__(self, version=None, actions=None, expectations=None): - self.validate_version_exists(version) - self.version = self.validate_version_value(version) - - self.validate_actions_run(actions) - self.actions = { - action_id: Action(**action_config) + def __init__(self, version, actions, expectations): + self.version = version + self.actions = actions + self.expectations = expectations + + @classmethod + def build(cls, version=None, actions=None, expectations=None, **kwargs): + cls.validate_version_exists(version) + version = cls.validate_version_value(version) + + cls.validate_actions_run(actions) + actions = { + action_id: Action.build(**action_config) for action_id, action_config in actions.items() } - self.validate_actions() - self.validate_needs_are_comma_delimited() - self.validate_needs_exist() - self.validate_unique_commands() - self.validate_outputs_per_version() + cls.validate_actions(actions) + cls.validate_needs_are_comma_delimited(actions) + cls.validate_needs_exist(actions) + cls.validate_unique_commands(actions) + cls.validate_outputs_per_version(version, actions) - expectations = self.validate_expectations_per_version(expectations) - self.expectations = Expectations(**expectations) + expectations = cls.validate_expectations_per_version(version, expectations) + expectations = Expectations.build(**expectations) + return cls(version, actions, expectations) @property def all_actions(self): @@ -184,20 +229,22 @@ def all_actions(self): """ return [action for action in self.actions.keys() if action != RUN_ALL_COMMAND] - def validate_actions(self): + @classmethod + def validate_actions(cls, actions): # TODO: move to Action when we move name onto it validators = { cohortextractor_pat: validate_cohortextractor_outputs, databuilder_pat: validate_databuilder_outputs, } - for action_id, config in self.actions.items(): + for action_id, config in actions.items(): for cmd, validator_func in validators.items(): if cmd.match(config.run.raw): validator_func(action_id, config) - def validate_expectations_per_version(self, expectations): + @classmethod + def validate_expectations_per_version(cls, version, expectations): """Ensure the expectations key exists for version 3 onwards""" - feat = get_feature_flags_for_version(self.version) + feat = get_feature_flags_for_version(version) if not feat.EXPECTATIONS_POPULATION: return {"population_size": 1000} @@ -212,20 +259,21 @@ def validate_expectations_per_version(self, expectations): return expectations - def validate_outputs_per_version(self): + @classmethod + def validate_outputs_per_version(cls, version, actions): """ Ensure outputs are unique for version 2 onwards We validate this on Pipeline so we can get the version """ - feat = get_feature_flags_for_version(self.version) + feat = get_feature_flags_for_version(version) if not feat.UNIQUE_OUTPUT_PATH: return # find duplicate paths defined in the outputs section seen_files = [] - for config in self.actions.values(): + for config in actions.values(): for output in config.outputs.dict().values(): for filename in output.values(): if filename in seen_files: @@ -233,7 +281,8 @@ def validate_outputs_per_version(self): seen_files.append(filename) - def validate_actions_run(self, actions): + @classmethod + def validate_actions_run(cls, actions): # TODO: move to Action when we move name onto it for action_id, config in actions.items(): if config["run"] == "": @@ -242,9 +291,10 @@ def validate_actions_run(self, actions): f"run must have a value, {action_id} has an empty run key" ) - def validate_unique_commands(self): + @classmethod + def validate_unique_commands(cls, actions): seen = defaultdict(list) - for name, config in self.actions.items(): + for name, config in actions.items(): run = config.run if run in seen: raise ValidationError( @@ -252,9 +302,10 @@ def validate_unique_commands(self): ) seen[run].append(name) - def validate_needs_are_comma_delimited(self): + @classmethod + def validate_needs_are_comma_delimited(cls, actions): space_delimited = {} - for name, action in self.actions.items(): + for name, action in actions.items(): # find needs definitions with spaces in them incorrect = [dep for dep in action.needs if " " in dep] if incorrect: @@ -276,10 +327,11 @@ def iter_incorrect_needs(space_delimited): raise ValidationError("\n".join(msg)) - def validate_needs_exist(self): + @classmethod + def validate_needs_exist(cls, actions): missing = {} - for name, action in self.actions.items(): - unknown_needs = set(action.needs) - set(self.actions) + for name, action in actions.items(): + unknown_needs = set(action.needs) - set(actions) if unknown_needs: missing[name] = unknown_needs @@ -298,7 +350,8 @@ def iter_missing_needs(missing): ] raise ValidationError("\n".join(msg)) - def validate_version_exists(self, version): + @classmethod + def validate_version_exists(cls, version): """ Ensure the version key exists. """ @@ -311,7 +364,8 @@ def validate_version_exists(self, version): f"latest version is {LATEST_VERSION})" ) - def validate_version_value(self, value): + @classmethod + def validate_version_value(cls, value): try: return float(value) except (TypeError, ValueError): diff --git a/tests/test_models.py b/tests/test_models.py index 3e4eac3..686525c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -19,7 +19,7 @@ def test_success(): }, } - Pipeline(**data) + Pipeline.build(**data) def test_action_has_a_version(): @@ -37,7 +37,7 @@ def test_action_has_a_version(): msg = "test must have a version specified" with pytest.raises(ValidationError, match=msg): - Pipeline(**data) + Pipeline.build(**data) def test_action_cohortextractor_multiple_outputs_with_output_flag(): @@ -56,7 +56,7 @@ def test_action_cohortextractor_multiple_outputs_with_output_flag(): }, } - run_command = Pipeline(**data).actions["generate_cohort"].run.raw + run_command = Pipeline.build(**data).actions["generate_cohort"].run.raw assert run_command == "cohortextractor:latest generate_cohort --output-dir=output" @@ -81,7 +81,7 @@ def test_action_cohortextractor_multiple_ouputs_without_output_flag(): "generate_cohort command should produce output in only one directory, found 2:" ) with pytest.raises(ValidationError, match=msg): - Pipeline(**data) + Pipeline.build(**data) @pytest.mark.parametrize( @@ -108,7 +108,7 @@ def test_action_extraction_command_with_multiple_outputs(image, command): msg = f"A `{command}` action must have exactly one output" with pytest.raises(ValidationError, match=msg): - Pipeline(**data) + Pipeline.build(**data) def test_action_extraction_command_with_one_outputs(): @@ -124,7 +124,7 @@ def test_action_extraction_command_with_one_outputs(): }, } - config = Pipeline(**data) + config = Pipeline.build(**data) outputs = config.actions["generate_cohort"].outputs.dict() assert len(outputs.values()) == 1 @@ -141,7 +141,7 @@ def test_command_properties(): }, } - action = Pipeline(**data).actions["generate_cohort"] + action = Pipeline.build(**data).actions["generate_cohort"] assert action.run.args == "generate_cohort another_arg" assert action.run.name == "cohortextractor" assert action.run.parts == [ @@ -164,7 +164,7 @@ def test_expectations_before_v3_has_a_default_set(): }, } - config = Pipeline(**data) + config = Pipeline.build(**data) assert config.expectations.population_size == 1000 @@ -183,7 +183,7 @@ def test_expectations_exists(): msg = "Project must include `expectations` section" with pytest.raises(ValidationError, match=msg): - Pipeline(**data) + Pipeline.build(**data) def test_expectations_population_size_exists(): @@ -200,7 +200,7 @@ def test_expectations_population_size_exists(): msg = "Project `expectations` section must include `population_size` section" with pytest.raises(ValidationError, match=msg): - Pipeline(**data) + Pipeline.build(**data) def test_expectations_population_size_is_a_number(): @@ -217,7 +217,7 @@ def test_expectations_population_size_is_a_number(): msg = "Project expectations population size must be a number" with pytest.raises(ValidationError, match=msg): - Pipeline(**data) + Pipeline.build(**data) def test_pipeline_all_actions(test_file): @@ -252,7 +252,7 @@ def test_pipeline_needs_success(): }, } - config = Pipeline(**data) + config = Pipeline.build(**data) assert config.actions["do_analysis"].needs == ["generate_cohort"] @@ -279,7 +279,7 @@ def test_pipeline_needs_with_non_comma_delimited_actions(): msg = "`needs` actions should be separated with commas. The following actions need fixing:" with pytest.raises(ValidationError, match=msg): - Pipeline(**data) + Pipeline.build(**data) def test_pipeline_needs_with_unknown_action(): @@ -298,7 +298,7 @@ def test_pipeline_needs_with_unknown_action(): match = "One or more actions is referencing unknown actions in its needs list" with pytest.raises(ValidationError, match=match): - Pipeline(**data) + Pipeline.build(**data) def test_pipeline_with_duplicated_action_run_commands(): @@ -322,7 +322,7 @@ def test_pipeline_with_duplicated_action_run_commands(): match = "Action action2 has the same 'run' command as other actions: action1" with pytest.raises(ValidationError, match=match): - Pipeline(**data) + Pipeline.build(**data) def test_pipeline_with_empty_run_command(): @@ -340,7 +340,7 @@ def test_pipeline_with_empty_run_command(): match = "run must have a value, action1 has an empty run key" with pytest.raises(ValidationError, match=match): - Pipeline(**data) + Pipeline.build(**data) def test_pipeline_with_missing_or_none_version(): @@ -357,11 +357,11 @@ def test_pipeline_with_missing_or_none_version(): msg = "Project file must have a `version` attribute" with pytest.raises(ValidationError, match=msg): - Pipeline(**data) + Pipeline.build(**data) with pytest.raises(ValidationError, match=msg): data["version"] = None - Pipeline(**data) + Pipeline.build(**data) def test_pipeline_with_non_numeric_version(): @@ -378,7 +378,7 @@ def test_pipeline_with_non_numeric_version(): with pytest.raises(ValidationError, match=msg): data["version"] = "test" - Pipeline(**data) + Pipeline.build(**data) def test_outputs_files_are_unique(): @@ -399,7 +399,7 @@ def test_outputs_files_are_unique(): msg = "Output path output/input.csv is not unique" with pytest.raises(ValidationError, match=msg): - Pipeline(**data) + Pipeline.build(**data) def test_outputs_duplicate_files_in_v1(): @@ -418,7 +418,7 @@ def test_outputs_duplicate_files_in_v1(): }, } - generate_cohort = Pipeline(**data).actions["generate_cohort"] + generate_cohort = Pipeline.build(**data).actions["generate_cohort"] cohort = generate_cohort.outputs.highly_sensitive["cohort"] test = generate_cohort.outputs.highly_sensitive["test"] @@ -431,7 +431,7 @@ def test_outputs_with_unknown_privacy_level(): with pytest.raises(ValidationError, match=msg): # no outputs - Pipeline( + Pipeline.build( **{ "version": 1, "actions": { @@ -444,7 +444,7 @@ def test_outputs_with_unknown_privacy_level(): ) with pytest.raises(ValidationError, match=msg): - Pipeline( + Pipeline.build( **{ "version": 1, "actions": { @@ -470,7 +470,7 @@ def test_outputs_with_invalid_pattern(): msg = "Output path test/foo is invalid:" with pytest.raises(ValidationError, match=msg): - Pipeline(**data) + Pipeline.build(**data) @pytest.mark.parametrize("image,tag", [("databuilder", "latest"), ("ehrql", "v0")]) @@ -485,7 +485,7 @@ def test_pipeline_databuilder_specifies_same_output(image, tag): }, } - Pipeline(**data) + Pipeline.build(**data) @pytest.mark.parametrize("image,tag", [("databuilder", "latest"), ("ehrql", "v0")]) @@ -502,7 +502,7 @@ def test_pipeline_databuilder_specifies_different_output(image, tag): msg = "--output in run command and outputs must match" with pytest.raises(ValidationError, match=msg): - Pipeline(**data) + Pipeline.build(**data) def test_pipeline_databuilder_recognizes_old_action_spelling(): @@ -519,7 +519,7 @@ def test_pipeline_databuilder_recognizes_old_action_spelling(): } with pytest.raises(ValidationError): - Pipeline(**data) + Pipeline.build(**data) @pytest.mark.parametrize( @@ -575,5 +575,5 @@ def test_action_is_database_action(name, run, is_database_action): }, } - action = Pipeline(**data).actions[name] + action = Pipeline.build(**data).actions[name] assert action.is_database_action == is_database_action diff --git a/tests/test_outputs.py b/tests/test_outputs.py index 5a02cbb..b727596 100644 --- a/tests/test_outputs.py +++ b/tests/test_outputs.py @@ -5,7 +5,7 @@ def test_get_output_dirs_with_duplicates(): - outputs = Outputs( + outputs = Outputs.build( highly_sensitive={ "a": "output/1a.csv", "b": "output/2a.csv", @@ -18,7 +18,7 @@ def test_get_output_dirs_with_duplicates(): def test_get_output_dirs_without_duplicates(): - outputs = Outputs( + outputs = Outputs.build( highly_sensitive={ "a": "1a/output.csv", "b": "2a/output.csv", From 18b14831d18408ece9b971400a93c4ec5f2409b8 Mon Sep 17 00:00:00 2001 From: Peter Inglesby Date: Fri, 27 Sep 2024 10:22:01 +0100 Subject: [PATCH 04/20] Use dataclasses --- pipeline/models.py | 47 +++++++++++++++++++++------------------------- 1 file changed, 21 insertions(+), 26 deletions(-) diff --git a/pipeline/models.py b/pipeline/models.py index 6e58f11..d4e7572 100644 --- a/pipeline/models.py +++ b/pipeline/models.py @@ -1,6 +1,9 @@ +import pathlib import re import shlex from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Dict, List, Optional from .constants import RUN_ALL_COMMAND from .exceptions import InvalidPatternError, ValidationError @@ -45,9 +48,9 @@ def is_database_action(args): return args[1] in db_commands +@dataclass(frozen=True) class Expectations: - def __init__(self, population_size): - self.population_size = population_size + population_size: int @classmethod def build(cls, population_size=None, **kwargs): @@ -60,11 +63,11 @@ def build(cls, population_size=None, **kwargs): return cls(population_size) +@dataclass(frozen=True) class Outputs: - def __init__(self, highly_sensitive, moderately_sensitive, minimally_sensitive): - self.highly_sensitive = highly_sensitive - self.moderately_sensitive = moderately_sensitive - self.minimally_sensitive = minimally_sensitive + highly_sensitive: Optional[Dict[str, str]] + moderately_sensitive: Optional[Dict[str, str]] + minimally_sensitive: Optional[Dict[str, str]] @classmethod def build( @@ -121,17 +124,9 @@ def validate_output_filenames_are_valid(cls, privacy_level, output): raise ValidationError(f"Output path {filename} is invalid: {e}") +@dataclass(frozen=True) class Command: - def __init__(self, raw): - self.raw = raw - - def __eq__(self, other): - if not isinstance(other, Command): # pragma: no cover - return NotImplemented - return self.raw == other.raw - - def __hash__(self): - return hash(self.raw) + raw: str @property def args(self): @@ -152,13 +147,13 @@ def version(self): return self.parts[0].split(":")[1] +@dataclass(frozen=True) class Action: - def __init__(self, outputs, run, needs, config, dummy_data_file): - self.outputs = outputs - self.run = run - self.needs = needs - self.config = config - self.dummy_data_file = dummy_data_file + outputs: Outputs + run: Command + needs: List[str] + config: Optional[Dict[Any, Any]] + dummy_data_file: Optional[pathlib.Path] @classmethod def build( @@ -192,11 +187,11 @@ def is_database_action(self): return is_database_action(self.run.parts) +@dataclass(frozen=True) class Pipeline: - def __init__(self, version, actions, expectations): - self.version = version - self.actions = actions - self.expectations = expectations + version: float + actions: Dict[str, Action] + expectations: Expectations @classmethod def build(cls, version=None, actions=None, expectations=None, **kwargs): From 0feaa17852ace344ab97762688191d519839fa10 Mon Sep 17 00:00:00 2001 From: Peter Inglesby Date: Fri, 27 Sep 2024 10:22:01 +0100 Subject: [PATCH 05/20] Reinstate typing --- pipeline/legacy.py | 2 +- pipeline/models.py | 114 +++++++++++++++++++++++++++------------------ 2 files changed, 69 insertions(+), 47 deletions(-) diff --git a/pipeline/legacy.py b/pipeline/legacy.py index 7902164..ff795d1 100644 --- a/pipeline/legacy.py +++ b/pipeline/legacy.py @@ -5,7 +5,7 @@ def get_all_output_patterns_from_project_file(project_file: str) -> list[str]: config = load_pipeline(project_file) - all_patterns = set() + all_patterns: set[str] = set() for action in config.actions.values(): for patterns in action.outputs.dict().values(): all_patterns.update(patterns.values()) diff --git a/pipeline/models.py b/pipeline/models.py index d4e7572..ca647b4 100644 --- a/pipeline/models.py +++ b/pipeline/models.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import pathlib import re import shlex from collections import defaultdict from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Iterable from .constants import RUN_ALL_COMMAND from .exceptions import InvalidPatternError, ValidationError @@ -27,7 +29,7 @@ } -def is_database_action(args): +def is_database_action(args: list[str]) -> bool: """ By default actions do not have database access, but certain trusted actions require it """ @@ -53,7 +55,11 @@ class Expectations: population_size: int @classmethod - def build(cls, population_size=None, **kwargs): + def build( + cls, + population_size: Any = None, + **kwargs: Any, + ) -> Expectations: try: population_size = int(population_size) except (TypeError, ValueError): @@ -65,18 +71,18 @@ def build(cls, population_size=None, **kwargs): @dataclass(frozen=True) class Outputs: - highly_sensitive: Optional[Dict[str, str]] - moderately_sensitive: Optional[Dict[str, str]] - minimally_sensitive: Optional[Dict[str, str]] + highly_sensitive: dict[str, str] | None + moderately_sensitive: dict[str, str] | None + minimally_sensitive: dict[str, str] | None @classmethod def build( cls, - highly_sensitive=None, - moderately_sensitive=None, - minimally_sensitive=None, - **kwargs, - ): + highly_sensitive: Any = None, + moderately_sensitive: Any = None, + minimally_sensitive: Any = None, + **kwargs: Any, + ) -> Outputs: if ( highly_sensitive is None and moderately_sensitive is None @@ -99,10 +105,10 @@ def build( return cls(highly_sensitive, moderately_sensitive, minimally_sensitive) - def __len__(self): + def __len__(self) -> int: return len(self.dict()) - def dict(self): + def dict(self) -> dict[str, dict[str, str]]: d = { k: getattr(self, k) for k in [ @@ -114,7 +120,9 @@ def dict(self): return {k: v for k, v in d.items() if v is not None} @classmethod - def validate_output_filenames_are_valid(cls, privacy_level, output): + def validate_output_filenames_are_valid( + cls, privacy_level: str, output: Any + ) -> None: if output is None: return for output_id, filename in output.items(): @@ -129,20 +137,20 @@ class Command: raw: str @property - def args(self): + def args(self) -> str: return " ".join(self.parts[1:]) @property - def name(self): + def name(self) -> str: # parts[0] with version split off return self.parts[0].split(":")[0] @property - def parts(self): + def parts(self) -> list[str]: return shlex.split(self.raw) @property - def version(self): + def version(self) -> str: # parts[0] with name split off return self.parts[0].split(":")[1] @@ -151,27 +159,27 @@ def version(self): class Action: outputs: Outputs run: Command - needs: List[str] - config: Optional[Dict[Any, Any]] - dummy_data_file: Optional[pathlib.Path] + needs: list[str] + config: dict[Any, Any] | None + dummy_data_file: pathlib.Path | None @classmethod def build( cls, - outputs=None, - run=None, - needs=None, - config=None, - dummy_data_file=None, - **kwargs, - ): + outputs: Any = None, + run: Any = None, + needs: Any = None, + config: Any = None, + dummy_data_file: Any = None, + **kwargs: Any, + ) -> Action: outputs = Outputs.build(**outputs) run = cls.parse_run_string(run) needs = needs or [] return cls(outputs, run, needs, config, dummy_data_file) @classmethod - def parse_run_string(cls, run): + def parse_run_string(cls, run: Any) -> Command: parts = shlex.split(run) name, _, version = parts[0].partition(":") @@ -183,18 +191,28 @@ def parse_run_string(cls, run): return Command(raw=run) @property - def is_database_action(self): + def is_database_action(self) -> bool: return is_database_action(self.run.parts) +Version = float +Actions = Dict[str, Action] + + @dataclass(frozen=True) class Pipeline: - version: float - actions: Dict[str, Action] + version: Version + actions: Actions expectations: Expectations @classmethod - def build(cls, version=None, actions=None, expectations=None, **kwargs): + def build( + cls, + version: Any = None, + actions: Any = None, + expectations: Any = None, + **kwargs: Any, + ) -> Pipeline: cls.validate_version_exists(version) version = cls.validate_version_value(version) @@ -214,7 +232,7 @@ def build(cls, version=None, actions=None, expectations=None, **kwargs): return cls(version, actions, expectations) @property - def all_actions(self): + def all_actions(self) -> list[str]: """ Get all actions for this Pipeline instance @@ -225,7 +243,7 @@ def all_actions(self): return [action for action in self.actions.keys() if action != RUN_ALL_COMMAND] @classmethod - def validate_actions(cls, actions): + def validate_actions(cls, actions: Actions) -> None: # TODO: move to Action when we move name onto it validators = { cohortextractor_pat: validate_cohortextractor_outputs, @@ -237,7 +255,9 @@ def validate_actions(cls, actions): validator_func(action_id, config) @classmethod - def validate_expectations_per_version(cls, version, expectations): + def validate_expectations_per_version( + cls, version: Version, expectations: Any + ) -> Any: """Ensure the expectations key exists for version 3 onwards""" feat = get_feature_flags_for_version(version) @@ -255,7 +275,7 @@ def validate_expectations_per_version(cls, version, expectations): return expectations @classmethod - def validate_outputs_per_version(cls, version, actions): + def validate_outputs_per_version(cls, version: Version, actions: Actions) -> None: """ Ensure outputs are unique for version 2 onwards @@ -277,7 +297,7 @@ def validate_outputs_per_version(cls, version, actions): seen_files.append(filename) @classmethod - def validate_actions_run(cls, actions): + def validate_actions_run(cls, actions: Any) -> None: # TODO: move to Action when we move name onto it for action_id, config in actions.items(): if config["run"] == "": @@ -287,8 +307,8 @@ def validate_actions_run(cls, actions): ) @classmethod - def validate_unique_commands(cls, actions): - seen = defaultdict(list) + def validate_unique_commands(cls, actions: Actions) -> None: + seen: dict[Command, list[str]] = defaultdict(list) for name, config in actions.items(): run = config.run if run in seen: @@ -298,7 +318,7 @@ def validate_unique_commands(cls, actions): seen[run].append(name) @classmethod - def validate_needs_are_comma_delimited(cls, actions): + def validate_needs_are_comma_delimited(cls, actions: Actions) -> None: space_delimited = {} for name, action in actions.items(): # find needs definitions with spaces in them @@ -309,7 +329,9 @@ def validate_needs_are_comma_delimited(cls, actions): if not space_delimited: return - def iter_incorrect_needs(space_delimited): + def iter_incorrect_needs( + space_delimited: dict[str, list[str]], + ) -> Iterable[str]: for name, needs in space_delimited.items(): yield f"Action: {name}" for need in needs: @@ -323,7 +345,7 @@ def iter_incorrect_needs(space_delimited): raise ValidationError("\n".join(msg)) @classmethod - def validate_needs_exist(cls, actions): + def validate_needs_exist(cls, actions: Actions) -> None: missing = {} for name, action in actions.items(): unknown_needs = set(action.needs) - set(actions) @@ -333,7 +355,7 @@ def validate_needs_exist(cls, actions): if not missing: return - def iter_missing_needs(missing): + def iter_missing_needs(missing: dict[str, set[str]]) -> Iterable[str]: for name, needs in missing.items(): yield f"Action: {name}" for need in needs: @@ -346,7 +368,7 @@ def iter_missing_needs(missing): raise ValidationError("\n".join(msg)) @classmethod - def validate_version_exists(cls, version): + def validate_version_exists(cls, version: Any) -> None: """ Ensure the version key exists. """ @@ -360,7 +382,7 @@ def validate_version_exists(cls, version): ) @classmethod - def validate_version_value(cls, value): + def validate_version_value(cls, value: Any) -> float: try: return float(value) except (TypeError, ValueError): From 3cf071de9c9415ec3020a9685eaf0674a5e99106 Mon Sep 17 00:00:00 2001 From: Peter Inglesby Date: Fri, 27 Sep 2024 10:22:01 +0100 Subject: [PATCH 06/20] Move version validation inline --- pipeline/models.py | 38 +++++++++++++------------------------- 1 file changed, 13 insertions(+), 25 deletions(-) diff --git a/pipeline/models.py b/pipeline/models.py index ca647b4..987d2f4 100644 --- a/pipeline/models.py +++ b/pipeline/models.py @@ -213,8 +213,19 @@ def build( expectations: Any = None, **kwargs: Any, ) -> Pipeline: - cls.validate_version_exists(version) - version = cls.validate_version_value(version) + if version is None: + raise ValidationError( + f"Project file must have a `version` attribute specifying which " + f"version of the project configuration format it uses (current " + f"latest version is {LATEST_VERSION})" + ) + + try: + version = float(version) + except (TypeError, ValueError): + raise ValidationError( + f"`version` must be a number between 1 and {LATEST_VERSION}" + ) cls.validate_actions_run(actions) actions = { @@ -366,26 +377,3 @@ def iter_missing_needs(missing: dict[str, set[str]]) -> Iterable[str]: *iter_missing_needs(missing), ] raise ValidationError("\n".join(msg)) - - @classmethod - def validate_version_exists(cls, version: Any) -> None: - """ - Ensure the version key exists. - """ - if version is not None: - return - - raise ValidationError( - f"Project file must have a `version` attribute specifying which " - f"version of the project configuration format it uses (current " - f"latest version is {LATEST_VERSION})" - ) - - @classmethod - def validate_version_value(cls, value: Any) -> float: - try: - return float(value) - except (TypeError, ValueError): - raise ValidationError( - f"`version` must be a number between 1 and {LATEST_VERSION}" - ) From dd3b9d56c498bc5cc084a71e0b652f0e0095a4a9 Mon Sep 17 00:00:00 2001 From: Peter Inglesby Date: Fri, 27 Sep 2024 10:22:01 +0100 Subject: [PATCH 07/20] Add expectations validation inline --- pipeline/models.py | 33 ++++++++++++--------------------- 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/pipeline/models.py b/pipeline/models.py index 987d2f4..5228eb3 100644 --- a/pipeline/models.py +++ b/pipeline/models.py @@ -238,8 +238,19 @@ def build( cls.validate_unique_commands(actions) cls.validate_outputs_per_version(version, actions) - expectations = cls.validate_expectations_per_version(version, expectations) + feat = get_feature_flags_for_version(version) + if feat.EXPECTATIONS_POPULATION: + if expectations is None: + raise ValidationError("Project must include `expectations` section") + else: + expectations = {"population_size": 1000} + + if "population_size" not in expectations: + raise ValidationError( + "Project `expectations` section must include `population_size` section", + ) expectations = Expectations.build(**expectations) + return cls(version, actions, expectations) @property @@ -265,26 +276,6 @@ def validate_actions(cls, actions: Actions) -> None: if cmd.match(config.run.raw): validator_func(action_id, config) - @classmethod - def validate_expectations_per_version( - cls, version: Version, expectations: Any - ) -> Any: - """Ensure the expectations key exists for version 3 onwards""" - feat = get_feature_flags_for_version(version) - - if not feat.EXPECTATIONS_POPULATION: - return {"population_size": 1000} - - if expectations is None: - raise ValidationError("Project must include `expectations` section") - - if "population_size" not in expectations: - raise ValidationError( - "Project `expectations` section must include `population_size` section", - ) - - return expectations - @classmethod def validate_outputs_per_version(cls, version: Version, actions: Actions) -> None: """ From 4255e77f6cb9e9920b70c0259e7c37f5264d8375 Mon Sep 17 00:00:00 2001 From: Peter Inglesby Date: Fri, 27 Sep 2024 10:22:01 +0100 Subject: [PATCH 08/20] Add Action.action_id attribute --- pipeline/models.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pipeline/models.py b/pipeline/models.py index 5228eb3..09a7b5d 100644 --- a/pipeline/models.py +++ b/pipeline/models.py @@ -157,6 +157,7 @@ def version(self) -> str: @dataclass(frozen=True) class Action: + action_id: str outputs: Outputs run: Command needs: list[str] @@ -166,6 +167,7 @@ class Action: @classmethod def build( cls, + action_id: str, outputs: Any = None, run: Any = None, needs: Any = None, @@ -176,7 +178,7 @@ def build( outputs = Outputs.build(**outputs) run = cls.parse_run_string(run) needs = needs or [] - return cls(outputs, run, needs, config, dummy_data_file) + return cls(action_id, outputs, run, needs, config, dummy_data_file) @classmethod def parse_run_string(cls, run: Any) -> Command: @@ -229,7 +231,7 @@ def build( cls.validate_actions_run(actions) actions = { - action_id: Action.build(**action_config) + action_id: Action.build(action_id, **action_config) for action_id, action_config in actions.items() } cls.validate_actions(actions) From 2702c1bbc716d87d771064ff9e8e34c7c580ce02 Mon Sep 17 00:00:00 2001 From: Peter Inglesby Date: Fri, 27 Sep 2024 10:22:01 +0100 Subject: [PATCH 09/20] Remove unnecessary indirection --- pipeline/models.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/pipeline/models.py b/pipeline/models.py index 09a7b5d..aab9ca0 100644 --- a/pipeline/models.py +++ b/pipeline/models.py @@ -269,14 +269,11 @@ def all_actions(self) -> list[str]: @classmethod def validate_actions(cls, actions: Actions) -> None: # TODO: move to Action when we move name onto it - validators = { - cohortextractor_pat: validate_cohortextractor_outputs, - databuilder_pat: validate_databuilder_outputs, - } for action_id, config in actions.items(): - for cmd, validator_func in validators.items(): - if cmd.match(config.run.raw): - validator_func(action_id, config) + if cohortextractor_pat.match(config.run.raw): + validate_cohortextractor_outputs(action_id, config) + if databuilder_pat.match(config.run.raw): + validate_databuilder_outputs(action_id, config) @classmethod def validate_outputs_per_version(cls, version: Version, actions: Actions) -> None: From 54b9707ee6358b4ba30e288966c846853bcfe964 Mon Sep 17 00:00:00 2001 From: Peter Inglesby Date: Fri, 27 Sep 2024 10:22:01 +0100 Subject: [PATCH 10/20] Move regexes nearer where they're used --- pipeline/models.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pipeline/models.py b/pipeline/models.py index aab9ca0..3f813c3 100644 --- a/pipeline/models.py +++ b/pipeline/models.py @@ -17,9 +17,6 @@ ) -cohortextractor_pat = re.compile(r"cohortextractor:\S+ generate_cohort") -databuilder_pat = re.compile(r"databuilder|ehrql:\S+ generate[-_]dataset") - # orderd by most common, going forwards DB_COMMANDS = { "ehrql": ("generate-dataset", "generate-measures"), @@ -270,9 +267,9 @@ def all_actions(self) -> list[str]: def validate_actions(cls, actions: Actions) -> None: # TODO: move to Action when we move name onto it for action_id, config in actions.items(): - if cohortextractor_pat.match(config.run.raw): + if re.match(r"cohortextractor:\S+ generate_cohort", config.run.raw): validate_cohortextractor_outputs(action_id, config) - if databuilder_pat.match(config.run.raw): + if re.match(r"databuilder|ehrql:\S+ generate[-_]dataset", config.run.raw): validate_databuilder_outputs(action_id, config) @classmethod From b3aaa3f09b5fb07db201498752a005cdb03df29a Mon Sep 17 00:00:00 2001 From: Peter Inglesby Date: Fri, 27 Sep 2024 10:22:01 +0100 Subject: [PATCH 11/20] Move some validation to Action.build --- pipeline/models.py | 37 +++++++++++++------------------------ 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/pipeline/models.py b/pipeline/models.py index 3f813c3..4a85eb1 100644 --- a/pipeline/models.py +++ b/pipeline/models.py @@ -173,12 +173,22 @@ def build( **kwargs: Any, ) -> Action: outputs = Outputs.build(**outputs) - run = cls.parse_run_string(run) + run = cls.parse_run_string(action_id, run) needs = needs or [] - return cls(action_id, outputs, run, needs, config, dummy_data_file) + action = cls(action_id, outputs, run, needs, config, dummy_data_file) + if re.match(r"cohortextractor:\S+ generate_cohort", run.raw): + validate_cohortextractor_outputs(action_id, action) + if re.match(r"databuilder|ehrql:\S+ generate[-_]dataset", run.raw): + validate_databuilder_outputs(action_id, action) + return action @classmethod - def parse_run_string(cls, run: Any) -> Command: + def parse_run_string(cls, action_id: str, run: Any) -> Command: + if run == "": + raise ValidationError( + f"run must have a value, {action_id} has an empty run key" + ) + parts = shlex.split(run) name, _, version = parts[0].partition(":") @@ -226,12 +236,10 @@ def build( f"`version` must be a number between 1 and {LATEST_VERSION}" ) - cls.validate_actions_run(actions) actions = { action_id: Action.build(action_id, **action_config) for action_id, action_config in actions.items() } - cls.validate_actions(actions) cls.validate_needs_are_comma_delimited(actions) cls.validate_needs_exist(actions) cls.validate_unique_commands(actions) @@ -263,15 +271,6 @@ def all_actions(self) -> list[str]: """ return [action for action in self.actions.keys() if action != RUN_ALL_COMMAND] - @classmethod - def validate_actions(cls, actions: Actions) -> None: - # TODO: move to Action when we move name onto it - for action_id, config in actions.items(): - if re.match(r"cohortextractor:\S+ generate_cohort", config.run.raw): - validate_cohortextractor_outputs(action_id, config) - if re.match(r"databuilder|ehrql:\S+ generate[-_]dataset", config.run.raw): - validate_databuilder_outputs(action_id, config) - @classmethod def validate_outputs_per_version(cls, version: Version, actions: Actions) -> None: """ @@ -294,16 +293,6 @@ def validate_outputs_per_version(cls, version: Version, actions: Actions) -> Non seen_files.append(filename) - @classmethod - def validate_actions_run(cls, actions: Any) -> None: - # TODO: move to Action when we move name onto it - for action_id, config in actions.items(): - if config["run"] == "": - # key is present but empty - raise ValidationError( - f"run must have a value, {action_id} has an empty run key" - ) - @classmethod def validate_unique_commands(cls, actions: Actions) -> None: seen: dict[Command, list[str]] = defaultdict(list) From b58f4f42021c8ce07da75085ed8b70cb5c760238 Mon Sep 17 00:00:00 2001 From: Peter Inglesby Date: Fri, 27 Sep 2024 10:22:01 +0100 Subject: [PATCH 12/20] Move more validation to Action.build We lose the ability to report more than one error at once, but the code is simpler. --- pipeline/models.py | 35 +++++++---------------------------- tests/test_models.py | 2 +- 2 files changed, 8 insertions(+), 29 deletions(-) diff --git a/pipeline/models.py b/pipeline/models.py index 4a85eb1..8c8cfc7 100644 --- a/pipeline/models.py +++ b/pipeline/models.py @@ -175,11 +175,18 @@ def build( outputs = Outputs.build(**outputs) run = cls.parse_run_string(action_id, run) needs = needs or [] + for n in needs: + if " " in n: + raise ValidationError( + f"`needs` actions should be separated with commas, but {action_id} needs `{n}`" + ) action = cls(action_id, outputs, run, needs, config, dummy_data_file) + if re.match(r"cohortextractor:\S+ generate_cohort", run.raw): validate_cohortextractor_outputs(action_id, action) if re.match(r"databuilder|ehrql:\S+ generate[-_]dataset", run.raw): validate_databuilder_outputs(action_id, action) + return action @classmethod @@ -240,7 +247,6 @@ def build( action_id: Action.build(action_id, **action_config) for action_id, action_config in actions.items() } - cls.validate_needs_are_comma_delimited(actions) cls.validate_needs_exist(actions) cls.validate_unique_commands(actions) cls.validate_outputs_per_version(version, actions) @@ -304,33 +310,6 @@ def validate_unique_commands(cls, actions: Actions) -> None: ) seen[run].append(name) - @classmethod - def validate_needs_are_comma_delimited(cls, actions: Actions) -> None: - space_delimited = {} - for name, action in actions.items(): - # find needs definitions with spaces in them - incorrect = [dep for dep in action.needs if " " in dep] - if incorrect: - space_delimited[name] = incorrect - - if not space_delimited: - return - - def iter_incorrect_needs( - space_delimited: dict[str, list[str]], - ) -> Iterable[str]: - for name, needs in space_delimited.items(): - yield f"Action: {name}" - for need in needs: - yield f" - {need}" - - msg = [ - "`needs` actions should be separated with commas. The following actions need fixing:", - *iter_incorrect_needs(space_delimited), - ] - - raise ValidationError("\n".join(msg)) - @classmethod def validate_needs_exist(cls, actions: Actions) -> None: missing = {} diff --git a/tests/test_models.py b/tests/test_models.py index 686525c..8984066 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -277,7 +277,7 @@ def test_pipeline_needs_with_non_comma_delimited_actions(): }, } - msg = "`needs` actions should be separated with commas. The following actions need fixing:" + msg = "`needs` actions should be separated with commas, but do_further_analysis needs `generate_cohort do_analysis`" with pytest.raises(ValidationError, match=msg): Pipeline.build(**data) From 46e3d2aefe25a44fadfe6e4cd85067abe0d7d4ef Mon Sep 17 00:00:00 2001 From: Peter Inglesby Date: Fri, 27 Sep 2024 10:22:01 +0100 Subject: [PATCH 13/20] More more validation inline Again, we lose the ability to report more than one error at once, but the code is simpler. --- pipeline/models.py | 33 ++++++++------------------------- tests/test_models.py | 2 +- 2 files changed, 9 insertions(+), 26 deletions(-) diff --git a/pipeline/models.py b/pipeline/models.py index 8c8cfc7..eb55200 100644 --- a/pipeline/models.py +++ b/pipeline/models.py @@ -5,7 +5,7 @@ import shlex from collections import defaultdict from dataclasses import dataclass -from typing import Any, Dict, Iterable +from typing import Any, Dict from .constants import RUN_ALL_COMMAND from .exceptions import InvalidPatternError, ValidationError @@ -247,10 +247,16 @@ def build( action_id: Action.build(action_id, **action_config) for action_id, action_config in actions.items() } - cls.validate_needs_exist(actions) cls.validate_unique_commands(actions) cls.validate_outputs_per_version(version, actions) + for a in actions.values(): + for n in a.needs: + if n not in actions: + raise ValidationError( + f"Action `{a.action_id}` references an unknown action in its `needs` list: {n}" + ) + feat = get_feature_flags_for_version(version) if feat.EXPECTATIONS_POPULATION: if expectations is None: @@ -309,26 +315,3 @@ def validate_unique_commands(cls, actions: Actions) -> None: f"Action {name} has the same 'run' command as other actions: {' ,'.join(seen[run])}" ) seen[run].append(name) - - @classmethod - def validate_needs_exist(cls, actions: Actions) -> None: - missing = {} - for name, action in actions.items(): - unknown_needs = set(action.needs) - set(actions) - if unknown_needs: - missing[name] = unknown_needs - - if not missing: - return - - def iter_missing_needs(missing: dict[str, set[str]]) -> Iterable[str]: - for name, needs in missing.items(): - yield f"Action: {name}" - for need in needs: - yield f" - {need}" - - msg = [ - "One or more actions is referencing unknown actions in its needs list:", - *iter_missing_needs(missing), - ] - raise ValidationError("\n".join(msg)) diff --git a/tests/test_models.py b/tests/test_models.py index 8984066..55e2465 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -296,7 +296,7 @@ def test_pipeline_needs_with_unknown_action(): }, } - match = "One or more actions is referencing unknown actions in its needs list" + match = "Action `action1` references an unknown action in its `needs` list: action2" with pytest.raises(ValidationError, match=match): Pipeline.build(**data) From c07a91f6c9b54c40e65ed6ca235d448b4ea07ebd Mon Sep 17 00:00:00 2001 From: Peter Inglesby Date: Fri, 27 Sep 2024 10:22:01 +0100 Subject: [PATCH 14/20] Move more validation inline --- pipeline/models.py | 57 ++++++++++++++++++---------------------------- 1 file changed, 22 insertions(+), 35 deletions(-) diff --git a/pipeline/models.py b/pipeline/models.py index eb55200..ef697f6 100644 --- a/pipeline/models.py +++ b/pipeline/models.py @@ -247,8 +247,28 @@ def build( action_id: Action.build(action_id, **action_config) for action_id, action_config in actions.items() } - cls.validate_unique_commands(actions) - cls.validate_outputs_per_version(version, actions) + + seen: dict[Command, list[str]] = defaultdict(list) + for name, config in actions.items(): + run = config.run + if run in seen: + raise ValidationError( + f"Action {name} has the same 'run' command as other actions: {' ,'.join(seen[run])}" + ) + seen[run].append(name) + + if get_feature_flags_for_version(version).UNIQUE_OUTPUT_PATH: + # find duplicate paths defined in the outputs section + seen_files = [] + for config in actions.values(): + for output in config.outputs.dict().values(): + for filename in output.values(): + if filename in seen_files: + raise ValidationError( + f"Output path {filename} is not unique" + ) + + seen_files.append(filename) for a in actions.values(): for n in a.needs: @@ -282,36 +302,3 @@ def all_actions(self) -> list[str]: than set operators as previously so we preserve the original order. """ return [action for action in self.actions.keys() if action != RUN_ALL_COMMAND] - - @classmethod - def validate_outputs_per_version(cls, version: Version, actions: Actions) -> None: - """ - Ensure outputs are unique for version 2 onwards - - We validate this on Pipeline so we can get the version - """ - - feat = get_feature_flags_for_version(version) - if not feat.UNIQUE_OUTPUT_PATH: - return - - # find duplicate paths defined in the outputs section - seen_files = [] - for config in actions.values(): - for output in config.outputs.dict().values(): - for filename in output.values(): - if filename in seen_files: - raise ValidationError(f"Output path {filename} is not unique") - - seen_files.append(filename) - - @classmethod - def validate_unique_commands(cls, actions: Actions) -> None: - seen: dict[Command, list[str]] = defaultdict(list) - for name, config in actions.items(): - run = config.run - if run in seen: - raise ValidationError( - f"Action {name} has the same 'run' command as other actions: {' ,'.join(seen[run])}" - ) - seen[run].append(name) From 70e7d4fed2421ae5e61124771644189794a1a9c7 Mon Sep 17 00:00:00 2001 From: Peter Inglesby Date: Fri, 27 Sep 2024 10:22:01 +0100 Subject: [PATCH 15/20] Rename function (It doesn't raise AssertionError) --- pipeline/models.py | 4 ++-- pipeline/validation.py | 2 +- tests/test_validation.py | 10 +++++----- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pipeline/models.py b/pipeline/models.py index ef697f6..567a21f 100644 --- a/pipeline/models.py +++ b/pipeline/models.py @@ -11,9 +11,9 @@ from .exceptions import InvalidPatternError, ValidationError from .features import LATEST_VERSION, get_feature_flags_for_version from .validation import ( - assert_valid_glob_pattern, validate_cohortextractor_outputs, validate_databuilder_outputs, + validate_glob_pattern, ) @@ -124,7 +124,7 @@ def validate_output_filenames_are_valid( return for output_id, filename in output.items(): try: - assert_valid_glob_pattern(filename, privacy_level) + validate_glob_pattern(filename, privacy_level) except InvalidPatternError as e: raise ValidationError(f"Output path {filename} is invalid: {e}") diff --git a/pipeline/validation.py b/pipeline/validation.py index 3027a46..4fafcde 100644 --- a/pipeline/validation.py +++ b/pipeline/validation.py @@ -13,7 +13,7 @@ from .models import Action -def assert_valid_glob_pattern(pattern: str, privacy_level: str) -> None: +def validate_glob_pattern(pattern: str, privacy_level: str) -> None: """ These patterns get converted into regular expressions and matched with a `find` command so there shouldn't be any possibility of a path diff --git a/tests/test_validation.py b/tests/test_validation.py index 5a6c5bc..83a61f4 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,12 +1,12 @@ import pytest from pipeline.exceptions import InvalidPatternError -from pipeline.validation import assert_valid_glob_pattern +from pipeline.validation import validate_glob_pattern -def test_assert_valid_glob_pattern(): - assert_valid_glob_pattern("foo/bar/*.txt", "highly_sensitive") - assert_valid_glob_pattern("foo/bar/*.txt", "moderately_sensitive") +def test_validate_glob_pattern(): + validate_glob_pattern("foo/bar/*.txt", "highly_sensitive") + validate_glob_pattern("foo/bar/*.txt", "moderately_sensitive") bad_patterns = [ ("/abs/path.txt", "highly_sensitive"), ("not//canonical.txt", "highly_sensitive"), @@ -25,4 +25,4 @@ def test_assert_valid_glob_pattern(): ] for pattern, sensitivity in bad_patterns: with pytest.raises(InvalidPatternError): - assert_valid_glob_pattern(pattern, sensitivity) + validate_glob_pattern(pattern, sensitivity) From 5a4d57af687e84787897c3a5b2d90713ac9de161 Mon Sep 17 00:00:00 2001 From: Peter Inglesby Date: Fri, 27 Sep 2024 10:22:01 +0100 Subject: [PATCH 16/20] Remove single-use type aliases --- pipeline/models.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/pipeline/models.py b/pipeline/models.py index 567a21f..451987d 100644 --- a/pipeline/models.py +++ b/pipeline/models.py @@ -5,7 +5,7 @@ import shlex from collections import defaultdict from dataclasses import dataclass -from typing import Any, Dict +from typing import Any from .constants import RUN_ALL_COMMAND from .exceptions import InvalidPatternError, ValidationError @@ -211,14 +211,10 @@ def is_database_action(self) -> bool: return is_database_action(self.run.parts) -Version = float -Actions = Dict[str, Action] - - @dataclass(frozen=True) class Pipeline: - version: Version - actions: Actions + version: float + actions: dict[str, Action] expectations: Expectations @classmethod From a741832e65af0dbde96622e804a009e3b712dbbb Mon Sep 17 00:00:00 2001 From: Peter Inglesby Date: Fri, 27 Sep 2024 10:22:01 +0100 Subject: [PATCH 17/20] Remove default value for require Action parameters --- pipeline/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pipeline/models.py b/pipeline/models.py index 451987d..4076481 100644 --- a/pipeline/models.py +++ b/pipeline/models.py @@ -165,8 +165,8 @@ class Action: def build( cls, action_id: str, - outputs: Any = None, - run: Any = None, + outputs: Any, + run: Any, needs: Any = None, config: Any = None, dummy_data_file: Any = None, From aa69d23343f4a4e836f700160a975ed18fb30380 Mon Sep 17 00:00:00 2001 From: Peter Inglesby Date: Fri, 27 Sep 2024 10:22:01 +0100 Subject: [PATCH 18/20] Add runtime type validation This came for free with pydantic. --- pipeline/models.py | 34 ++++++++-- pipeline/validation.py | 14 +++- tests/test_outputs.py | 2 + tests/test_type_validation.py | 116 ++++++++++++++++++++++++++++++++++ 4 files changed, 158 insertions(+), 8 deletions(-) create mode 100644 tests/test_type_validation.py diff --git a/pipeline/models.py b/pipeline/models.py index 4076481..d1daf00 100644 --- a/pipeline/models.py +++ b/pipeline/models.py @@ -14,6 +14,7 @@ validate_cohortextractor_outputs, validate_databuilder_outputs, validate_glob_pattern, + validate_type, ) @@ -75,6 +76,7 @@ class Outputs: @classmethod def build( cls, + action_id: str, highly_sensitive: Any = None, moderately_sensitive: Any = None, minimally_sensitive: Any = None, @@ -90,14 +92,13 @@ def build( ) cls.validate_output_filenames_are_valid( - "highly_sensitive", - highly_sensitive, + action_id, "highly_sensitive", highly_sensitive ) cls.validate_output_filenames_are_valid( - "moderately_sensitive", moderately_sensitive + action_id, "moderately_sensitive", moderately_sensitive ) cls.validate_output_filenames_are_valid( - "minimally_sensitive", minimally_sensitive + action_id, "minimally_sensitive", minimally_sensitive ) return cls(highly_sensitive, moderately_sensitive, minimally_sensitive) @@ -118,11 +119,13 @@ def dict(self) -> dict[str, dict[str, str]]: @classmethod def validate_output_filenames_are_valid( - cls, privacy_level: str, output: Any + cls, action_id: str, privacy_level: str, output: Any ) -> None: if output is None: return + validate_type(output, dict, f"`{privacy_level}` section for action {action_id}") for output_id, filename in output.items(): + validate_type(filename, str, f"`{output_id}` output for action {action_id}") try: validate_glob_pattern(filename, privacy_level) except InvalidPatternError as e: @@ -172,7 +175,22 @@ def build( dummy_data_file: Any = None, **kwargs: Any, ) -> Action: - outputs = Outputs.build(**outputs) + validate_type(outputs, dict, f"`outputs` section for action {action_id}") + validate_type(run, str, f"`run` section for action {action_id}") + validate_type( + needs, list, f"`needs` section for action {action_id}", optional=True + ) + validate_type( + config, dict, f"`config` section for action {action_id}", optional=True + ) + validate_type( + dummy_data_file, + str, + f"`dummy_data_file` section for action {action_id}", + optional=True, + ) + + outputs = Outputs.build(action_id=action_id, **outputs) run = cls.parse_run_string(action_id, run) needs = needs or [] for n in needs: @@ -190,7 +208,7 @@ def build( return action @classmethod - def parse_run_string(cls, action_id: str, run: Any) -> Command: + def parse_run_string(cls, action_id: str, run: str) -> Command: if run == "": raise ValidationError( f"run must have a value, {action_id} has an empty run key" @@ -239,6 +257,7 @@ def build( f"`version` must be a number between 1 and {LATEST_VERSION}" ) + validate_type(actions, dict, "Project `actions` section") actions = { action_id: Action.build(action_id, **action_config) for action_id, action_config in actions.items() @@ -280,6 +299,7 @@ def build( else: expectations = {"population_size": 1000} + validate_type(expectations, dict, "Project `expectations` section") if "population_size" not in expectations: raise ValidationError( "Project `expectations` section must include `population_size` section", diff --git a/pipeline/validation.py b/pipeline/validation.py index 4fafcde..e8bd3de 100644 --- a/pipeline/validation.py +++ b/pipeline/validation.py @@ -2,7 +2,7 @@ import posixpath from pathlib import Path, PurePosixPath, PureWindowsPath -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from .constants import LEVEL4_FILE_TYPES from .exceptions import InvalidPatternError, ValidationError @@ -13,6 +13,18 @@ from .models import Action +def validate_type(val: Any, exp_type: type, loc: str, optional: bool = False) -> None: + type_lookup: dict[type, str] = { + str: "string", + dict: "dictionary of key/value pairs", + list: "list", + } + if optional and val is None: + return + if not isinstance(val, exp_type): + raise ValidationError(f"{loc} must be a {type_lookup[exp_type]}") + + def validate_glob_pattern(pattern: str, privacy_level: str) -> None: """ These patterns get converted into regular expressions and matched diff --git a/tests/test_outputs.py b/tests/test_outputs.py index b727596..00014e0 100644 --- a/tests/test_outputs.py +++ b/tests/test_outputs.py @@ -6,6 +6,7 @@ def test_get_output_dirs_with_duplicates(): outputs = Outputs.build( + action_id="test", highly_sensitive={ "a": "output/1a.csv", "b": "output/2a.csv", @@ -19,6 +20,7 @@ def test_get_output_dirs_with_duplicates(): def test_get_output_dirs_without_duplicates(): outputs = Outputs.build( + action_id="test", highly_sensitive={ "a": "1a/output.csv", "b": "2a/output.csv", diff --git a/tests/test_type_validation.py b/tests/test_type_validation.py new file mode 100644 index 0000000..c35e662 --- /dev/null +++ b/tests/test_type_validation.py @@ -0,0 +1,116 @@ +import pytest + +from pipeline.exceptions import ValidationError +from pipeline.models import Pipeline + + +def test_missing_actions(): + with pytest.raises( + ValidationError, match="Project `actions` section must be a dictionary" + ): + Pipeline.build(version=3, expectations={"population_size": 10}) + + +def test_actions_incorrect_type(): + with pytest.raises( + ValidationError, match="Project `actions` section must be a dictionary" + ): + Pipeline.build(version=3, actions=[], expectations={"population_size": 10}) + + +def test_expectations_incorrect_type(): + with pytest.raises( + ValidationError, match="Project `expectations` section must be a dictionary" + ): + Pipeline.build(version=3, actions={}, expectations=[]) + + +def test_outputs_incorrect_type(): + with pytest.raises( + ValidationError, + match="`outputs` section for action action1 must be a dictionary", + ): + Pipeline.build( + version=3, + actions={"action1": {"outputs": [], "run": "test:v1"}}, + expectations={"population_size": 10}, + ) + + +def test_run_incorrect_type(): + with pytest.raises( + ValidationError, match="`run` section for action action1 must be a string" + ): + Pipeline.build( + version=3, + actions={"action1": {"outputs": {}, "run": ["test:v1"]}}, + expectations={"population_size": 10}, + ) + + +def test_needs_incorrect_type(): + with pytest.raises( + ValidationError, match="`needs` section for action action1 must be a list" + ): + Pipeline.build( + version=3, + actions={"action1": {"outputs": {}, "run": "test:v1", "needs": ""}}, + expectations={"population_size": 10}, + ) + + +def test_config_incorrect_type(): + with pytest.raises( + ValidationError, + match="`config` section for action action1 must be a dictionary", + ): + Pipeline.build( + version=3, + actions={"action1": {"outputs": {}, "run": "test:v1", "config": []}}, + expectations={"population_size": 10}, + ) + + +def test_dummy_data_file_incorrect_type(): + with pytest.raises( + ValidationError, + match="`dummy_data_file` section for action action1 must be a string", + ): + Pipeline.build( + version=3, + actions={ + "action1": {"outputs": {}, "run": "test:v1", "dummy_data_file": []} + }, + expectations={"population_size": 10}, + ) + + +def test_output_files_incorrect_type(): + with pytest.raises( + ValidationError, + match="`highly_sensitive` section for action action1 must be a dictionary", + ): + Pipeline.build( + version=3, + actions={ + "action1": {"outputs": {"highly_sensitive": []}, "run": "test:v1"} + }, + expectations={"population_size": 10}, + ) + + +def test_output_filename_incorrect_type(): + with pytest.raises( + ValidationError, + match="`dataset` output for action action1 must be a string", + ): + Pipeline.build( + version=3, + actions={ + "action1": { + "outputs": {"highly_sensitive": {"dataset": {}}}, + "run": "test:v1", + } + }, + expectations={"population_size": 10}, + ) From b91f35ede9c2ba0ddc6470050cae8ae00e4e66ad Mon Sep 17 00:00:00 2001 From: Peter Inglesby Date: Fri, 27 Sep 2024 10:22:01 +0100 Subject: [PATCH 19/20] Add runtime validation that no extra parameters are provided I'd assumed that pydantic did this, but it turns out it doesn't. Feels like it's worth checking for, as extra parameters are likely to be a mistake. --- pipeline/models.py | 6 ++++ pipeline/validation.py | 5 +++ tests/test_type_validation.py | 60 +++++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+) diff --git a/pipeline/models.py b/pipeline/models.py index d1daf00..65fc257 100644 --- a/pipeline/models.py +++ b/pipeline/models.py @@ -14,6 +14,7 @@ validate_cohortextractor_outputs, validate_databuilder_outputs, validate_glob_pattern, + validate_no_kwargs, validate_type, ) @@ -58,6 +59,7 @@ def build( population_size: Any = None, **kwargs: Any, ) -> Expectations: + validate_no_kwargs(kwargs, "project `expectations` section") try: population_size = int(population_size) except (TypeError, ValueError): @@ -91,6 +93,8 @@ def build( f"must specify at least one output of: {', '.join(['highly_sensitive', 'moderately_sensitive', 'minimally_sensitive'])}" ) + validate_no_kwargs(kwargs, f"`outputs` section for action {action_id}") + cls.validate_output_filenames_are_valid( action_id, "highly_sensitive", highly_sensitive ) @@ -175,6 +179,7 @@ def build( dummy_data_file: Any = None, **kwargs: Any, ) -> Action: + validate_no_kwargs(kwargs, f"action {action_id}") validate_type(outputs, dict, f"`outputs` section for action {action_id}") validate_type(run, str, f"`run` section for action {action_id}") validate_type( @@ -243,6 +248,7 @@ def build( expectations: Any = None, **kwargs: Any, ) -> Pipeline: + validate_no_kwargs(kwargs, "project") if version is None: raise ValidationError( f"Project file must have a `version` attribute specifying which " diff --git a/pipeline/validation.py b/pipeline/validation.py index e8bd3de..d1cb2f2 100644 --- a/pipeline/validation.py +++ b/pipeline/validation.py @@ -25,6 +25,11 @@ def validate_type(val: Any, exp_type: type, loc: str, optional: bool = False) -> raise ValidationError(f"{loc} must be a {type_lookup[exp_type]}") +def validate_no_kwargs(kwargs: dict[str, Any], loc: str) -> None: + if kwargs: + raise ValidationError(f"Unexpected parameters ({', '.join(kwargs)}) in {loc}") + + def validate_glob_pattern(pattern: str, privacy_level: str) -> None: """ These patterns get converted into regular expressions and matched diff --git a/tests/test_type_validation.py b/tests/test_type_validation.py index c35e662..2ae5720 100644 --- a/tests/test_type_validation.py +++ b/tests/test_type_validation.py @@ -1,3 +1,5 @@ +import re + import pytest from pipeline.exceptions import ValidationError @@ -114,3 +116,61 @@ def test_output_filename_incorrect_type(): }, expectations={"population_size": 10}, ) + + +def test_project_extra_parameters(): + with pytest.raises( + ValidationError, match=re.escape("Unexpected parameters (extra) in project") + ): + Pipeline.build(extra=123) + + +def test_action_extra_parameters(): + with pytest.raises( + ValidationError, + match=re.escape("Unexpected parameters (extra) in action action1"), + ): + Pipeline.build( + version=3, + actions={ + "action1": { + "outputs": {}, + "run": "test:v1", + "extra": 123, + } + }, + expectations={"population_size": 10}, + ) + + +def test_outputs_extra_parameters(): + with pytest.raises( + ValidationError, + match=re.escape( + "Unexpected parameters (extra) in `outputs` section for action action1" + ), + ): + Pipeline.build( + version=3, + actions={ + "action1": { + "outputs": {"highly_sensitive": {"dataset": {}}, "extra": 123}, + "run": "test:v1", + } + }, + expectations={"population_size": 10}, + ) + + +def test_expectations_extra_parameters(): + with pytest.raises( + ValidationError, + match=re.escape( + "Unexpected parameters (extra) in project `expectations` section" + ), + ): + Pipeline.build( + version=3, + actions={}, + expectations={"population_size": 10, "extra": 123}, + ) From ecf3e318a5bd340cc781da520f68e1c7ded63676 Mon Sep 17 00:00:00 2001 From: Peter Inglesby Date: Fri, 27 Sep 2024 10:22:01 +0100 Subject: [PATCH 20/20] Update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d5f1762..89731d2 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ For example: data = load_pipeline(f.read()) -The returned object is a Pydantic model, `Pipeline`, defined in `pipeline/models.py`. +The returned object is an instance of `pipeline.models.Pipeline`. ## Developer docs