Skip to content

Commit

Permalink
Merge pull request #115 from justin13601/62_improve_testing
Browse files Browse the repository at this point in the history
Improvements to testing
  • Loading branch information
mmcdermott authored Aug 24, 2024
2 parents 208a91d + af93dab commit 35538ad
Show file tree
Hide file tree
Showing 7 changed files with 627 additions and 345 deletions.
20 changes: 13 additions & 7 deletions src/aces/extract_subtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def extract_subtree(
Examples:
>>> from bigtree import Node
>>> from datetime import datetime
>>> from .types import ToEventWindowBounds, TemporalWindowBounds
>>> # We'll use an example for in-hospital mortality prediction. Our root event of the tree will be
>>> # an admission event.
>>> root = Node("admission")
Expand All @@ -63,27 +64,27 @@ def extract_subtree(
>>> # Node 1 will represent our gap window. We say that in the 24 hours after the admission, there
>>> # should be no discharges, deaths, or covid events.
>>> gap_node = Node("gap") # This sets the node's name.
>>> gap_node.endpoint_expr = (True, timedelta(days=2), True)
>>> gap_node.endpoint_expr = TemporalWindowBounds(True, timedelta(days=2), True)
>>> gap_node.constraints = {
... "is_discharge": (None, 0), "is_death": (None, 0), "is_covid_dx": (None, 0)
... }
>>> gap_node.parent = root
>>> # Node 2 will start our target window and span until the next discharge or death event.
>>> # There should be no covid events.
>>> target_node = Node("target") # This sets the node's name.
>>> target_node.endpoint_expr = (True, "is_discharge", True)
>>> target_node.endpoint_expr = ToEventWindowBounds(True, "is_discharge", True)
>>> target_node.constraints = {"is_covid_dx": (None, 0)}
>>> target_node.parent = gap_node
>>> #
>>> #### BRANCH 2 ####
>>> # Finally, for our second branch, we will impose no constraints but track the input time range,
>>> # which will span from the beginning of the record to 24 hours after admission.
>>> input_end_node = Node("input_end")
>>> input_end_node.endpoint_expr = (True, timedelta(days=1), True)
>>> input_end_node.endpoint_expr = TemporalWindowBounds(True, timedelta(days=1), True)
>>> input_end_node.constraints = {}
>>> input_end_node.parent = root
>>> input_start_node = Node("input_start")
>>> input_start_node.endpoint_expr = (True, "-_RECORD_START", True)
>>> input_start_node.endpoint_expr = ToEventWindowBounds(True, "-_RECORD_START", True)
>>> input_start_node.constraints = {}
>>> input_start_node.parent = root
>>> #
Expand All @@ -93,11 +94,11 @@ def extract_subtree(
>>> # This will be expressed through two windows, one spanning back a year, and the other looking
>>> # prior to that year.
>>> pre_node_1yr = Node("pre_node_1yr")
>>> pre_node_1yr.endpoint_expr = (False, timedelta(days=-365), False)
>>> pre_node_1yr.endpoint_expr = TemporalWindowBounds(False, timedelta(days=-365), False)
>>> pre_node_1yr.constraints = {}
>>> pre_node_1yr.parent = root
>>> pre_node_total = Node("pre_node_total")
>>> pre_node_total.endpoint_expr = (False, "-_RECORD_START", False)
>>> pre_node_total.endpoint_expr = ToEventWindowBounds(False, "-_RECORD_START", False)
>>> pre_node_total.constraints = {"*": (1, None)}
>>> pre_node_total.parent = pre_node_1yr
>>> #
Expand Down Expand Up @@ -272,11 +273,16 @@ def extract_subtree(
# In an event bound case, the child root will be a proper extant event, so it will be the
# anchor as well, and thus the child root offset should be zero.
child_root_offset = timedelta(days=0)
if endpoint_expr.end_event.startswith("-"):
child_anchor_time = "timestamp_at_start"
else:
child_anchor_time = "timestamp_at_end"

window_summary_df = (
aggregate_event_bound_window(predicates_df, endpoint_expr)
.with_columns(
pl.col("timestamp").alias("subtree_anchor_timestamp"),
pl.col("timestamp_at_end").alias("child_anchor_timestamp"),
pl.col(child_anchor_time).alias("child_anchor_timestamp"),
)
.drop("timestamp")
)
Expand Down
14 changes: 11 additions & 3 deletions src/aces/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,16 @@ def parse_timedelta(time_str: str) -> timedelta:

@contextmanager
def capture_output():
"""A context manager to capture stdout output."""
"""A context manager to capture stdout output.
This can eventually be eliminated if https://github.com/kayjan/bigtree/issues/285 is resolved.
Examples:
>>> with capture_output() as captured:
... print("Hello, world!")
>>> captured.getvalue().strip()
'Hello, world!'
"""
new_out = io.StringIO() # Create a StringIO object to capture output
old_out = sys.stdout # Save the current stdout so we can restore it later
try:
Expand All @@ -54,9 +63,8 @@ def capture_output():
def log_tree(node):
"""Logs the tree structure using logging.info."""
with capture_output() as captured:
print("\n")
print_tree(node, style="const_bold") # This will print to the captured StringIO instead of stdout
logger.info(captured.getvalue()) # Log the captured output
logger.info("\n" + captured.getvalue()) # Log the captured output


def hydra_loguru_init(filename) -> None:
Expand Down
55 changes: 0 additions & 55 deletions tests/test_check_static_variables.py

This file was deleted.

208 changes: 64 additions & 144 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@

root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True)

import tempfile
from pathlib import Path
from datetime import datetime

import polars as pl
from loguru import logger

from .utils import assert_df_equal, run_command
from .utils import cli_test

pl.enable_string_cache()

Expand Down Expand Up @@ -129,147 +127,69 @@

# Expected output
EXPECTED_OUTPUT = {
"inhospital_mortality": {
"subject_id": [1],
"index_timestamp": ["01/28/1991 23:32"],
"label": [0],
"trigger": ["01/27/1991 23:32"],
"input.end_summary": [
{
"window_name": "input.end",
"timestamp_at_start": "01/27/1991 23:32",
"timestamp_at_end": "01/28/1991 23:32",
"admission": 0,
"discharge": 0,
"death": 0,
"discharge_or_death": 0,
"_ANY_EVENT": 4,
},
],
"input.start_summary": [
{
"window_name": "input.start",
"timestamp_at_start": "12/01/1989 12:03",
"timestamp_at_end": "01/28/1991 23:32",
"admission": 2,
"discharge": 1,
"death": 0,
"discharge_or_death": 1,
"_ANY_EVENT": 16,
},
],
"gap.end_summary": [
{
"window_name": "gap.end",
"timestamp_at_start": "01/27/1991 23:32",
"timestamp_at_end": "01/29/1991 23:32",
"admission": 0,
"discharge": 0,
"death": 0,
"discharge_or_death": 0,
"_ANY_EVENT": 5,
},
],
"target.end_summary": [
{
"window_name": "target.end",
"timestamp_at_start": "01/29/1991 23:32",
"timestamp_at_end": "01/31/1991 02:15",
"admission": 0,
"discharge": 1,
"death": 0,
"discharge_or_death": 1,
"_ANY_EVENT": 7,
},
],
}
"inhospital_mortality": pl.DataFrame(
{
"subject_id": [1],
"index_timestamp": [datetime(1991, 1, 28, 23, 32)],
"label": [0],
"trigger": [datetime(1991, 1, 27, 23, 32)],
"input.end_summary": [
{
"window_name": "input.end",
"timestamp_at_start": datetime(1991, 1, 27, 23, 32),
"timestamp_at_end": datetime(1991, 1, 28, 23, 32),
"admission": 0,
"discharge": 0,
"death": 0,
"discharge_or_death": 0,
"_ANY_EVENT": 4,
},
],
"input.start_summary": [
{
"window_name": "input.start",
"timestamp_at_start": datetime(1989, 12, 1, 12, 3),
"timestamp_at_end": datetime(1991, 1, 28, 23, 32),
"admission": 2,
"discharge": 1,
"death": 0,
"discharge_or_death": 1,
"_ANY_EVENT": 16,
},
],
"gap.end_summary": [
{
"window_name": "gap.end",
"timestamp_at_start": datetime(1991, 1, 27, 23, 32),
"timestamp_at_end": datetime(1991, 1, 29, 23, 32),
"admission": 0,
"discharge": 0,
"death": 0,
"discharge_or_death": 0,
"_ANY_EVENT": 5,
},
],
"target.end_summary": [
{
"window_name": "target.end",
"timestamp_at_start": datetime(1991, 1, 29, 23, 32),
"timestamp_at_end": datetime(1991, 1, 31, 2, 15),
"admission": 0,
"discharge": 1,
"death": 0,
"discharge_or_death": 1,
"_ANY_EVENT": 7,
},
],
}
)
}


def test_e2e():
with tempfile.TemporaryDirectory() as d:
data_dir = Path(d) / "sample_data"
configs_dir = Path(d) / "sample_configs"
output_dir = Path(d) / "sample_output"

# Create the directories
data_dir.mkdir()
configs_dir.mkdir()
output_dir.mkdir()

# Write the predicates CSV file
predicates_csv = data_dir / "sample_data.csv"
predicates_csv.write_text(PREDICATES_CSV.strip())

# Run script and check the outputs
all_stderrs = []
all_stdouts = []
full_stderr = ""
full_stdout = ""
try:
for task_name, task_cfg in TASKS_CFGS.items():
logger.info(f"Running task '{task_name}'...")

# Write the task config YAMLs
task_cfg_path = configs_dir / f"{task_name}.yaml"
task_cfg_path.write_text(task_cfg)

output_path = output_dir / f"{task_name}.parquet"

extraction_config_kwargs = {
"data.path": str(predicates_csv.resolve()),
"data.standard": "direct",
"cohort_dir": str(configs_dir.resolve()),
"cohort_name": task_name,
"output_filepath": str(output_path.resolve()),
"hydra.verbose": True,
}

stderr, stdout = run_command("aces-cli", extraction_config_kwargs, task_name)
stderr, stdout = run_command("aces-cli", extraction_config_kwargs, task_name)

all_stderrs.append(stderr)
all_stdouts.append(stdout)

full_stderr = "\n".join(all_stderrs)
full_stdout = "\n".join(all_stdouts)

fp = output_dir / f"{task_name}.parquet"
assert fp.is_file(), f"Expected {fp} to exist."
got_df = pl.read_parquet(fp)

# Check the columns
expected_columns = EXPECTED_OUTPUT[task_name].keys()
assert got_df.columns == list(expected_columns), f"Columns mismatch for task '{task_name}'"

# Check the data
for col_name, expected_data in EXPECTED_OUTPUT[task_name].items():
if col_name in ["index_timestamp", "trigger"]:
want = pl.DataFrame({col_name: expected_data}).with_columns(
pl.col(col_name).str.strptime(pl.Datetime, format=TS_FORMAT)
)
elif col_name.endswith("_summary"):
df_struct = pl.DataFrame(expected_data)
df_struct = df_struct.with_columns(
pl.col("timestamp_at_start").str.strptime(pl.Datetime, format=TS_FORMAT),
pl.col("timestamp_at_end").str.strptime(pl.Datetime, format=TS_FORMAT),
)
want = df_struct.select(
pl.struct(*[col for col in df_struct.columns]).alias(col_name)
)
else:
want = pl.DataFrame({col_name: expected_data}).with_columns(
*[
pl.col(col_name).cast(PRED_CNT_TYPE)
if col_name != LAST_EVENT_INDEX_COLUMN
else pl.col(col_name).cast(EVENT_INDEX_TYPE)
]
)
got = got_df.select(col_name)
assert_df_equal(want, got, f"Data mismatch for task '{task_name}', column '{col_name}'")

except AssertionError as e:
print(f"Failed on task '{task_name}'")
print(f"stderr:\n{full_stderr}")
print(f"stdout:\n{full_stdout}")
raise e
cli_test(
input_files={"sample_data": PREDICATES_CSV},
task_configs=TASKS_CFGS,
want_outputs_by_task=EXPECTED_OUTPUT,
data_standard="direct",
)
Loading

0 comments on commit 35538ad

Please sign in to comment.