Skip to content

Commit

Permalink
test: disallow both value range and allowed values.
Browse files Browse the repository at this point in the history
  • Loading branch information
hollandjg committed Nov 22, 2023
1 parent 496ed93 commit f5240f0
Showing 1 changed file with 56 additions and 59 deletions.
115 changes: 56 additions & 59 deletions tests/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def variable_integer_strategy(draw, name=None, label=None, units=None, covariate
)
value_type = ValueType.INTEGER
dtype = int

value_range = draw(
st.one_of(
st.none(),
Expand All @@ -92,20 +93,10 @@ def variable_integer_strategy(draw, name=None, label=None, units=None, covariate
)
)
if value_range is None:
allowed_values = draw(
st.one_of(st.none(), st.lists(st.integers(), unique=True, min_size=1))
)
allowed_values = draw(st.one_of(st.none(), st.sets(st.integers(), min_size=1)))
else:
allowed_values = draw(
st.one_of(
st.none(),
st.lists(
st.integers(min_value=value_range[0], max_value=value_range[1]),
unique=True,
min_size=1,
),
)
)
allowed_values = None

rescale = draw(
st.one_of(
st.just(1),
Expand Down Expand Up @@ -146,25 +137,9 @@ def variable_real_strategy(draw, name=None, label=None, units=None, covariate=No
)

if value_range is None:
allowed_values = draw(
st.one_of(st.none(), st.lists(range_strategy, unique=True, min_size=1))
)
allowed_values = draw(st.one_of(st.none(), st.sets(range_strategy, min_size=1)))
else:
allowed_values = draw(
st.one_of(
st.none(),
st.lists(
st.floats(
min_value=value_range[0],
max_value=value_range[1],
allow_nan=False,
allow_subnormal=False,
),
unique=True,
min_size=1,
),
)
)
allowed_values = None
rescale = draw(st.one_of(st.just(1), range_strategy))
return Variable(
name=name,
Expand Down Expand Up @@ -256,32 +231,35 @@ def variable_class_strategy(draw, name=None, label=None, units=None, covariate=N
)


VARIABLE_STRATEGIES = (
variable_boolean_strategy,
variable_integer_strategy,
variable_probability_strategy,
variable_sigmoid_strategy,
variable_real_strategy,
variable_class_strategy,
)


@st.composite
def variable_strategy(draw, value_type: Optional[ValueType] = None, **kwargs):
def variable_strategy(
draw, elements=VARIABLE_STRATEGIES, value_type: Optional[ValueType] = None, **kwargs
):
if value_type is None:
return draw(
st.one_of(
variable_boolean_strategy(**kwargs),
variable_integer_strategy(**kwargs),
variable_real_strategy(**kwargs),
variable_probability_strategy(**kwargs),
variable_sigmoid_strategy(**kwargs),
variable_class_strategy(**kwargs),
)
)
strategy = draw(st.sampled_from(elements))

else:
return draw(
{
ValueType.BOOLEAN: variable_boolean_strategy,
ValueType.INTEGER: variable_integer_strategy,
ValueType.REAL: variable_real_strategy,
ValueType.SIGMOID: variable_sigmoid_strategy,
ValueType.PROBABILITY: variable_probability_strategy,
# ValueType.PROBABILITY_SAMPLE: variable_PROBABILITY_SAMPLE_strategy,
# ValueType.PROBABILITY_DISTRIBUTION: variable_PROBABILITY_DISTRIBUTION_strategy,
ValueType.CLASS: variable_class_strategy,
}[value_type](**kwargs)
)
strategy = {
ValueType.BOOLEAN: variable_boolean_strategy,
ValueType.INTEGER: variable_integer_strategy,
ValueType.REAL: variable_real_strategy,
ValueType.SIGMOID: variable_sigmoid_strategy,
ValueType.PROBABILITY: variable_probability_strategy,
# ValueType.PROBABILITY_SAMPLE: variable_PROBABILITY_SAMPLE_strategy,
# ValueType.PROBABILITY_DISTRIBUTION: variable_PROBABILITY_DISTRIBUTION_strategy,
ValueType.CLASS: variable_class_strategy,
}[value_type]
return draw(strategy(**kwargs))


@given(variable_strategy())
Expand All @@ -292,6 +270,8 @@ def test_variable_strategy_creation(o):
@st.composite
def variablecollection_strategy(
draw,
elements=VARIABLE_STRATEGIES,
value_type: Optional[ValueType] = None,
max_ivs=5,
max_dvs=1,
max_covariates=2,
Expand All @@ -317,13 +297,28 @@ def variablecollection_strategy(
)
)
independent_variables = [
draw(variable_strategy(name=names.pop(), **kwargs)) for _ in range(n_ivs)
draw(
variable_strategy(
name=names.pop(), value_type=value_type, elements=elements, **kwargs
)
)
for _ in range(n_ivs)
]
dependent_variables = [
draw(variable_strategy(name=names.pop(), **kwargs)) for _ in range(n_dvs)
draw(
variable_strategy(
name=names.pop(), value_type=value_type, elements=elements, **kwargs
)
)
for _ in range(n_dvs)
]
covariates = [
draw(variable_strategy(name=names.pop(), **kwargs)) for _ in range(n_covariates)
draw(
variable_strategy(
name=names.pop(), value_type=value_type, elements=elements, **kwargs
)
)
for _ in range(n_covariates)
]

vc = VariableCollection(
Expand All @@ -343,17 +338,19 @@ def test_variablecollection_strategy_creation(o):
def dataframe_strategy(
draw,
variables: Optional[Sequence[Variable]] = None,
value_type: Optional[ValueType] = None,
):
if variables is None:
variable_collection = draw(variablecollection_strategy())
variable_collection = draw(variablecollection_strategy(value_type=value_type))
variables = (
variable_collection.independent_variables
+ variable_collection.dependent_variables
+ variable_collection.covariates
)

df: pd.DataFrame = draw(
st_pd.data_frames(
columns=[st_pd.column(name=v.name, dtype=v.data_type) for v in variables],
columns=[st_pd.column(dtype=v.data_type) for v in variables],
)
)

Expand Down

0 comments on commit f5240f0

Please sign in to comment.