diff --git a/tests/test_strategies.py b/tests/test_strategies.py index c536963b..13922dec 100644 --- a/tests/test_strategies.py +++ b/tests/test_strategies.py @@ -83,6 +83,7 @@ def variable_integer_strategy(draw, name=None, label=None, units=None, covariate ) value_type = ValueType.INTEGER dtype = int + value_range = draw( st.one_of( st.none(), @@ -92,20 +93,10 @@ def variable_integer_strategy(draw, name=None, label=None, units=None, covariate ) ) if value_range is None: - allowed_values = draw( - st.one_of(st.none(), st.lists(st.integers(), unique=True, min_size=1)) - ) + allowed_values = draw(st.one_of(st.none(), st.sets(st.integers(), min_size=1))) else: - allowed_values = draw( - st.one_of( - st.none(), - st.lists( - st.integers(min_value=value_range[0], max_value=value_range[1]), - unique=True, - min_size=1, - ), - ) - ) + allowed_values = None + rescale = draw( st.one_of( st.just(1), @@ -146,25 +137,9 @@ def variable_real_strategy(draw, name=None, label=None, units=None, covariate=No ) if value_range is None: - allowed_values = draw( - st.one_of(st.none(), st.lists(range_strategy, unique=True, min_size=1)) - ) + allowed_values = draw(st.one_of(st.none(), st.sets(range_strategy, min_size=1))) else: - allowed_values = draw( - st.one_of( - st.none(), - st.lists( - st.floats( - min_value=value_range[0], - max_value=value_range[1], - allow_nan=False, - allow_subnormal=False, - ), - unique=True, - min_size=1, - ), - ) - ) + allowed_values = None rescale = draw(st.one_of(st.just(1), range_strategy)) return Variable( name=name, @@ -256,32 +231,35 @@ def variable_class_strategy(draw, name=None, label=None, units=None, covariate=N ) +VARIABLE_STRATEGIES = ( + variable_boolean_strategy, + variable_integer_strategy, + variable_probability_strategy, + variable_sigmoid_strategy, + variable_real_strategy, + variable_class_strategy, +) + + @st.composite -def variable_strategy(draw, value_type: Optional[ValueType] = None, **kwargs): +def variable_strategy( + draw, elements=VARIABLE_STRATEGIES, value_type: Optional[ValueType] = None, **kwargs +): if value_type is None: - return draw( - st.one_of( - variable_boolean_strategy(**kwargs), - variable_integer_strategy(**kwargs), - variable_real_strategy(**kwargs), - variable_probability_strategy(**kwargs), - variable_sigmoid_strategy(**kwargs), - variable_class_strategy(**kwargs), - ) - ) + strategy = draw(st.sampled_from(elements)) + else: - return draw( - { - ValueType.BOOLEAN: variable_boolean_strategy, - ValueType.INTEGER: variable_integer_strategy, - ValueType.REAL: variable_real_strategy, - ValueType.SIGMOID: variable_sigmoid_strategy, - ValueType.PROBABILITY: variable_probability_strategy, - # ValueType.PROBABILITY_SAMPLE: variable_PROBABILITY_SAMPLE_strategy, - # ValueType.PROBABILITY_DISTRIBUTION: variable_PROBABILITY_DISTRIBUTION_strategy, - ValueType.CLASS: variable_class_strategy, - }[value_type](**kwargs) - ) + strategy = { + ValueType.BOOLEAN: variable_boolean_strategy, + ValueType.INTEGER: variable_integer_strategy, + ValueType.REAL: variable_real_strategy, + ValueType.SIGMOID: variable_sigmoid_strategy, + ValueType.PROBABILITY: variable_probability_strategy, + # ValueType.PROBABILITY_SAMPLE: variable_PROBABILITY_SAMPLE_strategy, + # ValueType.PROBABILITY_DISTRIBUTION: variable_PROBABILITY_DISTRIBUTION_strategy, + ValueType.CLASS: variable_class_strategy, + }[value_type] + return draw(strategy(**kwargs)) @given(variable_strategy()) @@ -292,6 +270,8 @@ def test_variable_strategy_creation(o): @st.composite def variablecollection_strategy( draw, + elements=VARIABLE_STRATEGIES, + value_type: Optional[ValueType] = None, max_ivs=5, max_dvs=1, max_covariates=2, @@ -317,13 +297,28 @@ def variablecollection_strategy( ) ) independent_variables = [ - draw(variable_strategy(name=names.pop(), **kwargs)) for _ in range(n_ivs) + draw( + variable_strategy( + name=names.pop(), value_type=value_type, elements=elements, **kwargs + ) + ) + for _ in range(n_ivs) ] dependent_variables = [ - draw(variable_strategy(name=names.pop(), **kwargs)) for _ in range(n_dvs) + draw( + variable_strategy( + name=names.pop(), value_type=value_type, elements=elements, **kwargs + ) + ) + for _ in range(n_dvs) ] covariates = [ - draw(variable_strategy(name=names.pop(), **kwargs)) for _ in range(n_covariates) + draw( + variable_strategy( + name=names.pop(), value_type=value_type, elements=elements, **kwargs + ) + ) + for _ in range(n_covariates) ] vc = VariableCollection( @@ -343,17 +338,19 @@ def test_variablecollection_strategy_creation(o): def dataframe_strategy( draw, variables: Optional[Sequence[Variable]] = None, + value_type: Optional[ValueType] = None, ): if variables is None: - variable_collection = draw(variablecollection_strategy()) + variable_collection = draw(variablecollection_strategy(value_type=value_type)) variables = ( variable_collection.independent_variables + variable_collection.dependent_variables + variable_collection.covariates ) + df: pd.DataFrame = draw( st_pd.data_frames( - columns=[st_pd.column(name=v.name, dtype=v.data_type) for v in variables], + columns=[st_pd.column(dtype=v.data_type) for v in variables], ) )