Skip to content

Commit

Permalink
Merge pull request #701 from HideakiImamura/tests/add-unittests-for-t…
Browse files Browse the repository at this point in the history
…ell-trial

Add unittests for `tell_trial`
  • Loading branch information
c-bata authored Nov 17, 2023
2 parents 4741130 + 6d0fef3 commit ba92bd2
Showing 1 changed file with 119 additions and 0 deletions.
119 changes: 119 additions & 0 deletions python_tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,125 @@ def test_delete_study_not_found(self) -> None:
)
self.assertEqual(status, 404)

def test_tell_trial_complete(self) -> None:
storage = optuna.storages.InMemoryStorage()
study = optuna.create_study(storage=storage)
trial_id = study.ask()._trial_id

app = create_app(storage)
status, _, _ = send_request(
app,
f"/api/trials/{trial_id}/tell",
"POST",
body=json.dumps(
{
"state": "Complete",
"values": [0, 1, 2],
}
),
content_type="application/json",
)
self.assertEqual(status, 204)
trial = storage.get_trial(trial_id)
assert trial.state == optuna.trial.TrialState.COMPLETE
assert trial.values == [0, 1, 2]

def test_tell_trial_fail(self) -> None:
storage = optuna.storages.InMemoryStorage()
study = optuna.create_study(storage=storage)
trial_id = study.ask()._trial_id

app = create_app(storage)
status, _, _ = send_request(
app,
f"/api/trials/{trial_id}/tell",
"POST",
body=json.dumps(
{
"state": "Fail",
}
),
content_type="application/json",
)
self.assertEqual(status, 204)
trial = storage.get_trial(trial_id)
assert trial.state == optuna.trial.TrialState.FAIL

def test_tell_trial_with_no_state(self) -> None:
storage = optuna.storages.InMemoryStorage()
study = optuna.create_study(storage=storage)
trial_id = study.ask()._trial_id

app = create_app(storage)
status, _, _ = send_request(
app,
f"/api/trials/{trial_id}/tell",
"POST",
body=json.dumps({}),
content_type="application/json",
)
self.assertEqual(status, 400)

def test_tell_trial_with_invalid_state(self) -> None:
storage = optuna.storages.InMemoryStorage()
study = optuna.create_study(storage=storage)
for state in ["Pruned", "Running", "Waiting", "Invalid"]:
trial_id = study.ask()._trial_id
app = create_app(storage)
with self.subTest(state=state):
status, _, _ = send_request(
app,
f"/api/trials/{trial_id}/tell",
"POST",
body=json.dumps(
{
"state": state,
}
),
content_type="application/json",
)
self.assertEqual(status, 400)

def test_tell_trial_with_no_values(self) -> None:
storage = optuna.storages.InMemoryStorage()
study = optuna.create_study(storage=storage)
trial_id = study.ask()._trial_id

app = create_app(storage)
status, _, _ = send_request(
app,
f"/api/trials/{trial_id}/tell",
"POST",
body=json.dumps(
{
"state": "Complete",
}
),
content_type="application/json",
)
self.assertEqual(status, 400)

def test_tell_trial_with_invalid_values(self) -> None:
storage = optuna.storages.InMemoryStorage()
study = optuna.create_study(storage=storage)
for values in [1.0, ["foo"]]:
trial_id = study.ask()._trial_id
app = create_app(storage)
with self.subTest(values=values):
status, _, _ = send_request(
app,
f"/api/trials/{trial_id}/tell",
"POST",
body=json.dumps(
{
"state": "Complete",
"values": values,
}
),
content_type="application/json",
)
self.assertEqual(status, 400)


class BottleRequestHookTestCase(TestCase):
def test_ignore_trailing_slashes(self) -> None:
Expand Down

0 comments on commit ba92bd2

Please sign in to comment.