diff --git a/physbo/search/discrete/policy.py b/physbo/search/discrete/policy.py index 4a62bdd9..a2b8b1f3 100644 --- a/physbo/search/discrete/policy.py +++ b/physbo/search/discrete/policy.py @@ -129,6 +129,9 @@ def write( time_run_simulator=time_run_simulator, ) self.training.add(X=X, t=t, Z=Z) + local_index = np.searchsorted(self.actions, action) + local_index = local_index[np.take(self.actions, local_index, mode='clip') == action] + self.actions = self._delete_actions(local_index) if self.new_data is None: self.new_data = variable(X=X, t=t, Z=Z) else: diff --git a/physbo/search/discrete_multi/policy.py b/physbo/search/discrete_multi/policy.py index 35402939..9b245841 100644 --- a/physbo/search/discrete_multi/policy.py +++ b/physbo/search/discrete_multi/policy.py @@ -99,6 +99,9 @@ def write( else: self.new_data_list[i].add(X=X, t=t[:, i], Z=Z) self.training_list[i].add(X=X, t=t[:, i], Z=Z) + local_index = np.searchsorted(self.actions, action) + local_index = local_index[np.take(self.actions, local_index, mode='clip') == action] + self.actions = self._delete_actions(local_index) def _model(self, i): training = self.training_list[i] diff --git a/tests/unit/test_policy.py b/tests/unit/test_policy.py index ec778bdc..16d8d7c4 100644 --- a/tests/unit/test_policy.py +++ b/tests/unit/test_policy.py @@ -42,6 +42,15 @@ def policy(): return physbo.search.discrete.policy(test_X=X) +def test_write(policy, X): + simulator = lambda x: 1.0 + ACTIONS = np.array([0, 1], np.int32) + + policy.write(ACTIONS, np.apply_along_axis(simulator, 1, X[ACTIONS])) + numpy.testing.assert_array_equal(ACTIONS, policy.history.chosen_actions[:len(ACTIONS)]) + assert len(X) - len(ACTIONS) == len(policy.actions) + + def test_randomsearch(policy, mocker): simulator = mocker.MagicMock(return_value=1.0) write_spy = mocker.spy(physbo.search.discrete.policy, "write")