From 70ff947b74a8df4cb2a2a2773793bc7c8990a4da Mon Sep 17 00:00:00 2001 From: Christopher Bunn Date: Wed, 18 Jan 2023 14:58:32 -0500 Subject: [PATCH] Fixed `set_period()` not updating decomposer parameters (#3932) * Initial commit * Update release notes * Fix docstring * Added test case back --- docs/source/release_notes.rst | 1 + evalml/pipelines/components/component_base.py | 11 +++++++++++ .../transformers/preprocessing/decomposer.py | 2 +- .../decomposer_tests/test_decomposer.py | 3 +-- evalml/tests/component_tests/test_components.py | 17 +++++++++++++++++ 5 files changed, 31 insertions(+), 3 deletions(-) diff --git a/docs/source/release_notes.rst b/docs/source/release_notes.rst index fc94a266c9..99f27ecacd 100644 --- a/docs/source/release_notes.rst +++ b/docs/source/release_notes.rst @@ -3,6 +3,7 @@ Release Notes **Future Releases** * Enhancements * Fixes + * Fixed `set_period()` not updating decomposer parameters :pr:`3932` * Changes * Updated ``PolynomialDecomposer`` to work with sktime v0.15.1 :pr:`3930` * Pinned `category-encoders`` to 2.5.1.post0 :pr:`3933`` diff --git a/evalml/pipelines/components/component_base.py b/evalml/pipelines/components/component_base.py index 6a06065166..935dc09967 100644 --- a/evalml/pipelines/components/component_base.py +++ b/evalml/pipelines/components/component_base.py @@ -227,3 +227,14 @@ def __repr__(self): [f"{key}={safe_repr(value)}" for key, value in self.parameters.items()], ) return f"{(type(self).__name__)}({parameters_repr})" + + def update_parameters(self, update_dict, reset_fit=True): + """Updates the parameter dictionary of the component. + + Args: + update_dict (dict): A dict of parameters to update. + reset_fit (bool, optional): If True, will set `_is_fitted` to False. + """ + self._parameters.update(update_dict) + if reset_fit: + self._is_fitted = False diff --git a/evalml/pipelines/components/transformers/preprocessing/decomposer.py b/evalml/pipelines/components/transformers/preprocessing/decomposer.py index 2fc01f52d8..8466d01bc3 100644 --- a/evalml/pipelines/components/transformers/preprocessing/decomposer.py +++ b/evalml/pipelines/components/transformers/preprocessing/decomposer.py @@ -205,7 +205,7 @@ def set_period(self, X: pd.DataFrame, y: pd.Series): """ self.period = self.determine_periodicity(X, y) - self.parameters.update({"period": self.period}) + self.update_parameters({"period": self.period}) def _check_oos_past(self, y): """Function to check whether provided target data is out-of-sample and in the past.""" diff --git a/evalml/tests/component_tests/decomposer_tests/test_decomposer.py b/evalml/tests/component_tests/decomposer_tests/test_decomposer.py index 1edbf903d7..e8884e69e8 100644 --- a/evalml/tests/component_tests/decomposer_tests/test_decomposer.py +++ b/evalml/tests/component_tests/decomposer_tests/test_decomposer.py @@ -436,8 +436,7 @@ def test_decomposer_set_period(decomposer_child_class, period, generate_seasonal dec.set_period(X, y) assert 0.95 * period <= dec.period <= 1.05 * period - # TODO: Fix this with https://github.com/alteryx/evalml/issues/3771 - # assert dec.parameters["period"] == dec.period + assert dec.parameters["period"] == dec.period @pytest.mark.parametrize( diff --git a/evalml/tests/component_tests/test_components.py b/evalml/tests/component_tests/test_components.py index 164c9f07da..0f9650c020 100644 --- a/evalml/tests/component_tests/test_components.py +++ b/evalml/tests/component_tests/test_components.py @@ -570,6 +570,23 @@ def test_describe_component(): } +def test_update_parameters(X_y_binary): + X, y = X_y_binary + new_val = "New val" + + cmp = MockFitComponent() + cmp.fit(X, y) + cmp.update_parameters({"param_a": new_val}) + assert cmp.parameters["param_a"] == new_val + assert cmp._is_fitted is False + + cmp = MockFitComponent() + cmp.fit(X, y) + cmp.update_parameters({"param_b": new_val}, reset_fit=False) + assert cmp.parameters["param_b"] == new_val + assert cmp._is_fitted is True + + def test_missing_attributes(X_y_binary): class MockComponentName(ComponentBase): pass