Skip to content

Commit

Permalink
Merge pull request #2189 from opensafely-core/DRMacIver/fix-failing-n…
Browse files Browse the repository at this point in the history
…extgen-assertion

Fix occasional assertion failure in nextgen dummy data
  • Loading branch information
DRMacIver authored Oct 25, 2024
2 parents 930bdd7 + c0eec5e commit f4ffffd
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 9 deletions.
7 changes: 4 additions & 3 deletions ehrql/dummy_data_nextgen/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,10 @@ def get_random_value(self, column_info):
# TODO: Currently this code only runs when the column is date_of_birth
# so condition is always hit. Remove this pragma when that stops being
# the case.
if column_info.get_constraint(
Constraint.FirstOfMonth
): # pragma: no branch
if (
column_info.get_constraint(Constraint.FirstOfMonth)
and minimum.day != 1
):
if minimum.month == 12:
minimum = minimum.replace(year=minimum.year + 1, month=1, day=1)
else:
Expand Down
62 changes: 62 additions & 0 deletions ehrql/query_model/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ class AggregatedSeries(OneRowPerPatientSeries): ...
class Value(OneRowPerPatientSeries[T]):
value: T

def _repr_pretty_(self, p, cycle):
p.pretty(self.value)

def __post_init__(self):
super().__post_init__()
# Because we need to be strict about equality (see `__eq__()` below) we can only
Expand Down Expand Up @@ -192,11 +195,17 @@ class SelectTable(ManyRowsPerPatientFrame):
name: str
schema: TableSchema

def _repr_pretty_(self, p, cycle):
p.pretty(self.name)


class SelectPatientTable(OneRowPerPatientFrame):
name: str
schema: TableSchema

def _repr_pretty_(self, p, cycle):
p.text(self.name)


class InlinePatientTable(OneRowPerPatientFrame):
# `rows` is an iterable of tuples specifing the data for the table in the form:
Expand Down Expand Up @@ -224,6 +233,10 @@ class SelectColumn(Series):
source: Frame
name: str

def _repr_pretty_(self, p, cycle):
p.pretty(self.source)
p.text("." + self.name)


class Filter(ManyRowsPerPatientFrame):
source: ManyRowsPerPatientFrame
Expand Down Expand Up @@ -293,38 +306,87 @@ class EQ(Series[bool]):
lhs: Series[T]
rhs: Series[T]

def _repr_pretty_(self, p, cycle):
p.pretty(self.lhs)
p.text(" == ")
p.pretty(self.rhs)

class NE(Series[bool]):
lhs: Series[T]
rhs: Series[T]

def _repr_pretty_(self, p, cycle):
p.pretty(self.lhs)
p.text(" != ")
p.pretty(self.rhs)

class LT(Series[bool]):
lhs: Series[Comparable]
rhs: Series[Comparable]

def _repr_pretty_(self, p, cycle):
p.pretty(self.lhs)
p.text(" < ")
p.pretty(self.rhs)

class LE(Series[bool]):
lhs: Series[Comparable]
rhs: Series[Comparable]

def _repr_pretty_(self, p, cycle):
p.pretty(self.lhs)
p.text(" <= ")
p.pretty(self.rhs)

class GT(Series[bool]):
lhs: Series[Comparable]
rhs: Series[Comparable]

def _repr_pretty_(self, p, cycle):
p.pretty(self.lhs)
p.text(" > ")
p.pretty(self.rhs)

class GE(Series[bool]):
lhs: Series[Comparable]
rhs: Series[Comparable]

def _repr_pretty_(self, p, cycle):
p.pretty(self.lhs)
p.text(" >= ")
p.pretty(self.rhs)

# Boolean
class And(Series[bool]):
lhs: Series[bool]
rhs: Series[bool]

def _repr_pretty_(self, p, cycle):
p.text("(")
p.pretty(self.lhs)
p.text(") & (")
p.pretty(self.rhs)
p.text(")")

class Or(Series[bool]):
lhs: Series[bool]
rhs: Series[bool]

def _repr_pretty_(self, p, cycle):
p.text("(")
p.pretty(self.lhs)
p.text(") | (")
p.pretty(self.rhs)
p.text(")")

class Not(Series[bool]):
source: Series[bool]

def _repr_pretty_(self, p, cycle):
p.text("~(")
p.pretty(self.source)
p.text(")")

# Null handling
class IsNull(Series[bool]):
source: Series[Any]
Expand Down
6 changes: 5 additions & 1 deletion tests/generative/test_query_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import hypothesis as hyp
import hypothesis.strategies as st
import pytest
from hypothesis.vendor.pretty import _singleton_pprinters
from hypothesis.vendor.pretty import _singleton_pprinters, pretty

from ehrql.dummy_data import DummyDataGenerator
from ehrql.query_model.introspection import all_unique_nodes
Expand Down Expand Up @@ -116,6 +116,7 @@ class EnabledTests(Enum):
dummy_data = auto()
main_query = auto()
all_population = auto()
pretty_printing = auto()


if TEST_NAMES_TO_RUN := set(
Expand Down Expand Up @@ -150,6 +151,9 @@ def test_query_model(
run_dummy_data_test(population, variable)
if EnabledTests.main_query in test_types:
run_test(query_engines, data, population, variable, recorder)
if EnabledTests.pretty_printing in test_types:
pretty(population)
pretty(data)

if EnabledTests.all_population in test_types:
# We run the test again using a simplified population definition which includes all
Expand Down
18 changes: 13 additions & 5 deletions tests/unit/dummy_data_nextgen/test_specific_datasets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import operator
from datetime import date
from unittest import mock

Expand Down Expand Up @@ -164,7 +165,7 @@ def birthday_range_query(draw):
# results.

reasonable_dates = st.dates(min_value=date(1900, 1, 1), max_value=date(2024, 9, 1))
valid_date = draw(reasonable_dates).replace(month=1)
valid_date = draw(reasonable_dates).replace(day=1)

query_endpoints = draw(st.lists(reasonable_dates, min_size=1))
query_components = []
Expand All @@ -179,21 +180,28 @@ def birthday_range_query(draw):
allow_equal = True
if endpoint >= valid_date:
if allow_equal:
query_components.append(patients.date_of_birth <= endpoint)
op = operator.le
else:
query_components.append(patients.date_of_birth < endpoint)
op = operator.lt
else:
if allow_equal:
query_components.append(patients.date_of_birth >= endpoint)
op = operator.ge
else:
query_components.append(patients.date_of_birth > endpoint)
op = operator.gt
assert op(valid_date, endpoint)
query_components.append(op(patients.date_of_birth, endpoint))
while len(query_components) > 1:
q = query_components.pop()
i = draw(st.integers(0, len(query_components) - 1))
query_components[i] &= q
return query_components[0]


@example(
query=(patients.date_of_birth >= date(2000, 1, 1))
& (patients.date_of_birth < date(2000, 1, 2)),
target_size=1,
)
@example(query=patients.date_of_birth < date(1900, 12, 31), target_size=1000)
@example(query=patients.date_of_birth >= date(1900, 1, 2), target_size=1000)
@example(query=patients.date_of_birth >= date(2000, 12, 1), target_size=1000)
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_query_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@ def test_incompatible_duration_operations(expr):
"__radd__",
"__rsub__",
"_cast",
"_repr_pretty_",
},
)

Expand Down

0 comments on commit f4ffffd

Please sign in to comment.