diff --git a/.github/workflows/code-quality-main.yaml b/.github/workflows/code-quality-main.yaml index d336969..691b47c 100644 --- a/.github/workflows/code-quality-main.yaml +++ b/.github/workflows/code-quality-main.yaml @@ -13,10 +13,16 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v3 + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Install packages + run: | + pip install .[dev] - name: Run pre-commits - uses: pre-commit/action@v3.0.0 + uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/code-quality-pr.yaml b/.github/workflows/code-quality-pr.yaml index 7ca7753..bee2e11 100644 --- a/.github/workflows/code-quality-pr.yaml +++ b/.github/workflows/code-quality-pr.yaml @@ -16,10 +16,16 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v3 + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Install packages + run: | + pip install .[dev] - name: Find modified files id: file_changes @@ -31,6 +37,6 @@ jobs: run: echo '${{ steps.file_changes.outputs.files}}' - name: Run pre-commits - uses: pre-commit/action@v3.0.0 + uses: pre-commit/action@v3.0.1 with: extra_args: --files ${{ steps.file_changes.outputs.files}} diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c22ef16..4e51b11 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,16 +17,16 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python 3.10 - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: "3.10" - name: Install packages run: | - pip install -e .[dev] + pip install .[dev] #---------------------------------------------- # run test suite diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 591bc53..8210517 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ exclude: "to_organize" repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v5.0.0 hooks: # list of supported hooks: https://pre-commit.com/hooks.html - id: trailing-whitespace diff --git a/README.md b/README.md index eb335ea..a375aef 100644 --- a/README.md +++ b/README.md @@ -217,15 +217,18 @@ normal_spo2: Fields for a "plain" predicate: - `code` (required): Must be one of the following: - - a string with `//` sequence separating the column name and column value. - - a list of strings as above in the form of {any: \[???, ???, ...\]}, which will match any of the listed codes. - - a regex in the form of {regex: "???"}, which will match any code that matches that regular expression. + - a string matching values in a column named `code` (for `MEDS` only). + - a string with a `//` sequence separating the column name and the matching column value (for `ESGPT` only). + - a list of strings as above in the form of `{any: \[???, ???, ...\]}` (or the corresponding expanded indented `YAML` format), which will match any of the listed codes. + - a regex in the form of `{regex: "???"}` (or the corresponding expanded indented `YAML` format), which will match any code that matches that regular expression. - `value_min` (optional): Must be float or integer specifying the minimum value of the predicate, if the variable is presented as numerical values. - `value_max` (optional): Must be float or integer specifying the maximum value of the predicate, if the variable is presented as numerical values. - `value_min_inclusive` (optional): Must be a boolean specifying whether `value_min` is inclusive or not. - `value_max_inclusive` (optional): Must be a boolean specifying whether `value_max` is inclusive or not. - `other_cols` (optional): Must be a 1-to-1 dictionary of column name and column value, which places additional constraints on further columns. +**Note**: For memory optimization, we strongly recommend using either the List of Values or Regular Expression formats whenever possible, especially when needing to match multiple values. Defining each code as an individual string will increase memory usage significantly, as each code generates a separate predicate column. Using a list or regex consolidates multiple matching codes under a single column, reducing the overall memory footprint. + #### Derived Predicates "Derived" predicates combine existing "plain" predicates using `and` / `or` keywords and have exactly 1 required `expr` field: For instance, the following defines a predicate representing either death or discharge (by combining "plain" predicates of `death` and `discharge`): diff --git a/docs/source/configuration.md b/docs/source/configuration.md index ca83871..6571ae9 100644 --- a/docs/source/configuration.md +++ b/docs/source/configuration.md @@ -49,20 +49,29 @@ These configs consist of the following four fields: The field can additionally be a dictionary with either a `regex` key and the value being a regular expression (satisfied if the regular expression evaluates to True), or a `any` key and the value being a list of strings (satisfied if there is an occurrence for any code in the list). + + **Note**: Each individual definition of `PlainPredicateConfig` and `code` will generate a separate predicate + column. Thus, for memory optimization, it is strongly recommended to match multiple values using either the + List of Values or Regular Expression formats whenever possible. + - `value_min`: If specified, an observation will only satisfy this predicate if the occurrence of the underlying `code` with a reported numerical value that is either greater than or greater than or equal to `value_min` (with these options being decided on the basis of `value_min_inclusive`, where `value_min_inclusive=True` indicating that an observation satisfies this predicate if its value is greater than or equal to `value_min`, and `value_min_inclusive=False` indicating a greater than but not equal to will be used). + - `value_max`: If specified, an observation will only satisfy this predicate if the occurrence of the underlying `code` with a reported numerical value that is either less than or less than or equal to `value_max` (with these options being decided on the basis of `value_max_inclusive`, where `value_max_inclusive=True` indicating that an observation satisfies this predicate if its value is less than or equal to `value_max`, and `value_max_inclusive=False` indicating a less than but not equal to will be used). + - `value_min_inclusive`: See `value_min` + - `value_max_inclusive`: See `value_max` + - `other_cols`: This optional field accepts a 1-to-1 dictionary of column names to column values, and can be used to specify further constraints on other columns (ie., not `code`) for this predicate. diff --git a/pyproject.toml b/pyproject.toml index 0a3399d..4b2b26a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ expand_shards = "aces.expand_shards:main" [project.optional-dependencies] dev = [ - "pre-commit", "pytest", "pytest-cov", "pytest-subtests", "rootutils", "hypothesis" + "pre-commit<4", "pytest", "pytest-cov", "pytest-subtests", "rootutils", "hypothesis" ] profiling = ["psutil"] diff --git a/src/aces/config.py b/src/aces/config.py index b2a3183..415954a 100644 --- a/src/aces/config.py +++ b/src/aces/config.py @@ -1074,7 +1074,7 @@ class TaskExtractorConfig: >>> config = TaskExtractorConfig(predicates=predicates, trigger=trigger, windows={}) Traceback (most recent call last): ... - KeyError: "Missing 1 relationships:\\nDerived predicate 'foobar' references undefined predicate 'bar'" + KeyError: "Missing 1 relationships: Derived predicate 'foobar' references undefined predicate 'bar'" >>> predicates = {"foo": PlainPredicateConfig("foo")} >>> trigger = EventConfig("foo") @@ -1166,6 +1166,7 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta start_inclusive=True, end_inclusive=True, has={}, label=None, index_timestamp=None)}, label_window=None, index_timestamp_window=None) + >>> predicates_dict = { ... "metadata": {'description': 'A test predicates file'}, ... "description": 'this is a test', @@ -1195,6 +1196,83 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta start_inclusive=True, end_inclusive=True, has={}, label=None, index_timestamp=None)}, label_window=None, index_timestamp_window=None) + + >>> config_dict = { + ... "metadata": {'description': 'A test configuration file'}, + ... "description": 'this is a test for joining static and plain predicates', + ... "patient_demographics": {"male": {"code": "MALE"}, "female": {"code": "FEMALE"}}, + ... "predicates": {"normal_male_lab_range": {"code": "LAB", "value_min": 0, "value_max": 100, + ... "value_min_inclusive": True, "value_max_inclusive": True}, + ... "normal_female_lab_range": {"code": "LAB", "value_min": 0, "value_max": 90, + ... "value_min_inclusive": True, "value_max_inclusive": True}, + ... "normal_lab_male": {"expr": "and(normal_male_lab_range, male)"}, + ... "normal_lab_female": {"expr": "and(normal_female_lab_range, female)"}}, + ... "trigger": "_ANY_EVENT", + ... "windows": { + ... "start": { + ... "start": None, "end": "trigger + 24h", "start_inclusive": True, + ... "end_inclusive": True, "has": {"normal_lab_male": "(1, None)"}, + ... } + ... }, + ... } + >>> with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f: + ... config_path = Path(f.name) + ... yaml.dump(config_dict, f) + ... cfg = TaskExtractorConfig.load(config_path) + >>> cfg.predicates.keys() # doctest: +NORMALIZE_WHITESPACE + dict_keys(['normal_lab_male', 'normal_male_lab_range', 'female', 'male']) + + >>> config_dict = { + ... "metadata": {'description': 'A test configuration file'}, + ... "description": 'this is a test for nested derived predicates', + ... "patient_demographics": {"male": {"code": "MALE"}, "female": {"code": "FEMALE"}}, + ... "predicates": {"abnormally_low_male_lab_range": {"code": "LAB", "value_max": 90, + ... "value_max_inclusive": False}, + ... "abnormally_low_female_lab_range": {"code": "LAB", "value_max": 80, + ... "value_max_inclusive": False}, + ... "abnormally_high_lab_range": {"code": "LAB", "value_min": 120, + ... "value_min_inclusive": False}, + ... "abnormal_lab_male_range": {"expr": + ... "or(abnormally_low_male_lab_range, abnormally_high_lab_range)"}, + ... "abnormal_lab_female_range": {"expr": + ... "or(abnormally_low_female_lab_range, abnormally_high_lab_range)"}, + ... "abnormal_lab_male": {"expr": "and(abnormal_lab_male_range, male)"}, + ... "abnormal_lab_female": {"expr": "and(abnormal_lab_female_range, female)"}, + ... "abnormal_labs": {"expr": "or(abnormal_lab_male, abnormal_lab_female)"}}, + ... "trigger": "_ANY_EVENT", + ... "windows": { + ... "start": { + ... "start": None, "end": "trigger + 24h", "start_inclusive": True, + ... "end_inclusive": True, "has": {"abnormal_labs": "(1, None)"}, + ... } + ... }, + ... } + >>> with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f: + ... config_path = Path(f.name) + ... yaml.dump(config_dict, f) + ... cfg = TaskExtractorConfig.load(config_path) + >>> cfg.predicates.keys() # doctest: +NORMALIZE_WHITESPACE + dict_keys(['abnormal_lab_female', 'abnormal_lab_female_range', 'abnormal_lab_male', + 'abnormal_lab_male_range', 'abnormal_labs', 'abnormally_high_lab_range', + 'abnormally_low_female_lab_range', 'abnormally_low_male_lab_range', 'female', 'male']) + + >>> predicates_dict = { + ... "metadata": {'description': 'A test predicates file'}, + ... "description": 'this is a test', + ... "patient_demographics": {"brown_eyes": {"code": "eye_color//BR"}}, + ... "predicates": {'admission': "invalid"}, + ... } + >>> with (tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as config_fp, + ... tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as pred_fp): + ... config_path = Path(config_fp.name) + ... pred_path = Path(pred_fp.name) + ... yaml.dump(no_predicates_config, config_fp) + ... yaml.dump(predicates_dict, pred_fp) + ... cfg = TaskExtractorConfig.load(config_path, pred_path) # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + ValueError: Predicate 'admission' is not defined correctly in the configuration file. Currently + defined as the string: invalid. Please refer to the documentation for the supported formats. """ if isinstance(config_path, str): config_path = Path(config_path) @@ -1258,6 +1336,7 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta final_predicates = {**predicates, **overriding_predicates} final_demographics = {**patient_demographics, **overriding_demographics} + all_predicates = {**final_predicates, **final_demographics} logger.info("Parsing windows...") if windows is None: @@ -1271,23 +1350,45 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta logger.info("Parsing trigger event...") trigger = EventConfig(trigger) + # add window referenced predicates referenced_predicates = {pred for w in windows.values() for pred in w.referenced_predicates} + + # add trigger predicate referenced_predicates.add(trigger.predicate) + + # add label predicate if it exists and not already added label_reference = [w.label for w in windows.values() if w.label] if label_reference: referenced_predicates.update(set(label_reference)) - current_predicates = set(referenced_predicates) + special_predicates = {ANY_EVENT_COLUMN, START_OF_RECORD_KEY, END_OF_RECORD_KEY} - for pred in current_predicates - special_predicates: - if pred not in final_predicates: + for pred in set(referenced_predicates) - special_predicates: + if pred not in all_predicates: raise KeyError( - f"Something referenced predicate {pred} that wasn't defined in the configuration." - ) - if "expr" in final_predicates[pred]: - referenced_predicates.update( - DerivedPredicateConfig(**final_predicates[pred]).input_predicates + f"Something referenced predicate '{pred}' that wasn't defined in the configuration." ) + if "expr" in all_predicates[pred]: + stack = list(DerivedPredicateConfig(**all_predicates[pred]).input_predicates) + + while stack: + nested_pred = stack.pop() + + if nested_pred not in all_predicates: + raise KeyError( + f"Predicate '{nested_pred}' referenced in '{pred}' is not defined in the " + "configuration." + ) + + # if nested_pred is a DerivedPredicateConfig, unpack input_predicates and add to stack + if "expr" in all_predicates[nested_pred]: + derived_config = DerivedPredicateConfig(**all_predicates[nested_pred]) + stack.extend(derived_config.input_predicates) + referenced_predicates.add(nested_pred) # also add itself to referenced_predicates + else: + # if nested_pred is a PlainPredicateConfig, only add it to referenced_predicates + referenced_predicates.add(nested_pred) + logger.info("Parsing predicates...") predicates_to_parse = {k: v for k, v in final_predicates.items() if k in referenced_predicates} predicate_objs = {} @@ -1295,6 +1396,12 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta if "expr" in p: predicate_objs[n] = DerivedPredicateConfig(**p) else: + if isinstance(p, str): + raise ValueError( + f"Predicate '{n}' is not defined correctly in the configuration file. " + f"Currently defined as the string: {p}. " + "Please refer to the documentation for the supported formats." + ) config_data = {k: v for k, v in p.items() if k in PlainPredicateConfig.__dataclass_fields__} other_cols = {k: v for k, v in p.items() if k not in config_data} predicate_objs[n] = PlainPredicateConfig(**config_data, other_cols=other_cols) @@ -1344,7 +1451,7 @@ def _initialize_predicates(self): ) if missing_predicates: raise KeyError( - f"Missing {len(missing_predicates)} relationships:\n" + "\n".join(missing_predicates) + f"Missing {len(missing_predicates)} relationships: " + "; ".join(missing_predicates) ) self._predicate_dag_graph = nx.DiGraph(dag_relationships) diff --git a/src/aces/predicates.py b/src/aces/predicates.py index 2f10a76..e4a1ca0 100644 --- a/src/aces/predicates.py +++ b/src/aces/predicates.py @@ -609,6 +609,64 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D │ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 1 ┆ 1 ┆ 0 │ │ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 0 ┆ 1 ┆ 0 ┆ 1 │ └────────────┴─────────────────────┴─────┴──────┴────────────┴───────────────┴─────────────┘ + + >>> data = pl.DataFrame({ + ... "subject_id": [1, 1, 1, 2, 2], + ... "timestamp": [ + ... None, + ... "01/01/2021 00:00", + ... "01/01/2021 12:00", + ... "01/02/2021 00:00", + ... "01/02/2021 12:00"], + ... "adm": [0, 1, 0, 1, 0], + ... "male": [1, 0, 0, 0, 0], + ... }) + >>> predicates = { + ... "adm": PlainPredicateConfig("adm"), + ... "male": PlainPredicateConfig("male", static=True), # predicate match based on name for direct + ... "male_adm": DerivedPredicateConfig("and(male, adm)", static=['male']), + ... } + >>> trigger = EventConfig("adm") + >>> windows = { + ... "input": WindowConfig( + ... start=None, + ... end="trigger + 24h", + ... start_inclusive=True, + ... end_inclusive=True, + ... has={"_ANY_EVENT": "(32, None)"}, + ... ), + ... "gap": WindowConfig( + ... start="input.end", + ... end="start + 24h", + ... start_inclusive=False, + ... end_inclusive=True, + ... has={ + ... "adm": "(None, 0)", + ... "male_adm": "(None, 0)", + ... }, + ... ), + ... } + >>> config = TaskExtractorConfig(predicates=predicates, trigger=trigger, windows=windows) + >>> with tempfile.NamedTemporaryFile(mode="w", suffix=".csv") as f: + ... data_path = Path(f.name) + ... data.write_csv(data_path) + ... data_config = DictConfig({ + ... "path": str(data_path), "standard": "direct", "ts_format": "%m/%d/%Y %H:%M" + ... }) + ... get_predicates_df(config, data_config) + shape: (5, 6) + ┌────────────┬─────────────────────┬─────┬──────┬──────────┬────────────┐ + │ subject_id ┆ timestamp ┆ adm ┆ male ┆ male_adm ┆ _ANY_EVENT │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ + ╞════════════╪═════════════════════╪═════╪══════╪══════════╪════════════╡ + │ 1 ┆ null ┆ 0 ┆ 1 ┆ 0 ┆ null │ + │ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 1 ┆ 1 ┆ 1 │ + │ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 1 │ + │ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 1 │ + │ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 0 ┆ 0 ┆ 1 │ + └────────────┴─────────────────────┴─────┴──────┴──────────┴────────────┘ + >>> with tempfile.NamedTemporaryFile(mode="w", suffix=".csv") as f: ... data_path = Path(f.name) ... data.write_csv(data_path) @@ -638,15 +696,26 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D raise ValueError(f"Invalid data standard: {standard}. Options are 'direct', 'MEDS', 'ESGPT'.") predicate_cols = list(plain_predicates.keys()) + data = data.sort(by=["subject_id", "timestamp"], nulls_last=False) + # derived predicates logger.info("Loaded plain predicates. Generating derived predicate columns...") + static_variables = [pred for pred in cfg.plain_predicates if cfg.plain_predicates[pred].static] for name, code in cfg.derived_predicates.items(): + if any(x in static_variables for x in code.input_predicates): + data = data.with_columns( + [ + pl.col(static_var) + .first() + .over("subject_id") # take the first value in each subject_id group and propagate it + .alias(static_var) + for static_var in static_variables + ] + ) data = data.with_columns(code.eval_expr().cast(PRED_CNT_TYPE).alias(name)) logger.info(f"Added predicate column '{name}'.") predicate_cols.append(name) - data = data.sort(by=["subject_id", "timestamp"], nulls_last=False) - # add special predicates: # a column of 1s representing any predicate # a column of 0s with 1 in the first event of each subject_id representing the start of record diff --git a/src/aces/query.py b/src/aces/query.py index 701bf8f..07cef47 100644 --- a/src/aces/query.py +++ b/src/aces/query.py @@ -106,8 +106,9 @@ def query(cfg: TaskExtractorConfig, predicates_df: pl.DataFrame) -> pl.DataFrame return pl.DataFrame() result = extract_subtree(cfg.window_tree, prospective_root_anchors, predicates_df) - if result.is_empty(): - logger.info("No valid rows found.") + if result.is_empty(): # pragma: no cover + logger.warning("No valid rows found.") + return pl.DataFrame() else: # number of patients logger.info( @@ -125,7 +126,7 @@ def query(cfg: TaskExtractorConfig, predicates_df: pl.DataFrame) -> pl.DataFrame # add label column if specified if cfg.label_window: - logger.info( + logger.info( # pragma: no cover f"Extracting label '{cfg.windows[cfg.label_window].label}' from window " f"'{cfg.label_window}'..." ) @@ -137,9 +138,16 @@ def query(cfg: TaskExtractorConfig, predicates_df: pl.DataFrame) -> pl.DataFrame ) to_return_cols.insert(1, "label") + if result["label"].n_unique() == 1: # pragma: no cover + logger.warning( + f"All labels in the extracted cohort are the same: '{result['label'][0]}'. " + "This may indicate an issue with the task logic. " + "Please double-check your configuration file if this is not expected." + ) + # add index_timestamp column if specified if cfg.index_timestamp_window: - logger.info( + logger.info( # pragma: no cover f"Setting index timestamp as '{cfg.windows[cfg.index_timestamp_window].index_timestamp}' " f"of window '{cfg.index_timestamp_window}'..." )