Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support human-in-the-loop via trial.user_attrs #411

Merged
merged 24 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions optuna_dashboard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
from ._app import wsgi # noqa
from ._named_objectives import set_objective_names # noqa
from ._note import save_note # noqa
from ._objective_form_widget import ChoiceWidget # noqa
from ._objective_form_widget import ObjectiveChoiceWidget # noqa
from ._objective_form_widget import ObjectiveSliderWidget # noqa
from ._objective_form_widget import ObjectiveTextInputWidget # noqa
from ._objective_form_widget import ObjectiveUserAttrRef # noqa
from ._objective_form_widget import register_objective_form_widgets # noqa
from ._objective_form_widget import register_user_attr_form_widgets # noqa
from ._objective_form_widget import SliderWidget # noqa
from ._objective_form_widget import TextInputWidget # noqa


__version__ = "0.9.0"
18 changes: 18 additions & 0 deletions optuna_dashboard/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,24 @@ def tell_trial(trial_id: int) -> dict[str, Any]:
response.status = 204
return {}

@app.post("/api/trials/<trial_id:int>/user-attrs")
@json_api_view
def save_trial_user_attrs(trial_id: int) -> dict[str, Any]:
user_attrs = request.json.get("user_attrs", {})
if not user_attrs:
response.status = 400 # Bad request
return {"reason": "user_attrs must be specified."}

try:
for key, val in user_attrs.items():
storage.set_trial_user_attr(trial_id, key, val)
except Exception as e:
response.status = 500
return {"reason": f"Internal server error: {e}"}

response.status = 204
return {}

@app.put("/api/studies/<study_id:int>/<trial_id:int>/note")
@json_api_view
def save_trial_note(study_id: int, trial_id: int) -> dict[str, Any]:
Expand Down
90 changes: 70 additions & 20 deletions optuna_dashboard/_objective_form_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
from typing import TYPE_CHECKING
from typing import Union
import warnings

import optuna

Expand All @@ -21,6 +22,7 @@
"description": Optional[str],
"choices": list[str],
"values": list[float],
"user_attr_key": Optional[str],
},
)
SliderWidgetLabel = TypedDict(
Expand All @@ -36,40 +38,53 @@
"max": float,
"step": Optional[float],
"labels": Optional[list[SliderWidgetLabel]],
"user_attr_key": Optional[str],
},
)
TextInputWidgetJSON = TypedDict(
"TextInputWidgetJSON",
{"type": Literal["text"], "description": Optional[str]},
{"type": Literal["text"], "description": Optional[str], "user_attr_key": Optional[str]},
)
UserAttrRefJSON = TypedDict(
"UserAttrRefJSON",
{"type": Literal["user_attr"], "key": str, "user_attr_key": Optional[str]},
)
FormWidgetJSON = TypedDict(
"FormWidgetJSON",
{
"output_type": Literal["objective", "user_attr"],
"widgets": list[
Union[ChoiceWidgetJSON, SliderWidgetJSON, TextInputWidgetJSON, UserAttrRefJSON]
],
},
)
UserAttrRefJSON = TypedDict("UserAttrRefJSON", {"type": Literal["user_attr"], "key": str})
ObjectiveFormWidgetJSON = Union[
ChoiceWidgetJSON, SliderWidgetJSON, TextInputWidgetJSON, UserAttrRefJSON
]


@dataclass
class ObjectiveChoiceWidget:
class ChoiceWidget:
choices: list[str]
values: list[float]
description: Optional[str] = None
user_attr_key: Optional[str] = None

def to_dict(self) -> ChoiceWidgetJSON:
return {
"type": "choice",
"description": self.description,
"choices": self.choices,
"values": self.values,
"user_attr_key": self.user_attr_key,
}


@dataclass
class ObjectiveSliderWidget:
class SliderWidget:
min: float
max: float
step: Optional[float] = None
labels: Optional[list[tuple[float, str]]] = None
description: Optional[str] = None
user_attr_key: Optional[str] = None

def to_dict(self) -> SliderWidgetJSON:
labels: Optional[list[SliderWidgetLabel]] = None
Expand All @@ -82,52 +97,87 @@ def to_dict(self) -> SliderWidgetJSON:
"max": self.max,
"step": self.step,
"labels": labels,
"user_attr_key": self.user_attr_key,
}


@dataclass
class ObjectiveTextInputWidget:
class TextInputWidget:
description: Optional[str] = None
user_attr_key: Optional[str] = None

def to_dict(self) -> TextInputWidgetJSON:
return {
"type": "text",
"description": self.description,
"user_attr_key": self.user_attr_key,
}


@dataclass
class ObjectiveUserAttrRef:
key: str
user_attr_key: Optional[str] = None

def to_dict(self) -> UserAttrRefJSON:
return {
"type": "user_attr",
"key": self.key,
"user_attr_key": self.user_attr_key,
}


ObjectiveFormWidget = Union[
ObjectiveChoiceWidget, ObjectiveSliderWidget, ObjectiveTextInputWidget, ObjectiveUserAttrRef
]
SYSTEM_ATTR_KEY = "dashboard:objective_form_widgets:v1"
ObjectiveFormWidget = Union[ChoiceWidget, SliderWidget, TextInputWidget, ObjectiveUserAttrRef]
# For backward compatibility.
ObjectiveChoiceWidget = ChoiceWidget
ObjectiveSliderWidget = SliderWidget
ObjectiveTextInputWidget = TextInputWidget
FORM_WIDGETS_KEY = "dashboard:form_widgets:v2"


def register_objective_form_widgets(
study: optuna.Study, widgets: list[ObjectiveFormWidget]
) -> None:
if len(study.directions) != len(widgets):
raise ValueError("The length of actions must be the same with the number of objectives.")
widget_dicts = [w.to_dict() for w in widgets]
study._storage.set_study_system_attr(study._study_id, SYSTEM_ATTR_KEY, widget_dicts)
if any(w.user_attr_key is not None for w in widgets):
warnings.warn("`user_attr_key` specified, but it will not be used.")
form_widgets: FormWidgetJSON = {
"output_type": "objective",
"widgets": [w.to_dict() for w in widgets],
}
study._storage.set_study_system_attr(study._study_id, FORM_WIDGETS_KEY, form_widgets)


def get_objective_form_widgets_json(
study_system_attr: dict[str, Any]
) -> Optional[list[ObjectiveFormWidgetJSON]]:
if SYSTEM_ATTR_KEY in study_system_attr:
return study_system_attr[SYSTEM_ATTR_KEY]
def register_user_attr_form_widgets(
study: optuna.Study, widgets: list[ObjectiveFormWidget]
) -> None:
if any(w.user_attr_key is None for w in widgets):
raise ValueError("`user_attr_key` is not specified.")
if len(widgets) != len(set(w.user_attr_key for w in widgets)):
raise ValueError("`user_attr_key` must be unique for each widget.")
form_widgets: FormWidgetJSON = {
"output_type": "user_attr",
"widgets": [w.to_dict() for w in widgets],
}
study._storage.set_study_system_attr(study._study_id, FORM_WIDGETS_KEY, form_widgets)


def get_form_widgets_json(study_system_attr: dict[str, Any]) -> Optional[FormWidgetJSON]:
if FORM_WIDGETS_KEY in study_system_attr:
return study_system_attr[FORM_WIDGETS_KEY]

# For optuna-dashboard v0.9.0 and v0.9.0b6 users
if "dashboard:objective_form_widgets:v1" in study_system_attr:
return {
"output_type": "objective",
"widgets": study_system_attr["dashboard:objective_form_widgets:v1"],
}

# For optuna-dashboard v0.9.0b5 users
if "dashboard:objective_form_widgets" in study_system_attr:
return json.loads(study_system_attr["dashboard:objective_form_widgets"])
return {
"output_type": "objective",
"widgets": json.loads(study_system_attr["dashboard:objective_form_widgets"]),
}
return None
4 changes: 2 additions & 2 deletions optuna_dashboard/_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from . import _note as note
from ._named_objectives import get_objective_names
from ._objective_form_widget import get_objective_form_widgets_json
from ._objective_form_widget import get_form_widgets_json
from .artifact._backend import list_trial_artifacts


Expand Down Expand Up @@ -144,7 +144,7 @@ def serialize_study_detail(
objective_names = get_objective_names(system_attrs)
if objective_names:
serialized["objective_names"] = objective_names
objective_form_widgets = get_objective_form_widgets_json(system_attrs)
objective_form_widgets = get_form_widgets_json(system_attrs)
if objective_form_widgets:
serialized["objective_form_widgets"] = objective_form_widgets
return serialized
Expand Down
56 changes: 56 additions & 0 deletions optuna_dashboard/ts/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
saveStudyNoteAPI,
saveTrialNoteAPI,
tellTrialAPI,
saveTrialUserAttrsAPI,
renameStudyAPI,
uploadArtifactAPI,
getMetaInfoAPI,
Expand Down Expand Up @@ -151,6 +152,26 @@ export const actionCreator = () => {
setStudyDetailState(studyId, newStudy)
}

const setTrialUserAttrs = (
studyId: number,
index: number,
user_attrs: { [key: string]: number }
) => {
const newTrial: Trial = Object.assign(
{},
studyDetails[studyId].trials[index]
)
newTrial.user_attrs = Object.keys(user_attrs).map((key) => ({
key: key,
value: user_attrs[key].toString(),
}))
const newTrials: Trial[] = [...studyDetails[studyId].trials]
newTrials[index] = newTrial
const newStudy: StudyDetail = Object.assign({}, studyDetails[studyId])
newStudy.trials = newTrials
setStudyDetailState(studyId, newStudy)
}

const setStudyParamImportanceState = (
studyId: number,
importance: ParamImportance[][]
Expand Down Expand Up @@ -508,6 +529,40 @@ export const actionCreator = () => {
})
}

const saveTrialUserAttrs = (
studyId: number,
trialId: number,
user_attrs: { [key: string]: number }
): void => {
console.log("user_attrs", user_attrs)
const message = `id=${trialId}, user_attrs=${JSON.stringify(user_attrs)}`
saveTrialUserAttrsAPI(trialId, user_attrs)
.then(() => {
const index = studyDetails[studyId].trials.findIndex(
(t) => t.trial_id === trialId
)
if (index === -1) {
enqueueSnackbar(`Unexpected error happens. Please reload the page.`, {
variant: "error",
})
return
}
setTrialUserAttrs(studyId, index, user_attrs)
enqueueSnackbar(`Successfully updated trial (${message})`, {
variant: "success",
})
})
.catch((err) => {
const reason = err.response?.data.reason
enqueueSnackbar(
`Failed to update trial (${message}). Reason: ${reason}`,
{
variant: "error",
}
)
console.log(err)
})
}
return {
updateAPIMeta,
updateStudyDetail,
Expand All @@ -526,6 +581,7 @@ export const actionCreator = () => {
deleteArtifact,
makeTrialComplete,
makeTrialFail,
saveTrialUserAttrs,
}
}

Expand Down
15 changes: 14 additions & 1 deletion optuna_dashboard/ts/apiClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ interface StudyDetailResponse {
has_intermediate_values: boolean
note: Note
objective_names?: string[]
objective_form_widgets?: ObjectiveFormWidget[]
objective_form_widgets?: FormWidgets
}

export const getStudyDetailAPI = (
Expand Down Expand Up @@ -281,6 +281,19 @@ export const tellTrialAPI = (
})
}

export const saveTrialUserAttrsAPI = (
trialId: number,
user_attrs: { [key: string]: number }
): Promise<void> => {
const req = { user_attrs: user_attrs }

return axiosInstance
.post<void>(`/api/trials/${trialId}/user-attrs`, req)
.then((res) => {
return
})
}

interface ParamImportancesResponse {
param_importances: ParamImportance[][]
}
Expand Down
Loading