Skip to content

Commit

Permalink
Refactor form widgets output type
Browse files Browse the repository at this point in the history
  • Loading branch information
c-bata committed Apr 6, 2023
1 parent 6489092 commit d0cb33e
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 29 deletions.
6 changes: 0 additions & 6 deletions optuna_dashboard/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,11 +358,6 @@ def get_study_detail(study_id: int) -> dict[str, Any]:
union_user_attrs,
has_intermediate_values,
) = get_cached_extra_study_property(study_id, trials)
form_widgets_output_type = storage.get_study_system_attrs(study_id).get(
FORM_WIDGETS_OUTPUT_TYPE_KEY
)
if form_widgets_output_type is None:
form_widgets_output_type = "objective"
return serialize_study_detail(
summary,
best_trials,
Expand All @@ -371,7 +366,6 @@ def get_study_detail(study_id: int) -> dict[str, Any]:
union,
union_user_attrs,
has_intermediate_values,
form_widgets_output_type,
)

@app.get("/api/studies/<study_id:int>/param_importances")
Expand Down
50 changes: 31 additions & 19 deletions optuna_dashboard/_objective_form_widget.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
import json
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import Union
import warnings
Expand Down Expand Up @@ -49,9 +49,13 @@
"UserAttrRefJSON",
{"type": Literal["user_attr"], "key": str, "user_attr_key": Optional[str]},
)
ObjectiveFormWidgetJSON = Union[
ChoiceWidgetJSON, SliderWidgetJSON, TextInputWidgetJSON, UserAttrRefJSON
]
FormWidgetJSON = TypedDict(
"FormWidgetJSON",
{
"output_type": Literal["objective", "user_attr"],
"widgets": list[Union[ChoiceWidgetJSON, SliderWidgetJSON, TextInputWidgetJSON, UserAttrRefJSON]]
}
)


@dataclass
Expand Down Expand Up @@ -127,7 +131,6 @@ def to_dict(self) -> UserAttrRefJSON:
ObjectiveSliderWidget = SliderWidget
ObjectiveTextInputWidget = TextInputWidget
FORM_WIDGETS_KEY = "dashboard:form_widgets:v2"
FORM_WIDGETS_OUTPUT_TYPE_KEY = "dashboard:form_widgets_output_type:v2"


def register_objective_form_widgets(
Expand All @@ -137,11 +140,11 @@ def register_objective_form_widgets(
raise ValueError("The length of actions must be the same with the number of objectives.")
if any(w.user_attr_key is not None for w in widgets):
warnings.warn("`user_attr_key` specified, but it will not be used.")
widget_dicts = [w.to_dict() for w in widgets]
study._storage.set_study_system_attr(study._study_id, FORM_WIDGETS_KEY, widget_dicts)
study._storage.set_study_system_attr(
study._study_id, FORM_WIDGETS_OUTPUT_TYPE_KEY, "objective"
)
form_widgets: FormWidgetJSON = {
"output_type": "objective",
"widgets": [w.to_dict() for w in widgets],
}
study._storage.set_study_system_attr(study._study_id, FORM_WIDGETS_KEY, form_widgets)


def register_user_attr_form_widgets(
Expand All @@ -151,19 +154,28 @@ def register_user_attr_form_widgets(
raise ValueError("`user_attr_key` is not specified.")
if len(widgets) != len(set(w.user_attr_key for w in widgets)):
raise ValueError("`user_attr_key` must be unique for each widget.")
widget_dicts = [w.to_dict() for w in widgets]
study._storage.set_study_system_attr(study._study_id, FORM_WIDGETS_KEY, widget_dicts)
study._storage.set_study_system_attr(
study._study_id, FORM_WIDGETS_OUTPUT_TYPE_KEY, "user_attr"
)
form_widgets: FormWidgetJSON = {
"output_type": "user_attr",
"widgets": [w.to_dict() for w in widgets],
}
study._storage.set_study_system_attr(study._study_id, FORM_WIDGETS_KEY, form_widgets)


def get_objective_form_widgets_json(
study_system_attr: dict[str, Any]
) -> Optional[list[ObjectiveFormWidgetJSON]]:
def get_form_widgets_json(study_system_attr: dict[str, Any]) -> Optional[FormWidgetJSON]:
if FORM_WIDGETS_KEY in study_system_attr:
return study_system_attr[FORM_WIDGETS_KEY]

# For optuna-dashboard v0.9.0 and v0.9.0b6 users
if "dashboard:objective_form_widgets:v1" in study_system_attr:
return {
"output_type": "objective",
"widgets": study_system_attr["dashboard:objective_form_widgets:v1"]
}

# For optuna-dashboard v0.9.0b5 users
if "dashboard:objective_form_widgets" in study_system_attr:
return json.loads(study_system_attr["dashboard:objective_form_widgets"])
return {
"output_type": "objective",
"widgets": json.loads(study_system_attr["dashboard:objective_form_widgets"])
}
return None
6 changes: 2 additions & 4 deletions optuna_dashboard/_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from . import _note as note
from ._named_objectives import get_objective_names
from ._objective_form_widget import get_objective_form_widgets_json
from ._objective_form_widget import get_form_widgets_json
from .artifact._backend import list_trial_artifacts


Expand Down Expand Up @@ -122,7 +122,6 @@ def serialize_study_detail(
union: list[tuple[str, BaseDistribution]],
union_user_attrs: list[tuple[str, bool]],
has_intermediate_values: bool,
form_widgets_output_type: Optional[str],
) -> dict[str, Any]:
serialized: dict[str, Any] = {
"name": summary.study_name,
Expand All @@ -146,10 +145,9 @@ def serialize_study_detail(
objective_names = get_objective_names(system_attrs)
if objective_names:
serialized["objective_names"] = objective_names
objective_form_widgets = get_objective_form_widgets_json(system_attrs)
objective_form_widgets = get_form_widgets_json(system_attrs)
if objective_form_widgets:
serialized["objective_form_widgets"] = objective_form_widgets
serialized["form_widgets_output_type"] = form_widgets_output_type
return serialized


Expand Down

0 comments on commit d0cb33e

Please sign in to comment.