From 8aadf8aab3d70cade4ce7dd946dcb1ef510663eb Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Mon, 6 Mar 2023 21:34:01 +0900 Subject: [PATCH 01/23] Implement human-in-the-loop via trial.user_attrs --- optuna_dashboard/_app.py | 17 +++++++++++ optuna_dashboard/ts/action.ts | 28 +++++++++++++++++++ optuna_dashboard/ts/apiClient.ts | 13 +++++++++ .../ts/components/ObjectiveForm.tsx | 7 ++--- 4 files changed, 60 insertions(+), 5 deletions(-) diff --git a/optuna_dashboard/_app.py b/optuna_dashboard/_app.py index d91710f66..c521775a2 100644 --- a/optuna_dashboard/_app.py +++ b/optuna_dashboard/_app.py @@ -447,6 +447,23 @@ def tell_trial(trial_id: int) -> dict[str, Any]: response.status = 204 return {} + @app.post("/api/trials//user-attrs") + @json_api_view + def save_trial_user_attrs(trial_id: int) -> dict[str, Any]: + if "user_attrs" not in request.json: + response.status = 400 # Bad request + return {"reason": "user_attrs must be specified."} + + try: # TODO(knshnb): Proper error handling. + for key, val in request.json.get("user_attrs").items(): + storage.set_trial_user_attr(trial_id, key, val) + except Exception as e: + response.status = 500 + return {"reason": f"Internal server error: {e}"} + + response.status = 204 + return {} + @app.put("/api/studies///note") @json_api_view def save_trial_note(study_id: int, trial_id: int) -> dict[str, Any]: diff --git a/optuna_dashboard/ts/action.ts b/optuna_dashboard/ts/action.ts index e8f2aa35c..1a4ba8705 100644 --- a/optuna_dashboard/ts/action.ts +++ b/optuna_dashboard/ts/action.ts @@ -13,6 +13,7 @@ import { uploadArtifactAPI, getMetaInfoAPI, deleteArtifactAPI, + saveTrialUserAttrsAPI, } from "./apiClient" import { graphVisibilityState, @@ -482,6 +483,32 @@ export const actionCreator = () => { }) } + const saveTrialUserAttrs = ( + studyId: number, + trialId: number, + user_attrs: {[key: string]: number}, + ): void => { + console.log("user_attrs", user_attrs) + // TODO(knshnb): Update rendering of `user_attrs`. + const message = `id=${trialId}, user_attrs=${user_attrs}` + saveTrialUserAttrsAPI(trialId, user_attrs) + .then(() => { + // TODO(knshnb): Update states. + enqueueSnackbar(`Successfully updated trial (${message})`, { + variant: "success", + }) + }) + .catch((err) => { + const reason = err.response?.data.reason + enqueueSnackbar( + `Failed to update trial (${message}). Reason: ${reason}`, + { + variant: "error", + } + ) + console.log(err) + }) + } return { updateAPIMeta, updateStudyDetail, @@ -499,6 +526,7 @@ export const actionCreator = () => { uploadArtifact, deleteArtifact, tellTrial, + saveTrialUserAttrs, } } diff --git a/optuna_dashboard/ts/apiClient.ts b/optuna_dashboard/ts/apiClient.ts index e263117c7..2cd32c530 100644 --- a/optuna_dashboard/ts/apiClient.ts +++ b/optuna_dashboard/ts/apiClient.ts @@ -281,6 +281,19 @@ export const tellTrialAPI = ( }) } +export const saveTrialUserAttrsAPI = ( + trialId: number, + user_attrs: { [key: string]: number } +): Promise => { + const req = { user_attrs: user_attrs } + + return axiosInstance + .post(`/api/trials/${trialId}/user-attrs`, req) + .then((res) => { + return + }) +} + interface ParamImportancesResponse { param_importances: ParamImportance[][] } diff --git a/optuna_dashboard/ts/components/ObjectiveForm.tsx b/optuna_dashboard/ts/components/ObjectiveForm.tsx index ae8b2ae88..e2bca3de0 100644 --- a/optuna_dashboard/ts/components/ObjectiveForm.tsx +++ b/optuna_dashboard/ts/components/ObjectiveForm.tsx @@ -65,11 +65,8 @@ export const ObjectiveForm: FC<{ const handleSubmit = (e: React.MouseEvent): void => { e.preventDefault() - const filtered = values.filter((v): v is number => v !== null) - if (filtered.length !== directions.length) { - return - } - action.tellTrial(trial.study_id, trial.trial_id, "Complete", filtered) + const user_attrs = Object.fromEntries(widgets.map((widget, i) => [widget.description, values[i]])) + action.saveTrialUserAttrs(trial.study_id, trial.trial_id, user_attrs) } const getObjectiveName = (i: number): string => { From 61b18460db36ca36962be72e5d58d0efe3c3e8c3 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Fri, 10 Mar 2023 16:11:19 +0900 Subject: [PATCH 02/23] Remove unnecessary directions --- optuna_dashboard/ts/components/ObjectiveForm.tsx | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/optuna_dashboard/ts/components/ObjectiveForm.tsx b/optuna_dashboard/ts/components/ObjectiveForm.tsx index e2bca3de0..6433bbe7b 100644 --- a/optuna_dashboard/ts/components/ObjectiveForm.tsx +++ b/optuna_dashboard/ts/components/ObjectiveForm.tsx @@ -25,8 +25,7 @@ export const ObjectiveForm: FC<{ const theme = useTheme() const action = actionCreator() const [values, setValues] = useState<(number | null)[]>( - directions.map((d, i) => { - const widget = widgets.at(i) + widgets.map(widget => { if (widget === undefined) { return null } else if (widget.type === "text") { @@ -99,8 +98,7 @@ export const ObjectiveForm: FC<{ p: theme.spacing(1), }} > - {directions.map((d, i) => { - const widget = widgets.at(i) + {widgets.map((widget, i) => { const value = values.at(i) const key = `objective-${i}` if (widget === undefined) { @@ -309,8 +307,7 @@ export const ReadonlyObjectiveForm: FC<{ p: theme.spacing(1), }} > - {directions.map((d, i) => { - const widget = widgets.at(i) + {widgets.map((widget, i) => { const key = `objective-${i}` if (widget === undefined) { return ( From c7221c1ea4d4d85500f0a3c4a3d2381aadaea2db Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Fri, 10 Mar 2023 16:18:08 +0900 Subject: [PATCH 03/23] Implement register_user_attr_form_widgets --- optuna_dashboard/__init__.py | 1 + optuna_dashboard/_app.py | 3 +++ optuna_dashboard/_objective_form_widget.py | 10 ++++++++++ optuna_dashboard/_serializer.py | 3 +++ optuna_dashboard/ts/apiClient.ts | 2 ++ optuna_dashboard/ts/components/ObjectiveForm.tsx | 15 ++++++++++++--- optuna_dashboard/ts/components/TrialList.tsx | 6 ++++++ optuna_dashboard/ts/types/index.d.ts | 1 + 8 files changed, 38 insertions(+), 3 deletions(-) diff --git a/optuna_dashboard/__init__.py b/optuna_dashboard/__init__.py index 7b273ed17..cf75f553c 100644 --- a/optuna_dashboard/__init__.py +++ b/optuna_dashboard/__init__.py @@ -7,6 +7,7 @@ from ._objective_form_widget import ObjectiveTextInputWidget # noqa from ._objective_form_widget import ObjectiveUserAttrRef # noqa from ._objective_form_widget import register_objective_form_widgets # noqa +from ._objective_form_widget import register_user_attr_form_widgets # noqa __version__ = "0.9.0b6" diff --git a/optuna_dashboard/_app.py b/optuna_dashboard/_app.py index c521775a2..723896ccd 100644 --- a/optuna_dashboard/_app.py +++ b/optuna_dashboard/_app.py @@ -34,6 +34,7 @@ from ._bottle_util import json_api_view from ._cached_extra_study_property import get_cached_extra_study_property from ._importance import get_param_importance_from_trials_cache +from ._objective_form_widget import SYSTEM_ATTR_OUTPUT_TYPE_KEY from ._pareto_front import get_pareto_front_trials from ._serializer import serialize_study_detail from ._serializer import serialize_study_summary @@ -357,6 +358,7 @@ def get_study_detail(study_id: int) -> dict[str, Any]: union_user_attrs, has_intermediate_values, ) = get_cached_extra_study_property(study_id, trials) + form_widgets_output_type = storage.get_study_system_attrs(study_id).get(SYSTEM_ATTR_OUTPUT_TYPE_KEY) return serialize_study_detail( summary, best_trials, @@ -365,6 +367,7 @@ def get_study_detail(study_id: int) -> dict[str, Any]: union, union_user_attrs, has_intermediate_values, + form_widgets_output_type, ) @app.get("/api/studies//param_importances") diff --git a/optuna_dashboard/_objective_form_widget.py b/optuna_dashboard/_objective_form_widget.py index 3f19631ee..ea3d737ca 100644 --- a/optuna_dashboard/_objective_form_widget.py +++ b/optuna_dashboard/_objective_form_widget.py @@ -111,6 +111,7 @@ def to_dict(self) -> UserAttrRefJSON: ObjectiveChoiceWidget, ObjectiveSliderWidget, ObjectiveTextInputWidget, ObjectiveUserAttrRef ] SYSTEM_ATTR_KEY = "dashboard:objective_form_widgets:v1" +SYSTEM_ATTR_OUTPUT_TYPE_KEY = "dashboard:form_widgets_output_type:v1" def register_objective_form_widgets( @@ -120,6 +121,15 @@ def register_objective_form_widgets( raise ValueError("The length of actions must be the same with the number of objectives.") widget_dicts = [w.to_dict() for w in widgets] study._storage.set_study_system_attr(study._study_id, SYSTEM_ATTR_KEY, widget_dicts) + study._storage.set_study_system_attr(study._study_id, SYSTEM_ATTR_OUTPUT_TYPE_KEY, "objective") + + +def register_user_attr_form_widgets( + study: optuna.Study, widgets: list[ObjectiveFormWidget] +) -> None: + widget_dicts = [w.to_dict() for w in widgets] + study._storage.set_study_system_attr(study._study_id, SYSTEM_ATTR_KEY, widget_dicts) + study._storage.set_study_system_attr(study._study_id, SYSTEM_ATTR_OUTPUT_TYPE_KEY, "user_attr") def get_objective_form_widgets_json( diff --git a/optuna_dashboard/_serializer.py b/optuna_dashboard/_serializer.py index 021705703..b5ef1024c 100644 --- a/optuna_dashboard/_serializer.py +++ b/optuna_dashboard/_serializer.py @@ -2,6 +2,7 @@ import json from typing import Any +from typing import Optional from typing import TYPE_CHECKING from typing import Union @@ -121,6 +122,7 @@ def serialize_study_detail( union: list[tuple[str, BaseDistribution]], union_user_attrs: list[tuple[str, bool]], has_intermediate_values: bool, + form_widgets_output_type: Optional[str], ) -> dict[str, Any]: serialized: dict[str, Any] = { "name": summary.study_name, @@ -147,6 +149,7 @@ def serialize_study_detail( objective_form_widgets = get_objective_form_widgets_json(system_attrs) if objective_form_widgets: serialized["objective_form_widgets"] = objective_form_widgets + serialized["form_widgets_output_type"] = form_widgets_output_type return serialized diff --git a/optuna_dashboard/ts/apiClient.ts b/optuna_dashboard/ts/apiClient.ts index 2cd32c530..849bf2871 100644 --- a/optuna_dashboard/ts/apiClient.ts +++ b/optuna_dashboard/ts/apiClient.ts @@ -68,6 +68,7 @@ interface StudyDetailResponse { note: Note objective_names?: string[] objective_form_widgets?: ObjectiveFormWidget[] + form_widgets_output_type?: string } export const getStudyDetailAPI = ( @@ -101,6 +102,7 @@ export const getStudyDetailAPI = ( note: res.data.note, objective_names: res.data.objective_names, objective_form_widgets: res.data.objective_form_widgets, + form_widgets_output_type: res.data.form_widgets_output_type, } }) } diff --git a/optuna_dashboard/ts/components/ObjectiveForm.tsx b/optuna_dashboard/ts/components/ObjectiveForm.tsx index 6433bbe7b..704a79165 100644 --- a/optuna_dashboard/ts/components/ObjectiveForm.tsx +++ b/optuna_dashboard/ts/components/ObjectiveForm.tsx @@ -21,7 +21,8 @@ export const ObjectiveForm: FC<{ directions: StudyDirection[] names: string[] widgets: ObjectiveFormWidget[] -}> = ({ trial, directions, names, widgets }) => { + outputType: string +}> = ({ trial, directions, names, widgets, outputType }) => { const theme = useTheme() const action = actionCreator() const [values, setValues] = useState<(number | null)[]>( @@ -64,8 +65,16 @@ export const ObjectiveForm: FC<{ const handleSubmit = (e: React.MouseEvent): void => { e.preventDefault() - const user_attrs = Object.fromEntries(widgets.map((widget, i) => [widget.description, values[i]])) - action.saveTrialUserAttrs(trial.study_id, trial.trial_id, user_attrs) + if (outputType == "objective") { + const filtered = values.filter((v): v is number => v !== null) + if (filtered.length !== directions.length) { + return + } + action.tellTrial(trial.study_id, trial.trial_id, "Complete", filtered) + } else if (outputType == "user_attr") { + const user_attrs = Object.fromEntries(widgets.map((widget, i) => [widget.description, values[i]])) + action.saveTrialUserAttrs(trial.study_id, trial.trial_id, user_attrs) + } } const getObjectiveName = (i: number): string => { diff --git a/optuna_dashboard/ts/components/TrialList.tsx b/optuna_dashboard/ts/components/TrialList.tsx index 7ef89d737..85ba3d687 100644 --- a/optuna_dashboard/ts/components/TrialList.tsx +++ b/optuna_dashboard/ts/components/TrialList.tsx @@ -144,12 +144,14 @@ const TrialListDetail: FC<{ directions: StudyDirection[] objectiveNames: string[] objectiveFormWidgets: ObjectiveFormWidget[] + formWigetsOutputType: string }> = ({ trial, isBestTrial, directions, objectiveNames, objectiveFormWidgets, + formWigetsOutputType }) => { const theme = useTheme() const artifactEnabled = useRecoilValue(artifactIsAvailable) @@ -297,6 +299,7 @@ const TrialListDetail: FC<{ directions={directions} names={objectiveNames} widgets={objectiveFormWidgets} + outputType={formWigetsOutputType} /> )} {trial.state === "Complete" && directions.length > 0 && ( @@ -816,6 +819,9 @@ export const TrialList: FC<{ studyDetail: StudyDetail | null }> = ({ objectiveFormWidgets={ studyDetail?.objective_form_widgets || [] } + formWigetsOutputType={ + studyDetail?.form_widgets_output_type || "" + } /> ))} diff --git a/optuna_dashboard/ts/types/index.d.ts b/optuna_dashboard/ts/types/index.d.ts index 1d30bc574..87b33a299 100644 --- a/optuna_dashboard/ts/types/index.d.ts +++ b/optuna_dashboard/ts/types/index.d.ts @@ -176,6 +176,7 @@ type StudyDetail = { note: Note objective_names?: string[] objective_form_widgets?: ObjectiveFormWidget[] + form_widgets_output_type?: string } type StudyDetails = { From a6a48b8e99ba38cf3bb43c882b14733b9b50b9da Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Tue, 14 Mar 2023 16:36:22 +0900 Subject: [PATCH 04/23] Fix rendering of user_attrs --- optuna_dashboard/ts/action.ts | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/optuna_dashboard/ts/action.ts b/optuna_dashboard/ts/action.ts index 1a4ba8705..43967b01e 100644 --- a/optuna_dashboard/ts/action.ts +++ b/optuna_dashboard/ts/action.ts @@ -489,8 +489,7 @@ export const actionCreator = () => { user_attrs: {[key: string]: number}, ): void => { console.log("user_attrs", user_attrs) - // TODO(knshnb): Update rendering of `user_attrs`. - const message = `id=${trialId}, user_attrs=${user_attrs}` + const message = `id=${trialId}, user_attrs=${JSON.stringify(user_attrs)}` saveTrialUserAttrsAPI(trialId, user_attrs) .then(() => { // TODO(knshnb): Update states. From adc95b6251773cabceb0a7866902334d0310e5ad Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Tue, 14 Mar 2023 17:35:18 +0900 Subject: [PATCH 05/23] Update user_attrs state after saveTrialUserAttrsAPI --- optuna_dashboard/ts/action.ts | 59 +++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 24 deletions(-) diff --git a/optuna_dashboard/ts/action.ts b/optuna_dashboard/ts/action.ts index 43967b01e..76889e77f 100644 --- a/optuna_dashboard/ts/action.ts +++ b/optuna_dashboard/ts/action.ts @@ -1,30 +1,15 @@ -import { useRecoilState, useSetRecoilState } from "recoil" import { useSnackbar } from "notistack" +import { useRecoilState, useSetRecoilState } from "recoil" import { - getStudyDetailAPI, - getStudySummariesAPI, - getParamImportances, - createNewStudyAPI, - deleteStudyAPI, - saveStudyNoteAPI, - saveTrialNoteAPI, - tellTrialAPI, - renameStudyAPI, - uploadArtifactAPI, - getMetaInfoAPI, - deleteArtifactAPI, - saveTrialUserAttrsAPI, + createNewStudyAPI, deleteArtifactAPI, deleteStudyAPI, getMetaInfoAPI, getParamImportances, getStudyDetailAPI, + getStudySummariesAPI, renameStudyAPI, saveStudyNoteAPI, + saveTrialNoteAPI, saveTrialUserAttrsAPI, tellTrialAPI, uploadArtifactAPI } from "./apiClient" +import { getDominatedTrials } from "./dominatedTrials" import { - graphVisibilityState, - studyDetailsState, - studySummariesState, - paramImportanceState, - isFileUploading, - artifactIsAvailable, - reloadIntervalState, + artifactIsAvailable, graphVisibilityState, isFileUploading, paramImportanceState, reloadIntervalState, studyDetailsState, + studySummariesState } from "./state" -import { getDominatedTrials } from "./dominatedTrials" const localStorageGraphVisibility = "graphVisibility" const localStorageReloadInterval = "reloadInterval" @@ -152,6 +137,23 @@ export const actionCreator = () => { setStudyDetailState(studyId, newStudy) } + const setTrialUserAttrs = ( + studyId: number, + index: number, + user_attrs: { [key: string]: number }, + ) => { + const newTrial: Trial = Object.assign( + {}, + studyDetails[studyId].trials[index] + ) + newTrial.user_attrs = Object.keys(user_attrs).map(key => ({ key: key, value: user_attrs[key].toString() })) + const newTrials: Trial[] = [...studyDetails[studyId].trials] + newTrials[index] = newTrial + const newStudy: StudyDetail = Object.assign({}, studyDetails[studyId]) + newStudy.trials = newTrials + setStudyDetailState(studyId, newStudy) + } + const setStudyParamImportanceState = ( studyId: number, importance: ParamImportance[][] @@ -486,13 +488,22 @@ export const actionCreator = () => { const saveTrialUserAttrs = ( studyId: number, trialId: number, - user_attrs: {[key: string]: number}, + user_attrs: { [key: string]: number }, ): void => { console.log("user_attrs", user_attrs) const message = `id=${trialId}, user_attrs=${JSON.stringify(user_attrs)}` saveTrialUserAttrsAPI(trialId, user_attrs) .then(() => { - // TODO(knshnb): Update states. + const index = studyDetails[studyId].trials.findIndex( + (t) => t.trial_id === trialId + ) + if (index === -1) { + enqueueSnackbar(`Unexpected error happens. Please reload the page.`, { + variant: "error", + }) + return + } + setTrialUserAttrs(studyId, index, user_attrs) enqueueSnackbar(`Successfully updated trial (${message})`, { variant: "success", }) From e048b38b5d7e0221e1a17fb7a74d3cf837a77bf6 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Thu, 16 Mar 2023 17:29:46 +0900 Subject: [PATCH 06/23] Format --- optuna_dashboard/_app.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/optuna_dashboard/_app.py b/optuna_dashboard/_app.py index 723896ccd..54d07fa43 100644 --- a/optuna_dashboard/_app.py +++ b/optuna_dashboard/_app.py @@ -358,7 +358,11 @@ def get_study_detail(study_id: int) -> dict[str, Any]: union_user_attrs, has_intermediate_values, ) = get_cached_extra_study_property(study_id, trials) - form_widgets_output_type = storage.get_study_system_attrs(study_id).get(SYSTEM_ATTR_OUTPUT_TYPE_KEY) + form_widgets_output_type = storage.get_study_system_attrs(study_id).get( + SYSTEM_ATTR_OUTPUT_TYPE_KEY + ) + if form_widgets_output_type is None: + form_widgets_output_type = "objective" return serialize_study_detail( summary, best_trials, From a70642105c5fb6e76bd6ff0012a1b2d9c02aee61 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Thu, 16 Mar 2023 17:38:54 +0900 Subject: [PATCH 07/23] Fix --- optuna_dashboard/ts/components/ObjectiveForm.tsx | 1 + optuna_dashboard/ts/components/TrialList.tsx | 1 + 2 files changed, 2 insertions(+) diff --git a/optuna_dashboard/ts/components/ObjectiveForm.tsx b/optuna_dashboard/ts/components/ObjectiveForm.tsx index 40b869a9f..92ad178ee 100644 --- a/optuna_dashboard/ts/components/ObjectiveForm.tsx +++ b/optuna_dashboard/ts/components/ObjectiveForm.tsx @@ -285,6 +285,7 @@ export const ReadonlyObjectiveForm: FC<{ directions: StudyDirection[] names: string[] widgets: ObjectiveFormWidget[] + outputType: string }> = ({ trial, directions, names, widgets }) => { const theme = useTheme() const getObjectiveName = (i: number): string => { diff --git a/optuna_dashboard/ts/components/TrialList.tsx b/optuna_dashboard/ts/components/TrialList.tsx index 0f1a51da1..06d3227e2 100644 --- a/optuna_dashboard/ts/components/TrialList.tsx +++ b/optuna_dashboard/ts/components/TrialList.tsx @@ -308,6 +308,7 @@ const TrialListDetail: FC<{ directions={directions} names={objectiveNames} widgets={objectiveFormWidgets} + outputType={formWigetsOutputType} /> )} {artifactEnabled && } From 1144f820765d34c3b7a3674176072061957bc7f5 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Thu, 16 Mar 2023 17:41:24 +0900 Subject: [PATCH 08/23] Use user_attr_key by register_user_attr_form_widgets --- optuna_dashboard/_objective_form_widget.py | 16 ++++++++++++++-- optuna_dashboard/ts/components/ObjectiveForm.tsx | 2 +- optuna_dashboard/ts/types/index.d.ts | 4 ++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/optuna_dashboard/_objective_form_widget.py b/optuna_dashboard/_objective_form_widget.py index ea3d737ca..8b3375549 100644 --- a/optuna_dashboard/_objective_form_widget.py +++ b/optuna_dashboard/_objective_form_widget.py @@ -21,6 +21,7 @@ "description": Optional[str], "choices": list[str], "values": list[float], + "user_attr_key": Optional[str], }, ) SliderWidgetLabel = TypedDict( @@ -36,13 +37,16 @@ "max": float, "step": Optional[float], "labels": Optional[list[SliderWidgetLabel]], + "user_attr_key": Optional[str], }, ) TextInputWidgetJSON = TypedDict( "TextInputWidgetJSON", - {"type": Literal["text"], "description": Optional[str]}, + {"type": Literal["text"], "description": Optional[str], "user_attr_key": Optional[str]}, + ) + UserAttrRefJSON = TypedDict( + "UserAttrRefJSON", {"type": Literal["user_attr"], "user_attr_key": str} ) - UserAttrRefJSON = TypedDict("UserAttrRefJSON", {"type": Literal["user_attr"], "key": str}) ObjectiveFormWidgetJSON = Union[ ChoiceWidgetJSON, SliderWidgetJSON, TextInputWidgetJSON, UserAttrRefJSON ] @@ -53,6 +57,7 @@ class ObjectiveChoiceWidget: choices: list[str] values: list[float] description: Optional[str] = None + user_attr_key: Optional[str] = None def to_dict(self) -> ChoiceWidgetJSON: return { @@ -60,6 +65,7 @@ def to_dict(self) -> ChoiceWidgetJSON: "description": self.description, "choices": self.choices, "values": self.values, + "user_attr_key": self.user_attr_key, } @@ -70,6 +76,7 @@ class ObjectiveSliderWidget: step: Optional[float] = None labels: Optional[list[tuple[float, str]]] = None description: Optional[str] = None + user_attr_key: Optional[str] = None def to_dict(self) -> SliderWidgetJSON: labels: Optional[list[SliderWidgetLabel]] = None @@ -82,28 +89,33 @@ def to_dict(self) -> SliderWidgetJSON: "max": self.max, "step": self.step, "labels": labels, + "user_attr_key": self.user_attr_key, } @dataclass class ObjectiveTextInputWidget: description: Optional[str] = None + user_attr_key: Optional[str] = None def to_dict(self) -> TextInputWidgetJSON: return { "type": "text", "description": self.description, + "user_attr_key": self.user_attr_key, } @dataclass class ObjectiveUserAttrRef: key: str + user_attr_key: Optional[str] = None def to_dict(self) -> UserAttrRefJSON: return { "type": "user_attr", "key": self.key, + "user_attr_key": self.user_attr_key, } diff --git a/optuna_dashboard/ts/components/ObjectiveForm.tsx b/optuna_dashboard/ts/components/ObjectiveForm.tsx index 92ad178ee..6214e7405 100644 --- a/optuna_dashboard/ts/components/ObjectiveForm.tsx +++ b/optuna_dashboard/ts/components/ObjectiveForm.tsx @@ -72,7 +72,7 @@ export const ObjectiveForm: FC<{ } action.makeTrialComplete(trial.study_id, trial.trial_id, filtered) } else if (outputType == "user_attr") { - const user_attrs = Object.fromEntries(widgets.map((widget, i) => [widget.description, values[i]])) + const user_attrs = Object.fromEntries(widgets.map((widget, i) => [widget.user_attr_key, values[i]])) action.saveTrialUserAttrs(trial.study_id, trial.trial_id, user_attrs) } } diff --git a/optuna_dashboard/ts/types/index.d.ts b/optuna_dashboard/ts/types/index.d.ts index 87b33a299..4410691af 100644 --- a/optuna_dashboard/ts/types/index.d.ts +++ b/optuna_dashboard/ts/types/index.d.ts @@ -128,6 +128,7 @@ type StudySummary = { type ObjectiveChoiceWidget = { type: "choice" description: string + user_attr_key?: string choices: string[] values: number[] } @@ -135,6 +136,7 @@ type ObjectiveChoiceWidget = { type ObjectiveSliderWidget = { type: "slider" description: string + user_attr_key?: string min: number max: number step: number @@ -149,11 +151,13 @@ type ObjectiveSliderWidget = { type ObjectiveTextInputWidget = { type: "text" description: string + user_attr_key?: string } type ObjectiveUserAttrRef = { type: "user_attr" key: string + user_attr_key?: string } type ObjectiveFormWidget = From 23c0814d71b179f08bbf7f07205afe2dddc6b461 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Thu, 16 Mar 2023 17:50:28 +0900 Subject: [PATCH 09/23] Revert import format --- optuna_dashboard/ts/action.ts | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/optuna_dashboard/ts/action.ts b/optuna_dashboard/ts/action.ts index 3090948cf..e6b99f2a1 100644 --- a/optuna_dashboard/ts/action.ts +++ b/optuna_dashboard/ts/action.ts @@ -1,15 +1,29 @@ -import { useSnackbar } from "notistack" import { useRecoilState, useSetRecoilState } from "recoil" +import { useSnackbar } from "notistack" import { - createNewStudyAPI, deleteArtifactAPI, deleteStudyAPI, getMetaInfoAPI, getParamImportances, getStudyDetailAPI, - getStudySummariesAPI, renameStudyAPI, saveStudyNoteAPI, - saveTrialNoteAPI, saveTrialUserAttrsAPI, tellTrialAPI, uploadArtifactAPI + getStudyDetailAPI, + getStudySummariesAPI, + getParamImportances, + createNewStudyAPI, + deleteStudyAPI, + saveStudyNoteAPI, + saveTrialNoteAPI, + tellTrialAPI, + renameStudyAPI, + uploadArtifactAPI, + getMetaInfoAPI, + deleteArtifactAPI, } from "./apiClient" -import { getDominatedTrials } from "./dominatedTrials" import { - artifactIsAvailable, graphVisibilityState, isFileUploading, paramImportanceState, reloadIntervalState, studyDetailsState, - studySummariesState + graphVisibilityState, + studyDetailsState, + studySummariesState, + paramImportanceState, + isFileUploading, + artifactIsAvailable, + reloadIntervalState, } from "./state" +import { getDominatedTrials } from "./dominatedTrials" const localStorageGraphVisibility = "graphVisibility" const localStorageReloadInterval = "reloadInterval" From 6ab58fafa232e6813f6b9048570ca2973aa55f83 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Thu, 16 Mar 2023 18:47:58 +0900 Subject: [PATCH 10/23] Fix type --- optuna_dashboard/_objective_form_widget.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optuna_dashboard/_objective_form_widget.py b/optuna_dashboard/_objective_form_widget.py index 8b3375549..126eef764 100644 --- a/optuna_dashboard/_objective_form_widget.py +++ b/optuna_dashboard/_objective_form_widget.py @@ -45,7 +45,7 @@ {"type": Literal["text"], "description": Optional[str], "user_attr_key": Optional[str]}, ) UserAttrRefJSON = TypedDict( - "UserAttrRefJSON", {"type": Literal["user_attr"], "user_attr_key": str} + "UserAttrRefJSON", {"type": Literal["user_attr"], "user_attr_key": Optional[str]} ) ObjectiveFormWidgetJSON = Union[ ChoiceWidgetJSON, SliderWidgetJSON, TextInputWidgetJSON, UserAttrRefJSON From e85c1fdea998652131abdb5f079d74f57ac1575d Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Thu, 16 Mar 2023 18:54:10 +0900 Subject: [PATCH 11/23] Fix type --- optuna_dashboard/_objective_form_widget.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/optuna_dashboard/_objective_form_widget.py b/optuna_dashboard/_objective_form_widget.py index 126eef764..7c6940d96 100644 --- a/optuna_dashboard/_objective_form_widget.py +++ b/optuna_dashboard/_objective_form_widget.py @@ -45,7 +45,8 @@ {"type": Literal["text"], "description": Optional[str], "user_attr_key": Optional[str]}, ) UserAttrRefJSON = TypedDict( - "UserAttrRefJSON", {"type": Literal["user_attr"], "user_attr_key": Optional[str]} + "UserAttrRefJSON", + {"type": Literal["user_attr"], "key": str, "user_attr_key": Optional[str]}, ) ObjectiveFormWidgetJSON = Union[ ChoiceWidgetJSON, SliderWidgetJSON, TextInputWidgetJSON, UserAttrRefJSON From 8d161935ec7c46256dce0df89f1eb58fb7c3eb44 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Mon, 20 Mar 2023 14:49:48 +0900 Subject: [PATCH 12/23] Fix import --- optuna_dashboard/ts/action.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/optuna_dashboard/ts/action.ts b/optuna_dashboard/ts/action.ts index e6b99f2a1..b9fb6a61e 100644 --- a/optuna_dashboard/ts/action.ts +++ b/optuna_dashboard/ts/action.ts @@ -9,6 +9,7 @@ import { saveStudyNoteAPI, saveTrialNoteAPI, tellTrialAPI, + saveTrialUserAttrsAPI, renameStudyAPI, uploadArtifactAPI, getMetaInfoAPI, From 79bf067747206876a554ac738a0e3251011cbd25 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Mon, 20 Mar 2023 15:24:48 +0900 Subject: [PATCH 13/23] Display user_attr_key for register_user_attr_form_widgets --- .../ts/components/ObjectiveForm.tsx | 70 +++++++++++-------- 1 file changed, 40 insertions(+), 30 deletions(-) diff --git a/optuna_dashboard/ts/components/ObjectiveForm.tsx b/optuna_dashboard/ts/components/ObjectiveForm.tsx index 6214e7405..aaec1226d 100644 --- a/optuna_dashboard/ts/components/ObjectiveForm.tsx +++ b/optuna_dashboard/ts/components/ObjectiveForm.tsx @@ -77,16 +77,21 @@ export const ObjectiveForm: FC<{ } } - const getObjectiveName = (i: number): string => { - const n = names.at(i) - if (n !== undefined) { - return n - } - if (directions.length == 1) { - return `Objective` - } else { - return `Objective ${i}` + const getMetricName = (i: number): string => { + if (outputType == "objective") { + const n = names.at(i) + if (n !== undefined) { + return n + } + if (directions.length == 1) { + return `Objective` + } else { + return `Objective ${i}` + } + } else if (outputType == "user_attr") { + return widgets[i].user_attr_key as string } + return "Unkown metric name" } return ( @@ -113,7 +118,7 @@ export const ObjectiveForm: FC<{ if (widget === undefined) { return ( - {getObjectiveName(i)} + {getMetricName(i)} { const n = Number(s) @@ -133,7 +138,7 @@ export const ObjectiveForm: FC<{ value === null || value === undefined ? `Please input the float number.` : "", - label: getObjectiveName(i), + label: getMetricName(i), type: "text", }} /> @@ -143,7 +148,7 @@ export const ObjectiveForm: FC<{ return ( - {getObjectiveName(i)} - {widget.description} + {getMetricName(i)} - {widget.description} { @@ -176,7 +181,7 @@ export const ObjectiveForm: FC<{ return ( - {getObjectiveName(i)} - {widget.description} + {getMetricName(i)} - {widget.description} {widget.choices.map((c, j) => ( @@ -206,7 +211,7 @@ export const ObjectiveForm: FC<{ return ( - {getObjectiveName(i)} - {widget.description} + {getMetricName(i)} - {widget.description} - {getObjectiveName(i)} + {getMetricName(i)} = ({ trial, directions, names, widgets }) => { +}> = ({ trial, directions, names, widgets, outputType }) => { const theme = useTheme() - const getObjectiveName = (i: number): string => { - const n = names.at(i) - if (n !== undefined) { - return n - } - if (directions.length == 1) { - return `Objective` - } else { - return `Objective ${i}` + const getMetricName = (i: number): string => { + if (outputType == "objective") { + const n = names.at(i) + if (n !== undefined) { + return n + } + if (directions.length == 1) { + return `Objective` + } else { + return `Objective ${i}` + } + } else if (outputType == "user_attr") { + return widgets[i].user_attr_key as string } + return "Unkown metric name" } return ( <> @@ -322,7 +332,7 @@ export const ReadonlyObjectiveForm: FC<{ if (widget === undefined) { return ( - {getObjectiveName(i)} + {getMetricName(i)} - {getObjectiveName(i)} - {widget.description} + {getMetricName(i)} - {widget.description} - {getObjectiveName(i)} - {widget.description} + {getMetricName(i)} - {widget.description} {widget.choices.map((c, j) => ( @@ -370,7 +380,7 @@ export const ReadonlyObjectiveForm: FC<{ return ( - {getObjectiveName(i)} - {widget.description} + {getMetricName(i)} - {widget.description} - {getObjectiveName(i)} + {getMetricName(i)} Date: Mon, 20 Mar 2023 15:27:31 +0900 Subject: [PATCH 14/23] Apply prettier --- optuna_dashboard/ts/action.ts | 9 ++++++--- optuna_dashboard/ts/components/ObjectiveForm.tsx | 6 ++++-- optuna_dashboard/ts/components/TrialList.tsx | 2 +- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/optuna_dashboard/ts/action.ts b/optuna_dashboard/ts/action.ts index b9fb6a61e..657adbd5e 100644 --- a/optuna_dashboard/ts/action.ts +++ b/optuna_dashboard/ts/action.ts @@ -155,13 +155,16 @@ export const actionCreator = () => { const setTrialUserAttrs = ( studyId: number, index: number, - user_attrs: { [key: string]: number }, + user_attrs: { [key: string]: number } ) => { const newTrial: Trial = Object.assign( {}, studyDetails[studyId].trials[index] ) - newTrial.user_attrs = Object.keys(user_attrs).map(key => ({ key: key, value: user_attrs[key].toString() })) + newTrial.user_attrs = Object.keys(user_attrs).map((key) => ({ + key: key, + value: user_attrs[key].toString(), + })) const newTrials: Trial[] = [...studyDetails[studyId].trials] newTrials[index] = newTrial const newStudy: StudyDetail = Object.assign({}, studyDetails[studyId]) @@ -529,7 +532,7 @@ export const actionCreator = () => { const saveTrialUserAttrs = ( studyId: number, trialId: number, - user_attrs: { [key: string]: number }, + user_attrs: { [key: string]: number } ): void => { console.log("user_attrs", user_attrs) const message = `id=${trialId}, user_attrs=${JSON.stringify(user_attrs)}` diff --git a/optuna_dashboard/ts/components/ObjectiveForm.tsx b/optuna_dashboard/ts/components/ObjectiveForm.tsx index aaec1226d..379550ab8 100644 --- a/optuna_dashboard/ts/components/ObjectiveForm.tsx +++ b/optuna_dashboard/ts/components/ObjectiveForm.tsx @@ -26,7 +26,7 @@ export const ObjectiveForm: FC<{ const theme = useTheme() const action = actionCreator() const [values, setValues] = useState<(number | null)[]>( - widgets.map(widget => { + widgets.map((widget) => { if (widget === undefined) { return null } else if (widget.type === "text") { @@ -72,7 +72,9 @@ export const ObjectiveForm: FC<{ } action.makeTrialComplete(trial.study_id, trial.trial_id, filtered) } else if (outputType == "user_attr") { - const user_attrs = Object.fromEntries(widgets.map((widget, i) => [widget.user_attr_key, values[i]])) + const user_attrs = Object.fromEntries( + widgets.map((widget, i) => [widget.user_attr_key, values[i]]) + ) action.saveTrialUserAttrs(trial.study_id, trial.trial_id, user_attrs) } } diff --git a/optuna_dashboard/ts/components/TrialList.tsx b/optuna_dashboard/ts/components/TrialList.tsx index 06d3227e2..7192d942c 100644 --- a/optuna_dashboard/ts/components/TrialList.tsx +++ b/optuna_dashboard/ts/components/TrialList.tsx @@ -151,7 +151,7 @@ const TrialListDetail: FC<{ directions, objectiveNames, objectiveFormWidgets, - formWigetsOutputType + formWigetsOutputType, }) => { const theme = useTheme() const artifactEnabled = useRecoilValue(artifactIsAvailable) From 9c477cb3afade0c8423b9fa82f762695c5d37674 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Mon, 20 Mar 2023 15:39:18 +0900 Subject: [PATCH 15/23] Add validation of `user_attr_key` --- optuna_dashboard/_objective_form_widget.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/optuna_dashboard/_objective_form_widget.py b/optuna_dashboard/_objective_form_widget.py index 7c6940d96..732350cac 100644 --- a/optuna_dashboard/_objective_form_widget.py +++ b/optuna_dashboard/_objective_form_widget.py @@ -4,6 +4,7 @@ import json from typing import TYPE_CHECKING from typing import Union +import warnings import optuna @@ -132,6 +133,8 @@ def register_objective_form_widgets( ) -> None: if len(study.directions) != len(widgets): raise ValueError("The length of actions must be the same with the number of objectives.") + if any(w.user_attr_key is not None for w in widgets): + warnings.warn("`user_attr_key` specified, but it will not be used.") widget_dicts = [w.to_dict() for w in widgets] study._storage.set_study_system_attr(study._study_id, SYSTEM_ATTR_KEY, widget_dicts) study._storage.set_study_system_attr(study._study_id, SYSTEM_ATTR_OUTPUT_TYPE_KEY, "objective") @@ -140,6 +143,10 @@ def register_objective_form_widgets( def register_user_attr_form_widgets( study: optuna.Study, widgets: list[ObjectiveFormWidget] ) -> None: + if any(w.user_attr_key is None for w in widgets): + raise ValueError("`user_attr_key` is not specified.") + if len(widgets) != len(set(w.user_attr_key for w in widgets)): + raise ValueError("`user_attr_key` must be unique for each widget.") widget_dicts = [w.to_dict() for w in widgets] study._storage.set_study_system_attr(study._study_id, SYSTEM_ATTR_KEY, widget_dicts) study._storage.set_study_system_attr(study._study_id, SYSTEM_ATTR_OUTPUT_TYPE_KEY, "user_attr") From 017ac675e55a07471d28e1527298eac1a36e6fb5 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Mon, 20 Mar 2023 15:50:10 +0900 Subject: [PATCH 16/23] Remove "Object" prefix from widget classes --- optuna_dashboard/__init__.py | 3 +++ optuna_dashboard/_objective_form_widget.py | 14 ++++++++------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/optuna_dashboard/__init__.py b/optuna_dashboard/__init__.py index efe3e1d19..33687ea09 100644 --- a/optuna_dashboard/__init__.py +++ b/optuna_dashboard/__init__.py @@ -2,12 +2,15 @@ from ._app import wsgi # noqa from ._named_objectives import set_objective_names # noqa from ._note import save_note # noqa +from ._objective_form_widget import ChoiceWidget # noqa from ._objective_form_widget import ObjectiveChoiceWidget # noqa from ._objective_form_widget import ObjectiveSliderWidget # noqa from ._objective_form_widget import ObjectiveTextInputWidget # noqa from ._objective_form_widget import ObjectiveUserAttrRef # noqa from ._objective_form_widget import register_objective_form_widgets # noqa from ._objective_form_widget import register_user_attr_form_widgets # noqa +from ._objective_form_widget import SliderWidget # noqa +from ._objective_form_widget import TextInputWidget # noqa __version__ = "0.9.0" diff --git a/optuna_dashboard/_objective_form_widget.py b/optuna_dashboard/_objective_form_widget.py index 732350cac..548be60fc 100644 --- a/optuna_dashboard/_objective_form_widget.py +++ b/optuna_dashboard/_objective_form_widget.py @@ -55,7 +55,7 @@ @dataclass -class ObjectiveChoiceWidget: +class ChoiceWidget: choices: list[str] values: list[float] description: Optional[str] = None @@ -72,7 +72,7 @@ def to_dict(self) -> ChoiceWidgetJSON: @dataclass -class ObjectiveSliderWidget: +class SliderWidget: min: float max: float step: Optional[float] = None @@ -96,7 +96,7 @@ def to_dict(self) -> SliderWidgetJSON: @dataclass -class ObjectiveTextInputWidget: +class TextInputWidget: description: Optional[str] = None user_attr_key: Optional[str] = None @@ -121,9 +121,11 @@ def to_dict(self) -> UserAttrRefJSON: } -ObjectiveFormWidget = Union[ - ObjectiveChoiceWidget, ObjectiveSliderWidget, ObjectiveTextInputWidget, ObjectiveUserAttrRef -] +ObjectiveFormWidget = Union[ChoiceWidget, SliderWidget, TextInputWidget, ObjectiveUserAttrRef] +# For backward compatibility. +ObjectiveChoiceWidget = ChoiceWidget +ObjectiveSliderWidget = SliderWidget +ObjectiveTextInputWidget = TextInputWidget SYSTEM_ATTR_KEY = "dashboard:objective_form_widgets:v1" SYSTEM_ATTR_OUTPUT_TYPE_KEY = "dashboard:form_widgets_output_type:v1" From 77ed17968eae8b693d9d7b8d0b23b441eaaae81c Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Mon, 20 Mar 2023 16:40:00 +0900 Subject: [PATCH 17/23] Remove default text form for no widgets --- .../ts/components/ObjectiveForm.tsx | 43 +------------------ 1 file changed, 2 insertions(+), 41 deletions(-) diff --git a/optuna_dashboard/ts/components/ObjectiveForm.tsx b/optuna_dashboard/ts/components/ObjectiveForm.tsx index 379550ab8..d198e318e 100644 --- a/optuna_dashboard/ts/components/ObjectiveForm.tsx +++ b/optuna_dashboard/ts/components/ObjectiveForm.tsx @@ -117,36 +117,7 @@ export const ObjectiveForm: FC<{ {widgets.map((widget, i) => { const value = values.at(i) const key = `objective-${i}` - if (widget === undefined) { - return ( - - {getMetricName(i)} - { - const n = Number(s) - if (s.length > 0 && valid && !isNaN(n)) { - setValue(i, n) - return - } else if (values.at(i) !== null) { - setValue(i, null) - } - }} - delay={500} - textFieldProps={{ - required: true, - autoFocus: true, - fullWidth: true, - helperText: - value === null || value === undefined - ? `Please input the float number.` - : "", - label: getMetricName(i), - type: "text", - }} - /> - - ) - } else if (widget.type === "text") { + if (widget.type === "text") { return ( @@ -331,17 +302,7 @@ export const ReadonlyObjectiveForm: FC<{ > {widgets.map((widget, i) => { const key = `objective-${i}` - if (widget === undefined) { - return ( - - {getMetricName(i)} - - - ) - } else if (widget.type === "text") { + if (widget.type === "text") { return ( From 091774f0291541cf0ddf5973213d77aee25bd0a2 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Tue, 28 Mar 2023 12:16:46 +0900 Subject: [PATCH 18/23] Validate user_attrs have at least one element Co-authored-by: Masashi Shibata --- optuna_dashboard/_app.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/optuna_dashboard/_app.py b/optuna_dashboard/_app.py index 54d07fa43..da2dded4c 100644 --- a/optuna_dashboard/_app.py +++ b/optuna_dashboard/_app.py @@ -457,12 +457,13 @@ def tell_trial(trial_id: int) -> dict[str, Any]: @app.post("/api/trials//user-attrs") @json_api_view def save_trial_user_attrs(trial_id: int) -> dict[str, Any]: - if "user_attrs" not in request.json: + user_attrs = requests.json.get("user_attrs", {}) + if not user_attrs: response.status = 400 # Bad request return {"reason": "user_attrs must be specified."} - try: # TODO(knshnb): Proper error handling. - for key, val in request.json.get("user_attrs").items(): + try: + for key, val in user_attrs.items(): storage.set_trial_user_attr(trial_id, key, val) except Exception as e: response.status = 500 From a5c7b59dfa05ee376e5e9e5b2b63c10c81c077d6 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Tue, 28 Mar 2023 12:21:24 +0900 Subject: [PATCH 19/23] Fix typo --- optuna_dashboard/_app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optuna_dashboard/_app.py b/optuna_dashboard/_app.py index da2dded4c..32313983e 100644 --- a/optuna_dashboard/_app.py +++ b/optuna_dashboard/_app.py @@ -457,7 +457,7 @@ def tell_trial(trial_id: int) -> dict[str, Any]: @app.post("/api/trials//user-attrs") @json_api_view def save_trial_user_attrs(trial_id: int) -> dict[str, Any]: - user_attrs = requests.json.get("user_attrs", {}) + user_attrs = request.json.get("user_attrs", {}) if not user_attrs: response.status = 400 # Bad request return {"reason": "user_attrs must be specified."} From 6489092a9a15c212114303e6e17f7dc72641b79d Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Tue, 28 Mar 2023 17:40:39 +0900 Subject: [PATCH 20/23] Update form widgets key name --- optuna_dashboard/_app.py | 4 ++-- optuna_dashboard/_objective_form_widget.py | 20 ++++++++++++-------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/optuna_dashboard/_app.py b/optuna_dashboard/_app.py index 32313983e..ee3cb09dc 100644 --- a/optuna_dashboard/_app.py +++ b/optuna_dashboard/_app.py @@ -34,7 +34,7 @@ from ._bottle_util import json_api_view from ._cached_extra_study_property import get_cached_extra_study_property from ._importance import get_param_importance_from_trials_cache -from ._objective_form_widget import SYSTEM_ATTR_OUTPUT_TYPE_KEY +from ._objective_form_widget import FORM_WIDGETS_OUTPUT_TYPE_KEY from ._pareto_front import get_pareto_front_trials from ._serializer import serialize_study_detail from ._serializer import serialize_study_summary @@ -359,7 +359,7 @@ def get_study_detail(study_id: int) -> dict[str, Any]: has_intermediate_values, ) = get_cached_extra_study_property(study_id, trials) form_widgets_output_type = storage.get_study_system_attrs(study_id).get( - SYSTEM_ATTR_OUTPUT_TYPE_KEY + FORM_WIDGETS_OUTPUT_TYPE_KEY ) if form_widgets_output_type is None: form_widgets_output_type = "objective" diff --git a/optuna_dashboard/_objective_form_widget.py b/optuna_dashboard/_objective_form_widget.py index 548be60fc..8f29b64ee 100644 --- a/optuna_dashboard/_objective_form_widget.py +++ b/optuna_dashboard/_objective_form_widget.py @@ -126,8 +126,8 @@ def to_dict(self) -> UserAttrRefJSON: ObjectiveChoiceWidget = ChoiceWidget ObjectiveSliderWidget = SliderWidget ObjectiveTextInputWidget = TextInputWidget -SYSTEM_ATTR_KEY = "dashboard:objective_form_widgets:v1" -SYSTEM_ATTR_OUTPUT_TYPE_KEY = "dashboard:form_widgets_output_type:v1" +FORM_WIDGETS_KEY = "dashboard:form_widgets:v2" +FORM_WIDGETS_OUTPUT_TYPE_KEY = "dashboard:form_widgets_output_type:v2" def register_objective_form_widgets( @@ -138,8 +138,10 @@ def register_objective_form_widgets( if any(w.user_attr_key is not None for w in widgets): warnings.warn("`user_attr_key` specified, but it will not be used.") widget_dicts = [w.to_dict() for w in widgets] - study._storage.set_study_system_attr(study._study_id, SYSTEM_ATTR_KEY, widget_dicts) - study._storage.set_study_system_attr(study._study_id, SYSTEM_ATTR_OUTPUT_TYPE_KEY, "objective") + study._storage.set_study_system_attr(study._study_id, FORM_WIDGETS_KEY, widget_dicts) + study._storage.set_study_system_attr( + study._study_id, FORM_WIDGETS_OUTPUT_TYPE_KEY, "objective" + ) def register_user_attr_form_widgets( @@ -150,15 +152,17 @@ def register_user_attr_form_widgets( if len(widgets) != len(set(w.user_attr_key for w in widgets)): raise ValueError("`user_attr_key` must be unique for each widget.") widget_dicts = [w.to_dict() for w in widgets] - study._storage.set_study_system_attr(study._study_id, SYSTEM_ATTR_KEY, widget_dicts) - study._storage.set_study_system_attr(study._study_id, SYSTEM_ATTR_OUTPUT_TYPE_KEY, "user_attr") + study._storage.set_study_system_attr(study._study_id, FORM_WIDGETS_KEY, widget_dicts) + study._storage.set_study_system_attr( + study._study_id, FORM_WIDGETS_OUTPUT_TYPE_KEY, "user_attr" + ) def get_objective_form_widgets_json( study_system_attr: dict[str, Any] ) -> Optional[list[ObjectiveFormWidgetJSON]]: - if SYSTEM_ATTR_KEY in study_system_attr: - return study_system_attr[SYSTEM_ATTR_KEY] + if FORM_WIDGETS_KEY in study_system_attr: + return study_system_attr[FORM_WIDGETS_KEY] # For optuna-dashboard v0.9.0b5 users if "dashboard:objective_form_widgets" in study_system_attr: return json.loads(study_system_attr["dashboard:objective_form_widgets"]) From 416c6c7421428b6aa22a17edbbcaa6be4245643a Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Wed, 12 Apr 2023 15:45:24 +0900 Subject: [PATCH 21/23] Include form_widgets_output_type in form_widgets --- optuna_dashboard/_app.py | 7 --- optuna_dashboard/_objective_form_widget.py | 48 ++++++++++++-------- optuna_dashboard/_serializer.py | 7 +-- optuna_dashboard/ts/components/TrialList.tsx | 4 +- optuna_dashboard/ts/types/index.d.ts | 8 +++- 5 files changed, 40 insertions(+), 34 deletions(-) diff --git a/optuna_dashboard/_app.py b/optuna_dashboard/_app.py index ee3cb09dc..a72c60c21 100644 --- a/optuna_dashboard/_app.py +++ b/optuna_dashboard/_app.py @@ -34,7 +34,6 @@ from ._bottle_util import json_api_view from ._cached_extra_study_property import get_cached_extra_study_property from ._importance import get_param_importance_from_trials_cache -from ._objective_form_widget import FORM_WIDGETS_OUTPUT_TYPE_KEY from ._pareto_front import get_pareto_front_trials from ._serializer import serialize_study_detail from ._serializer import serialize_study_summary @@ -358,11 +357,6 @@ def get_study_detail(study_id: int) -> dict[str, Any]: union_user_attrs, has_intermediate_values, ) = get_cached_extra_study_property(study_id, trials) - form_widgets_output_type = storage.get_study_system_attrs(study_id).get( - FORM_WIDGETS_OUTPUT_TYPE_KEY - ) - if form_widgets_output_type is None: - form_widgets_output_type = "objective" return serialize_study_detail( summary, best_trials, @@ -371,7 +365,6 @@ def get_study_detail(study_id: int) -> dict[str, Any]: union, union_user_attrs, has_intermediate_values, - form_widgets_output_type, ) @app.get("/api/studies//param_importances") diff --git a/optuna_dashboard/_objective_form_widget.py b/optuna_dashboard/_objective_form_widget.py index 8f29b64ee..70bb5b6d7 100644 --- a/optuna_dashboard/_objective_form_widget.py +++ b/optuna_dashboard/_objective_form_widget.py @@ -49,9 +49,13 @@ "UserAttrRefJSON", {"type": Literal["user_attr"], "key": str, "user_attr_key": Optional[str]}, ) - ObjectiveFormWidgetJSON = Union[ - ChoiceWidgetJSON, SliderWidgetJSON, TextInputWidgetJSON, UserAttrRefJSON - ] + FormWidgetJSON = TypedDict( + "FormWidgetJSON", + { + "output_type": Literal["objective", "user_attr"], + "widgets": list[Union[ChoiceWidgetJSON, SliderWidgetJSON, TextInputWidgetJSON, UserAttrRefJSON]] + } + ) @dataclass @@ -127,7 +131,6 @@ def to_dict(self) -> UserAttrRefJSON: ObjectiveSliderWidget = SliderWidget ObjectiveTextInputWidget = TextInputWidget FORM_WIDGETS_KEY = "dashboard:form_widgets:v2" -FORM_WIDGETS_OUTPUT_TYPE_KEY = "dashboard:form_widgets_output_type:v2" def register_objective_form_widgets( @@ -137,11 +140,11 @@ def register_objective_form_widgets( raise ValueError("The length of actions must be the same with the number of objectives.") if any(w.user_attr_key is not None for w in widgets): warnings.warn("`user_attr_key` specified, but it will not be used.") - widget_dicts = [w.to_dict() for w in widgets] - study._storage.set_study_system_attr(study._study_id, FORM_WIDGETS_KEY, widget_dicts) - study._storage.set_study_system_attr( - study._study_id, FORM_WIDGETS_OUTPUT_TYPE_KEY, "objective" - ) + form_widgets: FormWidgetJSON = { + "output_type": "objective", + "widgets": [w.to_dict() for w in widgets], + } + study._storage.set_study_system_attr(study._study_id, FORM_WIDGETS_KEY, form_widgets) def register_user_attr_form_widgets( @@ -151,19 +154,28 @@ def register_user_attr_form_widgets( raise ValueError("`user_attr_key` is not specified.") if len(widgets) != len(set(w.user_attr_key for w in widgets)): raise ValueError("`user_attr_key` must be unique for each widget.") - widget_dicts = [w.to_dict() for w in widgets] - study._storage.set_study_system_attr(study._study_id, FORM_WIDGETS_KEY, widget_dicts) - study._storage.set_study_system_attr( - study._study_id, FORM_WIDGETS_OUTPUT_TYPE_KEY, "user_attr" - ) + form_widgets: FormWidgetJSON = { + "output_type": "user_attr", + "widgets": [w.to_dict() for w in widgets], + } + study._storage.set_study_system_attr(study._study_id, FORM_WIDGETS_KEY, form_widgets) -def get_objective_form_widgets_json( - study_system_attr: dict[str, Any] -) -> Optional[list[ObjectiveFormWidgetJSON]]: +def get_form_widgets_json(study_system_attr: dict[str, Any]) -> Optional[FormWidgetJSON]: if FORM_WIDGETS_KEY in study_system_attr: return study_system_attr[FORM_WIDGETS_KEY] + + # For optuna-dashboard v0.9.0 and v0.9.0b6 users + if "dashboard:objective_form_widgets:v1" in study_system_attr: + return { + "output_type": "objective", + "widgets": study_system_attr["dashboard:objective_form_widgets:v1"] + } + # For optuna-dashboard v0.9.0b5 users if "dashboard:objective_form_widgets" in study_system_attr: - return json.loads(study_system_attr["dashboard:objective_form_widgets"]) + return { + "output_type": "objective", + "widgets": json.loads(study_system_attr["dashboard:objective_form_widgets"]) + } return None diff --git a/optuna_dashboard/_serializer.py b/optuna_dashboard/_serializer.py index b5ef1024c..9f6b0bc17 100644 --- a/optuna_dashboard/_serializer.py +++ b/optuna_dashboard/_serializer.py @@ -2,7 +2,6 @@ import json from typing import Any -from typing import Optional from typing import TYPE_CHECKING from typing import Union @@ -14,7 +13,7 @@ from . import _note as note from ._named_objectives import get_objective_names -from ._objective_form_widget import get_objective_form_widgets_json +from ._objective_form_widget import get_form_widgets_json from .artifact._backend import list_trial_artifacts @@ -122,7 +121,6 @@ def serialize_study_detail( union: list[tuple[str, BaseDistribution]], union_user_attrs: list[tuple[str, bool]], has_intermediate_values: bool, - form_widgets_output_type: Optional[str], ) -> dict[str, Any]: serialized: dict[str, Any] = { "name": summary.study_name, @@ -146,10 +144,9 @@ def serialize_study_detail( objective_names = get_objective_names(system_attrs) if objective_names: serialized["objective_names"] = objective_names - objective_form_widgets = get_objective_form_widgets_json(system_attrs) + objective_form_widgets = get_form_widgets_json(system_attrs) if objective_form_widgets: serialized["objective_form_widgets"] = objective_form_widgets - serialized["form_widgets_output_type"] = form_widgets_output_type return serialized diff --git a/optuna_dashboard/ts/components/TrialList.tsx b/optuna_dashboard/ts/components/TrialList.tsx index 7192d942c..8a72da992 100644 --- a/optuna_dashboard/ts/components/TrialList.tsx +++ b/optuna_dashboard/ts/components/TrialList.tsx @@ -818,10 +818,10 @@ export const TrialList: FC<{ studyDetail: StudyDetail | null }> = ({ directions={studyDetail?.directions || []} objectiveNames={studyDetail?.objective_names || []} objectiveFormWidgets={ - studyDetail?.objective_form_widgets || [] + studyDetail?.objective_form_widgets?.widgets || [] } formWigetsOutputType={ - studyDetail?.form_widgets_output_type || "" + studyDetail?.objective_form_widgets?.output_type || "" } /> ))} diff --git a/optuna_dashboard/ts/types/index.d.ts b/optuna_dashboard/ts/types/index.d.ts index 4410691af..4fea1a7e8 100644 --- a/optuna_dashboard/ts/types/index.d.ts +++ b/optuna_dashboard/ts/types/index.d.ts @@ -166,6 +166,11 @@ type ObjectiveFormWidget = | ObjectiveTextInputWidget | ObjectiveUserAttrRef +type FormWidgets = { + "output_type": string + "widgets": ObjectiveFormWidget[] +} + type StudyDetail = { id: number name: string @@ -179,8 +184,7 @@ type StudyDetail = { has_intermediate_values: boolean note: Note objective_names?: string[] - objective_form_widgets?: ObjectiveFormWidget[] - form_widgets_output_type?: string + objective_form_widgets?: FormWidgets } type StudyDetails = { From 765b0d3620c2e075ef352b0b5c0c2febd1d8aaa1 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Wed, 12 Apr 2023 16:13:00 +0900 Subject: [PATCH 22/23] Apply formatter --- optuna_dashboard/_objective_form_widget.py | 10 ++++++---- optuna_dashboard/ts/types/index.d.ts | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/optuna_dashboard/_objective_form_widget.py b/optuna_dashboard/_objective_form_widget.py index 70bb5b6d7..7b076ec12 100644 --- a/optuna_dashboard/_objective_form_widget.py +++ b/optuna_dashboard/_objective_form_widget.py @@ -53,8 +53,10 @@ "FormWidgetJSON", { "output_type": Literal["objective", "user_attr"], - "widgets": list[Union[ChoiceWidgetJSON, SliderWidgetJSON, TextInputWidgetJSON, UserAttrRefJSON]] - } + "widgets": list[ + Union[ChoiceWidgetJSON, SliderWidgetJSON, TextInputWidgetJSON, UserAttrRefJSON] + ], + }, ) @@ -169,13 +171,13 @@ def get_form_widgets_json(study_system_attr: dict[str, Any]) -> Optional[FormWid if "dashboard:objective_form_widgets:v1" in study_system_attr: return { "output_type": "objective", - "widgets": study_system_attr["dashboard:objective_form_widgets:v1"] + "widgets": study_system_attr["dashboard:objective_form_widgets:v1"], } # For optuna-dashboard v0.9.0b5 users if "dashboard:objective_form_widgets" in study_system_attr: return { "output_type": "objective", - "widgets": json.loads(study_system_attr["dashboard:objective_form_widgets"]) + "widgets": json.loads(study_system_attr["dashboard:objective_form_widgets"]), } return None diff --git a/optuna_dashboard/ts/types/index.d.ts b/optuna_dashboard/ts/types/index.d.ts index 4fea1a7e8..b4f9a69ae 100644 --- a/optuna_dashboard/ts/types/index.d.ts +++ b/optuna_dashboard/ts/types/index.d.ts @@ -167,8 +167,8 @@ type ObjectiveFormWidget = | ObjectiveUserAttrRef type FormWidgets = { - "output_type": string - "widgets": ObjectiveFormWidget[] + output_type: string + widgets: ObjectiveFormWidget[] } type StudyDetail = { From 031a1e103a499349218e396faccd94a2edd5a6c6 Mon Sep 17 00:00:00 2001 From: Kenshin Abe Date: Wed, 12 Apr 2023 16:34:38 +0900 Subject: [PATCH 23/23] Fix type --- optuna_dashboard/ts/apiClient.ts | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/optuna_dashboard/ts/apiClient.ts b/optuna_dashboard/ts/apiClient.ts index 849bf2871..868ba6926 100644 --- a/optuna_dashboard/ts/apiClient.ts +++ b/optuna_dashboard/ts/apiClient.ts @@ -67,8 +67,7 @@ interface StudyDetailResponse { has_intermediate_values: boolean note: Note objective_names?: string[] - objective_form_widgets?: ObjectiveFormWidget[] - form_widgets_output_type?: string + objective_form_widgets?: FormWidgets } export const getStudyDetailAPI = ( @@ -102,7 +101,6 @@ export const getStudyDetailAPI = ( note: res.data.note, objective_names: res.data.objective_names, objective_form_widgets: res.data.objective_form_widgets, - form_widgets_output_type: res.data.form_widgets_output_type, } }) }