Skip to content

Commit

Permalink
Fixed set_period() not updating decomposer parameters (#3932)
Browse files Browse the repository at this point in the history
* Initial commit

* Update release notes

* Fix docstring

* Added test case back
  • Loading branch information
christopherbunn authored Jan 18, 2023
1 parent 16a1c86 commit 70ff947
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/source/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Release Notes
**Future Releases**
* Enhancements
* Fixes
* Fixed `set_period()` not updating decomposer parameters :pr:`3932`
* Changes
* Updated ``PolynomialDecomposer`` to work with sktime v0.15.1 :pr:`3930`
* Pinned `category-encoders`` to 2.5.1.post0 :pr:`3933``
Expand Down
11 changes: 11 additions & 0 deletions evalml/pipelines/components/component_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,14 @@ def __repr__(self):
[f"{key}={safe_repr(value)}" for key, value in self.parameters.items()],
)
return f"{(type(self).__name__)}({parameters_repr})"

def update_parameters(self, update_dict, reset_fit=True):
"""Updates the parameter dictionary of the component.
Args:
update_dict (dict): A dict of parameters to update.
reset_fit (bool, optional): If True, will set `_is_fitted` to False.
"""
self._parameters.update(update_dict)
if reset_fit:
self._is_fitted = False
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def set_period(self, X: pd.DataFrame, y: pd.Series):
"""
self.period = self.determine_periodicity(X, y)
self.parameters.update({"period": self.period})
self.update_parameters({"period": self.period})

def _check_oos_past(self, y):
"""Function to check whether provided target data is out-of-sample and in the past."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,7 @@ def test_decomposer_set_period(decomposer_child_class, period, generate_seasonal
dec.set_period(X, y)

assert 0.95 * period <= dec.period <= 1.05 * period
# TODO: Fix this with https://github.com/alteryx/evalml/issues/3771
# assert dec.parameters["period"] == dec.period
assert dec.parameters["period"] == dec.period


@pytest.mark.parametrize(
Expand Down
17 changes: 17 additions & 0 deletions evalml/tests/component_tests/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,23 @@ def test_describe_component():
}


def test_update_parameters(X_y_binary):
X, y = X_y_binary
new_val = "New val"

cmp = MockFitComponent()
cmp.fit(X, y)
cmp.update_parameters({"param_a": new_val})
assert cmp.parameters["param_a"] == new_val
assert cmp._is_fitted is False

cmp = MockFitComponent()
cmp.fit(X, y)
cmp.update_parameters({"param_b": new_val}, reset_fit=False)
assert cmp.parameters["param_b"] == new_val
assert cmp._is_fitted is True


def test_missing_attributes(X_y_binary):
class MockComponentName(ComponentBase):
pass
Expand Down

0 comments on commit 70ff947

Please sign in to comment.