Skip to content

Commit

Permalink
Merge pull request #702 from toshihikoyanase/add-save-trial-user-attr…
Browse files Browse the repository at this point in the history
…s-tests

Add unit tests for `save_trial_user_attrs`
  • Loading branch information
c-bata authored Nov 17, 2023
2 parents ba92bd2 + c1c6adc commit cd2159e
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions python_tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,47 @@ def test_change_component(self) -> None:
assert study_detail["feedback_component_type"]["output_type"] == "artifact"
assert study_detail["feedback_component_type"]["artifact_key"] == "image"

def test_save_trial_user_attrs(self) -> None:
study = optuna.create_study()
trials: list[optuna.Trial] = []
for _ in range(2):
trial = study.ask()
trials.append(trial)

request_body = {
"user_attrs": {
"number": 0,
},
}

app = create_app(study._storage)
status, _, _ = send_request(
app,
f"/api/trials/{trials[0]._trial_id}/user-attrs",
"POST",
content_type="application/json",
body=json.dumps(request_body),
)
self.assertEqual(status, 204)

assert study.trials[0].user_attrs == request_body["user_attrs"]
assert study.trials[1].user_attrs == {}

def test_save_trial_user_attrs_empty(self) -> None:
study = optuna.create_study()
trial = study.ask()

app = create_app(study._storage)
status, _, _ = send_request(
app,
f"/api/trials/{trial._trial_id}/user-attrs",
"POST",
content_type="application/json",
body=json.dumps({}),
)
self.assertEqual(status, 400)
assert study.trials[0].user_attrs == {}

@pytest.mark.skipif(sys.version_info < (3, 8), reason="BoTorch dropped Python3.7 support")
def test_skip_trial(self) -> None:
storage = optuna.storages.InMemoryStorage()
Expand Down

0 comments on commit cd2159e

Please sign in to comment.