Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support human-in-the-loop via trial.user_attrs #411

Merged
merged 24 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions optuna_dashboard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
from ._app import wsgi # noqa
from ._named_objectives import set_objective_names # noqa
from ._note import save_note # noqa
from ._objective_form_widget import ChoiceWidget # noqa
from ._objective_form_widget import ObjectiveChoiceWidget # noqa
from ._objective_form_widget import ObjectiveSliderWidget # noqa
from ._objective_form_widget import ObjectiveTextInputWidget # noqa
from ._objective_form_widget import ObjectiveUserAttrRef # noqa
from ._objective_form_widget import register_objective_form_widgets # noqa
from ._objective_form_widget import register_user_attr_form_widgets # noqa
from ._objective_form_widget import SliderWidget # noqa
from ._objective_form_widget import TextInputWidget # noqa


__version__ = "0.9.0"
24 changes: 24 additions & 0 deletions optuna_dashboard/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ._bottle_util import json_api_view
from ._cached_extra_study_property import get_cached_extra_study_property
from ._importance import get_param_importance_from_trials_cache
from ._objective_form_widget import SYSTEM_ATTR_OUTPUT_TYPE_KEY
from ._pareto_front import get_pareto_front_trials
from ._serializer import serialize_study_detail
from ._serializer import serialize_study_summary
Expand Down Expand Up @@ -357,6 +358,11 @@ def get_study_detail(study_id: int) -> dict[str, Any]:
union_user_attrs,
has_intermediate_values,
) = get_cached_extra_study_property(study_id, trials)
form_widgets_output_type = storage.get_study_system_attrs(study_id).get(
SYSTEM_ATTR_OUTPUT_TYPE_KEY
)
if form_widgets_output_type is None:
form_widgets_output_type = "objective"
return serialize_study_detail(
summary,
best_trials,
Expand All @@ -365,6 +371,7 @@ def get_study_detail(study_id: int) -> dict[str, Any]:
union,
union_user_attrs,
has_intermediate_values,
form_widgets_output_type,
)

@app.get("/api/studies/<study_id:int>/param_importances")
Expand Down Expand Up @@ -447,6 +454,23 @@ def tell_trial(trial_id: int) -> dict[str, Any]:
response.status = 204
return {}

@app.post("/api/trials/<trial_id:int>/user-attrs")
@json_api_view
def save_trial_user_attrs(trial_id: int) -> dict[str, Any]:
if "user_attrs" not in request.json:
response.status = 400 # Bad request
return {"reason": "user_attrs must be specified."}

try: # TODO(knshnb): Proper error handling.
for key, val in request.json.get("user_attrs").items():
storage.set_trial_user_attr(trial_id, key, val)
knshnb marked this conversation as resolved.
Show resolved Hide resolved
except Exception as e:
response.status = 500
return {"reason": f"Internal server error: {e}"}

response.status = 204
return {}

@app.put("/api/studies/<study_id:int>/<trial_id:int>/note")
@json_api_view
def save_trial_note(study_id: int, trial_id: int) -> dict[str, Any]:
Expand Down
48 changes: 40 additions & 8 deletions optuna_dashboard/_objective_form_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
from typing import TYPE_CHECKING
from typing import Union
import warnings

import optuna

Expand All @@ -21,6 +22,7 @@
"description": Optional[str],
"choices": list[str],
"values": list[float],
"user_attr_key": Optional[str],
},
)
SliderWidgetLabel = TypedDict(
Expand All @@ -36,40 +38,47 @@
"max": float,
"step": Optional[float],
"labels": Optional[list[SliderWidgetLabel]],
"user_attr_key": Optional[str],
},
)
TextInputWidgetJSON = TypedDict(
"TextInputWidgetJSON",
{"type": Literal["text"], "description": Optional[str]},
{"type": Literal["text"], "description": Optional[str], "user_attr_key": Optional[str]},
)
UserAttrRefJSON = TypedDict(
"UserAttrRefJSON",
{"type": Literal["user_attr"], "key": str, "user_attr_key": Optional[str]},
)
UserAttrRefJSON = TypedDict("UserAttrRefJSON", {"type": Literal["user_attr"], "key": str})
ObjectiveFormWidgetJSON = Union[
ChoiceWidgetJSON, SliderWidgetJSON, TextInputWidgetJSON, UserAttrRefJSON
]


@dataclass
class ObjectiveChoiceWidget:
class ChoiceWidget:
choices: list[str]
values: list[float]
description: Optional[str] = None
user_attr_key: Optional[str] = None

def to_dict(self) -> ChoiceWidgetJSON:
return {
"type": "choice",
"description": self.description,
"choices": self.choices,
"values": self.values,
"user_attr_key": self.user_attr_key,
}


@dataclass
class ObjectiveSliderWidget:
class SliderWidget:
min: float
max: float
step: Optional[float] = None
labels: Optional[list[tuple[float, str]]] = None
description: Optional[str] = None
user_attr_key: Optional[str] = None

def to_dict(self) -> SliderWidgetJSON:
labels: Optional[list[SliderWidgetLabel]] = None
Expand All @@ -82,44 +91,67 @@ def to_dict(self) -> SliderWidgetJSON:
"max": self.max,
"step": self.step,
"labels": labels,
"user_attr_key": self.user_attr_key,
}


@dataclass
class ObjectiveTextInputWidget:
class TextInputWidget:
description: Optional[str] = None
user_attr_key: Optional[str] = None

def to_dict(self) -> TextInputWidgetJSON:
return {
"type": "text",
"description": self.description,
"user_attr_key": self.user_attr_key,
}


@dataclass
class ObjectiveUserAttrRef:
key: str
user_attr_key: Optional[str] = None

def to_dict(self) -> UserAttrRefJSON:
return {
"type": "user_attr",
"key": self.key,
"user_attr_key": self.user_attr_key,
}


ObjectiveFormWidget = Union[
ObjectiveChoiceWidget, ObjectiveSliderWidget, ObjectiveTextInputWidget, ObjectiveUserAttrRef
]
ObjectiveFormWidget = Union[ChoiceWidget, SliderWidget, TextInputWidget, ObjectiveUserAttrRef]
# For backward compatibility.
ObjectiveChoiceWidget = ChoiceWidget
ObjectiveSliderWidget = SliderWidget
ObjectiveTextInputWidget = TextInputWidget
SYSTEM_ATTR_KEY = "dashboard:objective_form_widgets:v1"
SYSTEM_ATTR_OUTPUT_TYPE_KEY = "dashboard:form_widgets_output_type:v1"


def register_objective_form_widgets(
study: optuna.Study, widgets: list[ObjectiveFormWidget]
) -> None:
if len(study.directions) != len(widgets):
raise ValueError("The length of actions must be the same with the number of objectives.")
if any(w.user_attr_key is not None for w in widgets):
warnings.warn("`user_attr_key` specified, but it will not be used.")
widget_dicts = [w.to_dict() for w in widgets]
study._storage.set_study_system_attr(study._study_id, SYSTEM_ATTR_KEY, widget_dicts)
study._storage.set_study_system_attr(study._study_id, SYSTEM_ATTR_OUTPUT_TYPE_KEY, "objective")


def register_user_attr_form_widgets(
study: optuna.Study, widgets: list[ObjectiveFormWidget]
) -> None:
if any(w.user_attr_key is None for w in widgets):
raise ValueError("`user_attr_key` is not specified.")
if len(widgets) != len(set(w.user_attr_key for w in widgets)):
raise ValueError("`user_attr_key` must be unique for each widget.")
widget_dicts = [w.to_dict() for w in widgets]
study._storage.set_study_system_attr(study._study_id, SYSTEM_ATTR_KEY, widget_dicts)
study._storage.set_study_system_attr(study._study_id, SYSTEM_ATTR_OUTPUT_TYPE_KEY, "user_attr")


def get_objective_form_widgets_json(
Expand Down
3 changes: 3 additions & 0 deletions optuna_dashboard/_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union

Expand Down Expand Up @@ -121,6 +122,7 @@ def serialize_study_detail(
union: list[tuple[str, BaseDistribution]],
union_user_attrs: list[tuple[str, bool]],
has_intermediate_values: bool,
form_widgets_output_type: Optional[str],
) -> dict[str, Any]:
serialized: dict[str, Any] = {
"name": summary.study_name,
Expand All @@ -147,6 +149,7 @@ def serialize_study_detail(
objective_form_widgets = get_objective_form_widgets_json(system_attrs)
if objective_form_widgets:
serialized["objective_form_widgets"] = objective_form_widgets
serialized["form_widgets_output_type"] = form_widgets_output_type
return serialized


Expand Down
56 changes: 56 additions & 0 deletions optuna_dashboard/ts/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
saveStudyNoteAPI,
saveTrialNoteAPI,
tellTrialAPI,
saveTrialUserAttrsAPI,
renameStudyAPI,
uploadArtifactAPI,
getMetaInfoAPI,
Expand Down Expand Up @@ -151,6 +152,26 @@ export const actionCreator = () => {
setStudyDetailState(studyId, newStudy)
}

const setTrialUserAttrs = (
studyId: number,
index: number,
user_attrs: { [key: string]: number }
) => {
const newTrial: Trial = Object.assign(
{},
studyDetails[studyId].trials[index]
)
newTrial.user_attrs = Object.keys(user_attrs).map((key) => ({
key: key,
value: user_attrs[key].toString(),
}))
const newTrials: Trial[] = [...studyDetails[studyId].trials]
newTrials[index] = newTrial
const newStudy: StudyDetail = Object.assign({}, studyDetails[studyId])
newStudy.trials = newTrials
setStudyDetailState(studyId, newStudy)
}

const setStudyParamImportanceState = (
studyId: number,
importance: ParamImportance[][]
Expand Down Expand Up @@ -508,6 +529,40 @@ export const actionCreator = () => {
})
}

const saveTrialUserAttrs = (
studyId: number,
trialId: number,
user_attrs: { [key: string]: number }
): void => {
console.log("user_attrs", user_attrs)
const message = `id=${trialId}, user_attrs=${JSON.stringify(user_attrs)}`
saveTrialUserAttrsAPI(trialId, user_attrs)
.then(() => {
const index = studyDetails[studyId].trials.findIndex(
(t) => t.trial_id === trialId
)
if (index === -1) {
enqueueSnackbar(`Unexpected error happens. Please reload the page.`, {
variant: "error",
})
return
}
setTrialUserAttrs(studyId, index, user_attrs)
enqueueSnackbar(`Successfully updated trial (${message})`, {
variant: "success",
})
})
.catch((err) => {
const reason = err.response?.data.reason
enqueueSnackbar(
`Failed to update trial (${message}). Reason: ${reason}`,
{
variant: "error",
}
)
console.log(err)
})
}
return {
updateAPIMeta,
updateStudyDetail,
Expand All @@ -526,6 +581,7 @@ export const actionCreator = () => {
deleteArtifact,
makeTrialComplete,
makeTrialFail,
saveTrialUserAttrs,
}
}

Expand Down
15 changes: 15 additions & 0 deletions optuna_dashboard/ts/apiClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ interface StudyDetailResponse {
note: Note
objective_names?: string[]
objective_form_widgets?: ObjectiveFormWidget[]
form_widgets_output_type?: string
}

export const getStudyDetailAPI = (
Expand Down Expand Up @@ -101,6 +102,7 @@ export const getStudyDetailAPI = (
note: res.data.note,
objective_names: res.data.objective_names,
objective_form_widgets: res.data.objective_form_widgets,
form_widgets_output_type: res.data.form_widgets_output_type,
}
})
}
Expand Down Expand Up @@ -281,6 +283,19 @@ export const tellTrialAPI = (
})
}

export const saveTrialUserAttrsAPI = (
trialId: number,
user_attrs: { [key: string]: number }
): Promise<void> => {
const req = { user_attrs: user_attrs }

return axiosInstance
.post<void>(`/api/trials/${trialId}/user-attrs`, req)
.then((res) => {
return
})
}

interface ParamImportancesResponse {
param_importances: ParamImportance[][]
}
Expand Down
Loading