Skip to content

Commit

Permalink
test: simplify generation of valid variables
Browse files Browse the repository at this point in the history
  • Loading branch information
hollandjg committed Nov 22, 2023
1 parent 901253f commit 8e8fbb1
Showing 1 changed file with 35 additions and 17 deletions.
52 changes: 35 additions & 17 deletions tests/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,22 @@

FLOAT_STRATEGIES = st.one_of(
st.integers(),
st.floats(min_value=0, max_value=1, allow_nan=False, allow_subnormal=False),
st.floats(allow_infinity=False, allow_nan=False, allow_subnormal=False),
st.floats(allow_infinity=True, allow_nan=False, allow_subnormal=False),
st.floats(allow_infinity=True, allow_nan=False, allow_subnormal=True),
)

PROBABILITY_STRATEGIES = st.floats(
min_value=0, max_value=1, allow_nan=False, allow_subnormal=False
)

VALUE_TYPE_VALUE_STRATEGY_MAPPING = {
ValueType.BOOLEAN: st.booleans(),
ValueType.INTEGER: st.integers(),
ValueType.REAL: FLOAT_STRATEGIES,
ValueType.SIGMOID: FLOAT_STRATEGIES,
ValueType.PROBABILITY: FLOAT_STRATEGIES,
ValueType.PROBABILITY_SAMPLE: FLOAT_STRATEGIES,
ValueType.PROBABILITY: PROBABILITY_STRATEGIES,
ValueType.PROBABILITY_SAMPLE: PROBABILITY_STRATEGIES,
ValueType.PROBABILITY_DISTRIBUTION: FLOAT_STRATEGIES,
ValueType.CLASS: st.text(),
}
Expand Down Expand Up @@ -88,26 +91,41 @@ def variable_strategy(
variable_label = draw(st.text(max_size=variable_label_max_length))
units = draw(st.text(max_size=units_max_length))
is_covariate = draw(st.booleans())
type = draw(st.sampled_from(ValueType))
dtype = VALUE_TYPE_DTYPE_MAPPING[type]
value_strategy = VALUE_TYPE_VALUE_STRATEGY_MAPPING[type]

value_range = draw(
st.one_of(
st.none(),
st.tuples(value_strategy, value_strategy).map(sorted),

value_type = draw(st.sampled_from(ValueType))

dtype = VALUE_TYPE_DTYPE_MAPPING[value_type]
value_strategy = VALUE_TYPE_VALUE_STRATEGY_MAPPING[value_type]

if value_type is ValueType.BOOLEAN:
allowed_values = [True, False]
value_range = None
rescale = 1
elif value_type in {
ValueType.PROBABILITY,
ValueType.PROBABILITY_SAMPLE,
ValueType.PROBABILITY_DISTRIBUTION,
}:
value_range = (0, 1)
allowed_values = None
rescale = 1
else:
value_range = draw(
st.one_of(
st.none(),
st.tuples(value_strategy, value_strategy).map(sorted),
)
)
)
allowed_values = draw(
st.one_of(st.none(), st.lists(value_strategy, unique=True, min_size=1))
)
rescale = draw(st.one_of(st.just(1), value_strategy))
allowed_values = draw(
st.one_of(st.none(), st.lists(value_strategy, unique=True, min_size=1))
)
rescale = draw(st.one_of(st.just(1), value_strategy))

v = Variable(
name=name,
variable_label=variable_label,
units=units,
type=type,
type=value_type,
is_covariate=is_covariate,
value_range=value_range,
allowed_values=allowed_values,
Expand Down

0 comments on commit 8e8fbb1

Please sign in to comment.