diff --git a/tests/test_strategies.py b/tests/test_strategies.py index 38503755..3f7350d6 100644 --- a/tests/test_strategies.py +++ b/tests/test_strategies.py @@ -43,19 +43,22 @@ FLOAT_STRATEGIES = st.one_of( st.integers(), - st.floats(min_value=0, max_value=1, allow_nan=False, allow_subnormal=False), st.floats(allow_infinity=False, allow_nan=False, allow_subnormal=False), st.floats(allow_infinity=True, allow_nan=False, allow_subnormal=False), st.floats(allow_infinity=True, allow_nan=False, allow_subnormal=True), ) +PROBABILITY_STRATEGIES = st.floats( + min_value=0, max_value=1, allow_nan=False, allow_subnormal=False +) + VALUE_TYPE_VALUE_STRATEGY_MAPPING = { ValueType.BOOLEAN: st.booleans(), ValueType.INTEGER: st.integers(), ValueType.REAL: FLOAT_STRATEGIES, ValueType.SIGMOID: FLOAT_STRATEGIES, - ValueType.PROBABILITY: FLOAT_STRATEGIES, - ValueType.PROBABILITY_SAMPLE: FLOAT_STRATEGIES, + ValueType.PROBABILITY: PROBABILITY_STRATEGIES, + ValueType.PROBABILITY_SAMPLE: PROBABILITY_STRATEGIES, ValueType.PROBABILITY_DISTRIBUTION: FLOAT_STRATEGIES, ValueType.CLASS: st.text(), } @@ -88,26 +91,41 @@ def variable_strategy( variable_label = draw(st.text(max_size=variable_label_max_length)) units = draw(st.text(max_size=units_max_length)) is_covariate = draw(st.booleans()) - type = draw(st.sampled_from(ValueType)) - dtype = VALUE_TYPE_DTYPE_MAPPING[type] - value_strategy = VALUE_TYPE_VALUE_STRATEGY_MAPPING[type] - - value_range = draw( - st.one_of( - st.none(), - st.tuples(value_strategy, value_strategy).map(sorted), + + value_type = draw(st.sampled_from(ValueType)) + + dtype = VALUE_TYPE_DTYPE_MAPPING[value_type] + value_strategy = VALUE_TYPE_VALUE_STRATEGY_MAPPING[value_type] + + if value_type is ValueType.BOOLEAN: + allowed_values = [True, False] + value_range = None + rescale = 1 + elif value_type in { + ValueType.PROBABILITY, + ValueType.PROBABILITY_SAMPLE, + ValueType.PROBABILITY_DISTRIBUTION, + }: + value_range = (0, 1) + allowed_values = None + rescale = 1 + else: + value_range = draw( + st.one_of( + st.none(), + st.tuples(value_strategy, value_strategy).map(sorted), + ) ) - ) - allowed_values = draw( - st.one_of(st.none(), st.lists(value_strategy, unique=True, min_size=1)) - ) - rescale = draw(st.one_of(st.just(1), value_strategy)) + allowed_values = draw( + st.one_of(st.none(), st.lists(value_strategy, unique=True, min_size=1)) + ) + rescale = draw(st.one_of(st.just(1), value_strategy)) v = Variable( name=name, variable_label=variable_label, units=units, - type=type, + type=value_type, is_covariate=is_covariate, value_range=value_range, allowed_values=allowed_values,