diff --git a/python_tests/test_api.py b/python_tests/test_api.py index bde7838a2..ce78da47a 100644 --- a/python_tests/test_api.py +++ b/python_tests/test_api.py @@ -399,6 +399,125 @@ def test_delete_study_not_found(self) -> None: ) self.assertEqual(status, 404) + def test_tell_trial_complete(self) -> None: + storage = optuna.storages.InMemoryStorage() + study = optuna.create_study(storage=storage) + trial_id = study.ask()._trial_id + + app = create_app(storage) + status, _, _ = send_request( + app, + f"/api/trials/{trial_id}/tell", + "POST", + body=json.dumps( + { + "state": "Complete", + "values": [0, 1, 2], + } + ), + content_type="application/json", + ) + self.assertEqual(status, 204) + trial = storage.get_trial(trial_id) + assert trial.state == optuna.trial.TrialState.COMPLETE + assert trial.values == [0, 1, 2] + + def test_tell_trial_fail(self) -> None: + storage = optuna.storages.InMemoryStorage() + study = optuna.create_study(storage=storage) + trial_id = study.ask()._trial_id + + app = create_app(storage) + status, _, _ = send_request( + app, + f"/api/trials/{trial_id}/tell", + "POST", + body=json.dumps( + { + "state": "Fail", + } + ), + content_type="application/json", + ) + self.assertEqual(status, 204) + trial = storage.get_trial(trial_id) + assert trial.state == optuna.trial.TrialState.FAIL + + def test_tell_trial_with_no_state(self) -> None: + storage = optuna.storages.InMemoryStorage() + study = optuna.create_study(storage=storage) + trial_id = study.ask()._trial_id + + app = create_app(storage) + status, _, _ = send_request( + app, + f"/api/trials/{trial_id}/tell", + "POST", + body=json.dumps({}), + content_type="application/json", + ) + self.assertEqual(status, 400) + + def test_tell_trial_with_invalid_state(self) -> None: + storage = optuna.storages.InMemoryStorage() + study = optuna.create_study(storage=storage) + for state in ["Pruned", "Running", "Waiting", "Invalid"]: + trial_id = study.ask()._trial_id + app = create_app(storage) + with self.subTest(state=state): + status, _, _ = send_request( + app, + f"/api/trials/{trial_id}/tell", + "POST", + body=json.dumps( + { + "state": state, + } + ), + content_type="application/json", + ) + self.assertEqual(status, 400) + + def test_tell_trial_with_no_values(self) -> None: + storage = optuna.storages.InMemoryStorage() + study = optuna.create_study(storage=storage) + trial_id = study.ask()._trial_id + + app = create_app(storage) + status, _, _ = send_request( + app, + f"/api/trials/{trial_id}/tell", + "POST", + body=json.dumps( + { + "state": "Complete", + } + ), + content_type="application/json", + ) + self.assertEqual(status, 400) + + def test_tell_trial_with_invalid_values(self) -> None: + storage = optuna.storages.InMemoryStorage() + study = optuna.create_study(storage=storage) + for values in [1.0, ["foo"]]: + trial_id = study.ask()._trial_id + app = create_app(storage) + with self.subTest(values=values): + status, _, _ = send_request( + app, + f"/api/trials/{trial_id}/tell", + "POST", + body=json.dumps( + { + "state": "Complete", + "values": values, + } + ), + content_type="application/json", + ) + self.assertEqual(status, 400) + class BottleRequestHookTestCase(TestCase): def test_ignore_trailing_slashes(self) -> None: