diff --git a/tests/integration/diag_scripts/mlr/test_custom_sklearn_functions.py b/tests/integration/diag_scripts/mlr/test_custom_sklearn_functions.py index f14ccb3e2e..b058465188 100644 --- a/tests/integration/diag_scripts/mlr/test_custom_sklearn_functions.py +++ b/tests/integration/diag_scripts/mlr/test_custom_sklearn_functions.py @@ -459,10 +459,18 @@ def _more_tags(self): return {"allow_nan": True} +class MockBaseEstimator: + """Estimator with ``_get_tags``.""" + + def _get_tags(self): + """Return tags.""" + return _DEFAULT_TAGS + + @pytest.mark.parametrize( 'estimator,err_msg', [ - (BaseEstimator(), 'The key xxx is not defined in _get_tags'), + (MockBaseEstimator(), 'The key xxx is not defined in _get_tags'), (NoTagsEstimator(), 'The key xxx is not defined in _DEFAULT_TAGS'), ], ) @@ -480,9 +488,8 @@ def test_safe_tags_error(estimator, err_msg): (NoTagsEstimator(), 'allow_nan', _DEFAULT_TAGS['allow_nan']), (MoreTagsEstimator(), None, {**_DEFAULT_TAGS, **{'allow_nan': True}}), (MoreTagsEstimator(), 'allow_nan', True), - (BaseEstimator(), None, _DEFAULT_TAGS), - (BaseEstimator(), 'allow_nan', _DEFAULT_TAGS['allow_nan']), - (BaseEstimator(), 'allow_nan', _DEFAULT_TAGS['allow_nan']), + (MockBaseEstimator(), None, _DEFAULT_TAGS), + (MockBaseEstimator(), 'allow_nan', _DEFAULT_TAGS['allow_nan']), ], ) def test_safe_tags_no_get_tags(estimator, key, expected_results):