Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Follow-up #411 to refactor type definitions #431

Merged
merged 26 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
8aadf8a
Implement human-in-the-loop via trial.user_attrs
knshnb Mar 6, 2023
61b1846
Remove unnecessary directions
knshnb Mar 10, 2023
c7221c1
Implement register_user_attr_form_widgets
knshnb Mar 10, 2023
a6a48b8
Fix rendering of user_attrs
knshnb Mar 14, 2023
adc95b6
Update user_attrs state after saveTrialUserAttrsAPI
knshnb Mar 14, 2023
5392b91
Merge remote-tracking branch 'upstream/main' into register-user-attr-…
knshnb Mar 16, 2023
e048b38
Format
knshnb Mar 16, 2023
a706421
Fix
knshnb Mar 16, 2023
1144f82
Use user_attr_key by register_user_attr_form_widgets
knshnb Mar 16, 2023
23c0814
Revert import format
knshnb Mar 16, 2023
6ab58fa
Fix type
knshnb Mar 16, 2023
e85c1fd
Fix type
knshnb Mar 16, 2023
8d16193
Fix import
knshnb Mar 20, 2023
79bf067
Display user_attr_key for register_user_attr_form_widgets
knshnb Mar 20, 2023
261a13d
Apply prettier
knshnb Mar 20, 2023
9c477cb
Add validation of `user_attr_key`
knshnb Mar 20, 2023
017ac67
Remove "Object" prefix from widget classes
knshnb Mar 20, 2023
77ed179
Remove default text form for no widgets
knshnb Mar 20, 2023
091774f
Validate user_attrs have at least one element
knshnb Mar 28, 2023
a5c7b59
Fix typo
knshnb Mar 28, 2023
6489092
Update form widgets key name
knshnb Mar 28, 2023
416c6c7
Include form_widgets_output_type in form_widgets
knshnb Apr 12, 2023
765b0d3
Apply formatter
knshnb Apr 12, 2023
031a1e1
Fix type
knshnb Apr 12, 2023
d6f3846
Follow-up #411 to refactor type definitions
c-bata Apr 13, 2023
ef947d6
Update api.rst
c-bata Apr 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ Human-in-the-loop
:nosignatures:

optuna_dashboard.register_objective_form_widgets
optuna_dashboard.ObjectiveChoiceWidget
optuna_dashboard.ObjectiveSliderWidget
optuna_dashboard.ObjectiveTextInputWidget
optuna_dashboard.ChoiceWidget
optuna_dashboard.SliderWidget
optuna_dashboard.TextInputWidget
optuna_dashboard.ObjectiveUserAttrRef

Artifact
Expand Down
14 changes: 9 additions & 5 deletions optuna_dashboard/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from ._app import run_server # noqa
from ._app import wsgi # noqa
from ._form_widget import ChoiceWidget # noqa
from ._form_widget import ObjectiveChoiceWidget # noqa
from ._form_widget import ObjectiveSliderWidget # noqa
from ._form_widget import ObjectiveTextInputWidget # noqa
from ._form_widget import ObjectiveUserAttrRef # noqa
from ._form_widget import register_objective_form_widgets # noqa
from ._form_widget import register_user_attr_form_widgets # noqa
from ._form_widget import SliderWidget # noqa
from ._form_widget import TextInputWidget # noqa
from ._named_objectives import set_objective_names # noqa
from ._note import save_note # noqa
from ._objective_form_widget import ObjectiveChoiceWidget # noqa
from ._objective_form_widget import ObjectiveSliderWidget # noqa
from ._objective_form_widget import ObjectiveTextInputWidget # noqa
from ._objective_form_widget import ObjectiveUserAttrRef # noqa
from ._objective_form_widget import register_objective_form_widgets # noqa


__version__ = "0.9.0"
18 changes: 18 additions & 0 deletions optuna_dashboard/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,24 @@ def tell_trial(trial_id: int) -> dict[str, Any]:
response.status = 204
return {}

@app.post("/api/trials/<trial_id:int>/user-attrs")
@json_api_view
def save_trial_user_attrs(trial_id: int) -> dict[str, Any]:
user_attrs = request.json.get("user_attrs", {})
if not user_attrs:
response.status = 400 # Bad request
return {"reason": "user_attrs must be specified."}

try:
for key, val in user_attrs.items():
storage.set_trial_user_attr(trial_id, key, val)
except Exception as e:
response.status = 500
return {"reason": f"Internal server error: {e}"}

response.status = 204
return {}

@app.put("/api/studies/<study_id:int>/<trial_id:int>/note")
@json_api_view
def save_trial_note(study_id: int, trial_id: int) -> dict[str, Any]:
Expand Down
183 changes: 183 additions & 0 deletions optuna_dashboard/_form_widget.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from __future__ import annotations

from dataclasses import dataclass
import json
from typing import TYPE_CHECKING
from typing import Union
import warnings

import optuna


if TYPE_CHECKING:
from typing import Any
from typing import Literal
from typing import Optional
from typing import TypedDict

ChoiceWidgetJSON = TypedDict(
"ChoiceWidgetJSON",
{
"type": Literal["choice"],
"description": Optional[str],
"choices": list[str],
"values": list[float],
"user_attr_key": Optional[str],
},
)
SliderWidgetLabel = TypedDict(
"SliderWidgetLabel",
{"value": float, "label": str},
)
SliderWidgetJSON = TypedDict(
"SliderWidgetJSON",
{
"type": Literal["slider"],
"description": Optional[str],
"min": float,
"max": float,
"step": Optional[float],
"labels": Optional[list[SliderWidgetLabel]],
"user_attr_key": Optional[str],
},
)
TextInputWidgetJSON = TypedDict(
"TextInputWidgetJSON",
{"type": Literal["text"], "description": Optional[str], "user_attr_key": Optional[str]},
)
UserAttrRefJSON = TypedDict(
"UserAttrRefJSON",
{"type": Literal["user_attr"], "key": str, "user_attr_key": Optional[str]},
)
FormWidgetJSON = TypedDict(
"FormWidgetJSON",
{
"output_type": Literal["objective", "user_attr"],
"widgets": list[
Union[ChoiceWidgetJSON, SliderWidgetJSON, TextInputWidgetJSON, UserAttrRefJSON]
],
},
)


@dataclass
class ChoiceWidget:
choices: list[str]
values: list[float]
description: Optional[str] = None
user_attr_key: Optional[str] = None

def to_dict(self) -> ChoiceWidgetJSON:
return {
"type": "choice",
"description": self.description,
"choices": self.choices,
"values": self.values,
"user_attr_key": self.user_attr_key,
}


@dataclass
class SliderWidget:
min: float
max: float
step: Optional[float] = None
labels: Optional[list[tuple[float, str]]] = None
description: Optional[str] = None
user_attr_key: Optional[str] = None

def to_dict(self) -> SliderWidgetJSON:
labels: Optional[list[SliderWidgetLabel]] = None
if self.labels is not None:
labels = [{"value": value, "label": label} for value, label in self.labels]
return {
"type": "slider",
"description": self.description,
"min": self.min,
"max": self.max,
"step": self.step,
"labels": labels,
"user_attr_key": self.user_attr_key,
}


@dataclass
class TextInputWidget:
description: Optional[str] = None
user_attr_key: Optional[str] = None

def to_dict(self) -> TextInputWidgetJSON:
return {
"type": "text",
"description": self.description,
"user_attr_key": self.user_attr_key,
}


@dataclass
class ObjectiveUserAttrRef:
key: str
user_attr_key: Optional[str] = None

def to_dict(self) -> UserAttrRefJSON:
return {
"type": "user_attr",
"key": self.key,
"user_attr_key": self.user_attr_key,
}


ObjectiveFormWidget = Union[ChoiceWidget, SliderWidget, TextInputWidget, ObjectiveUserAttrRef]
# For backward compatibility.
ObjectiveChoiceWidget = ChoiceWidget
ObjectiveSliderWidget = SliderWidget
ObjectiveTextInputWidget = TextInputWidget
FORM_WIDGETS_KEY = "dashboard:form_widgets:v2"


def register_objective_form_widgets(
study: optuna.Study, widgets: list[ObjectiveFormWidget]
) -> None:
if len(study.directions) != len(widgets):
raise ValueError("The length of actions must be the same with the number of objectives.")
if any(w.user_attr_key is not None for w in widgets):
warnings.warn("`user_attr_key` specified, but it will not be used.")
form_widgets: FormWidgetJSON = {
"output_type": "objective",
"widgets": [w.to_dict() for w in widgets],
}
study._storage.set_study_system_attr(study._study_id, FORM_WIDGETS_KEY, form_widgets)


def register_user_attr_form_widgets(
study: optuna.Study, widgets: list[ObjectiveFormWidget]
) -> None:
if any(w.user_attr_key is None for w in widgets):
raise ValueError("`user_attr_key` is not specified.")
if len(widgets) != len(set(w.user_attr_key for w in widgets)):
raise ValueError("`user_attr_key` must be unique for each widget.")
form_widgets: FormWidgetJSON = {
"output_type": "user_attr",
"widgets": [w.to_dict() for w in widgets],
}
study._storage.set_study_system_attr(study._study_id, FORM_WIDGETS_KEY, form_widgets)


def get_form_widgets_json(study_system_attr: dict[str, Any]) -> Optional[FormWidgetJSON]:
if FORM_WIDGETS_KEY in study_system_attr:
return study_system_attr[FORM_WIDGETS_KEY]

# For optuna-dashboard v0.9.0 and v0.9.0b6 users
if "dashboard:objective_form_widgets:v1" in study_system_attr:
return {
"output_type": "objective",
"widgets": study_system_attr["dashboard:objective_form_widgets:v1"],
}

# For optuna-dashboard v0.9.0b5 users
if "dashboard:objective_form_widgets" in study_system_attr:
return {
"output_type": "objective",
"widgets": json.loads(study_system_attr["dashboard:objective_form_widgets"]),
}
return None
133 changes: 0 additions & 133 deletions optuna_dashboard/_objective_form_widget.py

This file was deleted.

Loading