Skip to content

Commit

Permalink
Merge pull request #981 from opensafely-core/t1oo
Browse files Browse the repository at this point in the history
Automatically apply T1OO exclusions
  • Loading branch information
evansd authored Nov 3, 2023
2 parents 7d869e8 + dad5387 commit ff2cb40
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 2 deletions.
1 change: 1 addition & 0 deletions cohortextractor/study_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
self._original_default_expectations = default_expectations or {}
self.set_index_date(index_date)
self.pandas_csv_args = self.get_pandas_csv_args(self.covariate_definitions)

self.database_url = os.environ.get("DATABASE_URL")
self.temporary_database = os.environ.get("TEMP_DATABASE_NAME")
if self.database_url:
Expand Down
55 changes: 53 additions & 2 deletions cohortextractor/tpp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import re
import uuid
from functools import cached_property
from urllib import parse

import pandas
import structlog
Expand Down Expand Up @@ -36,10 +37,14 @@
SLEEP = 4
BACKOFF_FACTOR = 4

T1OO_TABLE = "PatientsWithTypeOneDissent"


class TPPBackend:
_db_connection = None
_current_column_name = None
# TODO: Temporary default to support safe deployment
include_t1oo = True

def __init__(
self,
Expand All @@ -48,18 +53,42 @@ def __init__(
temporary_database=None,
dummy_data=False,
):
self.database_url = database_url
if database_url is not None:
# set self.include_t1oo from the database url
self.database_url = self.modify_dsn(database_url)
else:
self.database_url = database_url

self.covariate_definitions = covariate_definitions
self.temporary_database = temporary_database
self.dummy_data = dummy_data
self.next_temp_table_id = 1
self._therapeutics_table_name = None
self.truncate_sql_logs = False

if self.covariate_definitions:
self.queries = self.get_queries(self.covariate_definitions)
else:
self.queries = []

def modify_dsn(self, dsn):
"""
Removes the `opensafely_include_t1oo` parameter if present and uses it to set
the `include_t1oo` attribute accordingly
"""
parts = parse.urlparse(dsn)
params = parse.parse_qs(parts.query, keep_blank_values=True)
include_t1oo_values = params.pop("opensafely_include_t1oo", [])
if len(include_t1oo_values) == 1:
self.include_t1oo = include_t1oo_values[0].lower() == "true"
elif len(include_t1oo_values) != 0:
raise ValueError(
"`opensafely_include_t1oo` parameter must not be supplied more than once"
)
new_query = parse.urlencode(params, doseq=True)
new_parts = parts._replace(query=new_query)
return parse.urlunparse(new_parts)

def to_file(self, filename):
queries = list(self.queries)
# If we have a temporary database available we write results to a table
Expand Down Expand Up @@ -383,15 +412,37 @@ def get_queries(self, covariate_definitions):
for name in table_queries
if name != "population"
]

wheres = [f'{output_columns["population"]} = 1']

def get_t1oo_exclude_expressions():
# If this query has been explictly flagged as including T1OO patients then
# return unmodified
if self.include_t1oo:
return [], []
# Otherwise we add an extra LEFT OUTER JOIN on the T1OO table and
# WHERE clause which will exclude any patient IDs found in the T1OO table
return (
[
f"LEFT OUTER JOIN {T1OO_TABLE} ON {T1OO_TABLE}.Patient_ID = {patient_id_expr}"
],
[f"{T1OO_TABLE}.Patient_ID IS null"],
)

t100_join, t1oo_where = get_t1oo_exclude_expressions()
joins.extend(t100_join)
joins_str = "\n ".join(joins)
wheres.extend(t1oo_where)
where_str = " AND ".join(wheres)

joined_output_query = f"""
-- Join all columns for final output
SELECT
{output_columns_str}
FROM
{primary_table}
{joins_str}
WHERE {output_columns["population"]} = 1
WHERE {where_str}
"""
all_queries = []
for sql_list in table_queries.values():
Expand Down
59 changes: 59 additions & 0 deletions tests/test_logging.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import re
from collections import Counter
from unittest.mock import patch
Expand All @@ -16,6 +17,15 @@
)


@pytest.fixture
def set_database_url_with_t1oo(monkeypatch):
if "TPP_DATABASE_URL" in os.environ:
monkeypatch.setenv(
"DATABASE_URL",
f'{os.environ["TPP_DATABASE_URL"]}?opensafely_include_t1oo=True',
)


@pytest.fixture(name="logger")
def fixture_logger():
"""Modify `capture_logs` to keep reference to `processors` list intact,
Expand Down Expand Up @@ -195,6 +205,55 @@ def test_stats_logging_tpp_backend(logger):
)


def test_stats_logging_tpp_backend_with_t1oo(logger, set_database_url_with_t1oo):
# The query counter is a global at the module level, so it isn't reset between tests
# Find the next position (without incrementing it); this is the start of the test's timing logs
start_counter = timing_log_counter.next

study = StudyDefinition(
population=patients.all(),
)
study.to_dicts()

# initial stats
expected_initial_study_def_logs = [
# output columns include patient_id and population
# tables - Patient table only
# no joins because t1oo are included, so there is no need to join on the
# t1oo table to exclude them
{"output_column_count": 2, "table_count": 1, "table_joins_count": 0},
{"variable_count": 1},
{"variables_using_codelist_count": 0},
]
# timing stats
# logs in tpp_backend during query execution

expected_timing_log_params = [
*_sql_execute_timing_logs(
description="Query for population",
sql="SELECT * INTO #population",
timing_id=start_counter,
),
*_sql_execute_timing_logs(
description=None,
sql="CREATE CLUSTERED INDEX patient_id_ix ON #population (patient_id)",
timing_id=start_counter + 1,
is_truncated=False,
),
*_sql_execute_timing_logs(
description="Join all columns for final output",
sql="#population.patient_id AS [patient_id]",
timing_id=start_counter + 2,
),
]
assert_stats_logs(
logger,
expected_initial_study_def_logs,
expected_timing_log_params,
downloaded=False,
)


@patch("cohortextractor.cohortextractor.preflight_generation_check")
@patch(
"cohortextractor.cohortextractor.list_study_definitions",
Expand Down
96 changes: 96 additions & 0 deletions tests/test_tpp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
Organisation,
Patient,
PatientAddress,
PatientsWithTypeOneDissent,
PotentialCareHomeAddress,
RegistrationHistory,
SGSS_AllTests_Negative,
Expand All @@ -79,6 +80,18 @@ def set_database_url(monkeypatch):
monkeypatch.setenv("DATABASE_URL", os.environ["TPP_DATABASE_URL"])


@pytest.fixture
def set_database_url_with_t1oo(monkeypatch):
def set_db_url(t1oo_value):
if "TPP_DATABASE_URL" in os.environ:
monkeypatch.setenv(
"DATABASE_URL",
f'{os.environ["TPP_DATABASE_URL"]}?opensafely_include_t1oo={t1oo_value}',
)

return set_db_url


def setup_module(module):
make_database()

Expand Down Expand Up @@ -128,10 +141,64 @@ def setup_function(function):
session.query(UKRR).delete()
session.query(Patient).delete()
session.query(BuildProgress).delete()
session.query(PatientsWithTypeOneDissent).delete()

session.commit()


@pytest.mark.parametrize(
"dsn_in,dsn_out,t1oo_status",
[
(
"mssql://user:pass@localhost:4321/db",
"mssql://user:pass@localhost:4321/db",
True,
),
(
"mssql://user:pass@localhost:4321/db?param1=one&param2&param1=three",
"mssql://user:pass@localhost:4321/db?param1=one&param1=three&param2=",
True,
),
(
"mssql://user:pass@localhost:4321/db?opensafely_include_t1oo&param2=two",
"mssql://user:pass@localhost:4321/db?param2=two",
False,
),
(
"mssql://user:pass@localhost:4321/db?opensafely_include_t1oo=false",
"mssql://user:pass@localhost:4321/db",
False,
),
(
"mssql://user:pass@localhost:4321/db?opensafely_include_t1oo=true",
"mssql://user:pass@localhost:4321/db",
True,
),
(
"mssql://user:pass@localhost:4321/db?opensafely_include_t1oo=True",
"mssql://user:pass@localhost:4321/db",
True,
),
],
)
def test_tpp_backend_modify_dsn(dsn_in, dsn_out, t1oo_status):
backend = TPPBackend(database_url=dsn_in, covariate_definitions=None)
assert backend.database_url == dsn_out
assert backend.include_t1oo == t1oo_status


@pytest.mark.parametrize(
"dsn",
[
"mssql://user:pass@localhost:4321/db?opensafely_include_t1oo=false&opensafely_include_t1oo=false",
"mssql://user:pass@localhost:4321/db?opensafely_include_t1oo=false&opensafely_include_t1oo",
],
)
def test_tpp_backend_modify_dsn_rejects_duplicate_params(dsn):
with pytest.raises(ValueError, match="must not be supplied more than once"):
TPPBackend(database_url=dsn, covariate_definitions=None)


@pytest.mark.parametrize("format", ["csv", "csv.gz", "feather", "dta", "dta.gz"])
def test_minimal_study_to_file(tmp_path, format):
session = make_session()
Expand Down Expand Up @@ -181,6 +248,35 @@ def test_minimal_study_with_reserved_keywords():
assert_results(study.to_dicts(), all=["M", "F"], asc=["40", "55"])


@pytest.mark.parametrize(
"flag,expected",
[
("", ["1", "4"]),
("False", ["1", "4"]),
("false", ["1", "4"]),
("1", ["1", "4"]),
("True", ["1", "2", "3", "4"]),
("true", ["1", "2", "3", "4"]),
],
)
def test_minimal_study_with_t1oo_flag(set_database_url_with_t1oo, flag, expected):
set_database_url_with_t1oo(flag)
# Test that type 1 opt-outs are only included if flag is explicitly set to "True"
session = make_session()
patient_1 = Patient(Patient_ID=1, DateOfBirth="1980-01-01", Sex="M")
patient_2 = Patient(Patient_ID=2, DateOfBirth="1965-01-01", Sex="F")
patient_3 = Patient(Patient_ID=3, DateOfBirth="1975-01-01", Sex="F")
patient_4 = Patient(Patient_ID=4, DateOfBirth="1985-01-01", Sex="F")
t1oo_2 = PatientsWithTypeOneDissent(Patient_ID=2)
t1oo_3 = PatientsWithTypeOneDissent(Patient_ID=3)
session.add_all([patient_1, patient_2, patient_3, patient_4, t1oo_2, t1oo_3])
session.commit()
study = StudyDefinition(
population=patients.all(),
)
assert_results(study.to_dicts(), patient_id=expected)


@pytest.mark.parametrize("format", ["csv", "csv.gz", "feather", "dta", "dta.gz"])
def test_study_to_file_with_therapeutic_risk_groups(tmp_path, format):
session = make_session()
Expand Down
8 changes: 8 additions & 0 deletions tests/tpp_backend_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,3 +1055,11 @@ class VmpMapping(Base):

id = Column(String(collation="Latin1_General_CI_AS"))
prev_id = Column(String(collation="Latin1_General_CI_AS"))


class PatientsWithTypeOneDissent(Base):
__tablename__ = "PatientsWithTypeOneDissent"
# fake pk to satisfy the ORM
# Patient_ID might be the primary key, TBC
pk = Column(Integer, primary_key=True)
Patient_ID = Column(types.BIGINT)

0 comments on commit ff2cb40

Please sign in to comment.