Skip to content

Commit

Permalink
Follow-up optuna#411 to refactor type definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
c-bata committed Apr 13, 2023
1 parent 031a1e1 commit d6f3846
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 69 deletions.
18 changes: 9 additions & 9 deletions optuna_dashboard/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from ._app import run_server # noqa
from ._app import wsgi # noqa
from ._form_widget import ChoiceWidget # noqa
from ._form_widget import ObjectiveChoiceWidget # noqa
from ._form_widget import ObjectiveSliderWidget # noqa
from ._form_widget import ObjectiveTextInputWidget # noqa
from ._form_widget import ObjectiveUserAttrRef # noqa
from ._form_widget import register_objective_form_widgets # noqa
from ._form_widget import register_user_attr_form_widgets # noqa
from ._form_widget import SliderWidget # noqa
from ._form_widget import TextInputWidget # noqa
from ._named_objectives import set_objective_names # noqa
from ._note import save_note # noqa
from ._objective_form_widget import ChoiceWidget # noqa
from ._objective_form_widget import ObjectiveChoiceWidget # noqa
from ._objective_form_widget import ObjectiveSliderWidget # noqa
from ._objective_form_widget import ObjectiveTextInputWidget # noqa
from ._objective_form_widget import ObjectiveUserAttrRef # noqa
from ._objective_form_widget import register_objective_form_widgets # noqa
from ._objective_form_widget import register_user_attr_form_widgets # noqa
from ._objective_form_widget import SliderWidget # noqa
from ._objective_form_widget import TextInputWidget # noqa


__version__ = "0.9.0"
File renamed without changes.
8 changes: 4 additions & 4 deletions optuna_dashboard/_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from optuna.trial import FrozenTrial

from . import _note as note
from ._form_widget import get_form_widgets_json
from ._named_objectives import get_objective_names
from ._objective_form_widget import get_form_widgets_json
from .artifact._backend import list_trial_artifacts


Expand Down Expand Up @@ -144,9 +144,9 @@ def serialize_study_detail(
objective_names = get_objective_names(system_attrs)
if objective_names:
serialized["objective_names"] = objective_names
objective_form_widgets = get_form_widgets_json(system_attrs)
if objective_form_widgets:
serialized["objective_form_widgets"] = objective_form_widgets
form_widgets = get_form_widgets_json(system_attrs)
if form_widgets:
serialized["form_widgets"] = form_widgets
return serialized


Expand Down
4 changes: 2 additions & 2 deletions optuna_dashboard/ts/apiClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ interface StudyDetailResponse {
has_intermediate_values: boolean
note: Note
objective_names?: string[]
objective_form_widgets?: FormWidgets
form_widgets?: FormWidgets
}

export const getStudyDetailAPI = (
Expand Down Expand Up @@ -100,7 +100,7 @@ export const getStudyDetailAPI = (
has_intermediate_values: res.data.has_intermediate_values,
note: res.data.note,
objective_names: res.data.objective_names,
objective_form_widgets: res.data.objective_form_widgets,
form_widgets: res.data.form_widgets,
}
})
}
Expand Down
37 changes: 19 additions & 18 deletions optuna_dashboard/ts/components/ObjectiveForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@ export const ObjectiveForm: FC<{
trial: Trial
directions: StudyDirection[]
names: string[]
widgets: ObjectiveFormWidget[]
outputType: string
}> = ({ trial, directions, names, widgets, outputType }) => {
formWidgets: FormWidgets
}> = ({ trial, directions, names, formWidgets }) => {
const theme = useTheme()
const action = actionCreator()
const [values, setValues] = useState<(number | null)[]>(
widgets.map((widget) => {
formWidgets.widgets.map((widget) => {
if (widget === undefined) {
return null
} else if (widget.type === "text") {
Expand Down Expand Up @@ -65,22 +64,25 @@ export const ObjectiveForm: FC<{

const handleSubmit = (e: React.MouseEvent<HTMLButtonElement>): void => {
e.preventDefault()
if (outputType == "objective") {
if (formWidgets.output_type == "objective") {
const filtered = values.filter<number>((v): v is number => v !== null)
if (filtered.length !== directions.length) {
return
}
action.makeTrialComplete(trial.study_id, trial.trial_id, filtered)
} else if (outputType == "user_attr") {
} else if (formWidgets.output_type == "user_attr") {
const user_attrs = Object.fromEntries(
widgets.map((widget, i) => [widget.user_attr_key, values[i]])
formWidgets.widgets.map((widget, i) => [
widget.user_attr_key,
values[i],
])
)
action.saveTrialUserAttrs(trial.study_id, trial.trial_id, user_attrs)
}
}

const getMetricName = (i: number): string => {
if (outputType == "objective") {
if (formWidgets.output_type == "objective") {
const n = names.at(i)
if (n !== undefined) {
return n
Expand All @@ -90,8 +92,8 @@ export const ObjectiveForm: FC<{
} else {
return `Objective ${i}`
}
} else if (outputType == "user_attr") {
return widgets[i].user_attr_key as string
} else if (formWidgets.output_type == "user_attr") {
return formWidgets.widgets[i].user_attr_key as string
}
return "Unkown metric name"
}
Expand All @@ -114,7 +116,7 @@ export const ObjectiveForm: FC<{
p: theme.spacing(1),
}}
>
{widgets.map((widget, i) => {
{formWidgets.widgets.map((widget, i) => {
const value = values.at(i)
const key = `objective-${i}`
if (widget.type === "text") {
Expand Down Expand Up @@ -262,12 +264,11 @@ export const ReadonlyObjectiveForm: FC<{
trial: Trial
directions: StudyDirection[]
names: string[]
widgets: ObjectiveFormWidget[]
outputType: string
}> = ({ trial, directions, names, widgets, outputType }) => {
formWidgets: FormWidgets
}> = ({ trial, directions, names, formWidgets }) => {
const theme = useTheme()
const getMetricName = (i: number): string => {
if (outputType == "objective") {
if (formWidgets.output_type == "objective") {
const n = names.at(i)
if (n !== undefined) {
return n
Expand All @@ -277,8 +278,8 @@ export const ReadonlyObjectiveForm: FC<{
} else {
return `Objective ${i}`
}
} else if (outputType == "user_attr") {
return widgets[i].user_attr_key as string
} else if (formWidgets.output_type == "user_attr") {
return formWidgets.widgets[i].user_attr_key as string
}
return "Unkown metric name"
}
Expand All @@ -300,7 +301,7 @@ export const ReadonlyObjectiveForm: FC<{
p: theme.spacing(1),
}}
>
{widgets.map((widget, i) => {
{formWidgets.widgets.map((widget, i) => {
const key = `objective-${i}`
if (widget.type === "text") {
return (
Expand Down
57 changes: 23 additions & 34 deletions optuna_dashboard/ts/components/TrialList.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,8 @@ const TrialListDetail: FC<{
isBestTrial: (trialId: number) => boolean
directions: StudyDirection[]
objectiveNames: string[]
objectiveFormWidgets: ObjectiveFormWidget[]
formWigetsOutputType: string
}> = ({
trial,
isBestTrial,
directions,
objectiveNames,
objectiveFormWidgets,
formWigetsOutputType,
}) => {
formWidgets?: FormWidgets
}> = ({ trial, isBestTrial, directions, objectiveNames, formWidgets }) => {
const theme = useTheme()
const artifactEnabled = useRecoilValue<boolean>(artifactIsAvailable)
const startMs = trial.datetime_start?.getTime()
Expand Down Expand Up @@ -293,24 +285,26 @@ const TrialListDetail: FC<{
latestNote={trial.note}
cardSx={{ marginBottom: theme.spacing(2) }}
/>
{trial.state === "Running" && directions.length > 0 && (
<ObjectiveForm
trial={trial}
directions={directions}
names={objectiveNames}
widgets={objectiveFormWidgets}
outputType={formWigetsOutputType}
/>
)}
{trial.state === "Complete" && directions.length > 0 && (
<ReadonlyObjectiveForm
trial={trial}
directions={directions}
names={objectiveNames}
widgets={objectiveFormWidgets}
outputType={formWigetsOutputType}
/>
)}
{trial.state === "Running" &&
directions.length > 0 &&
formWidgets !== undefined && (
<ObjectiveForm
trial={trial}
directions={directions}
names={objectiveNames}
formWidgets={formWidgets}
/>
)}
{trial.state === "Complete" &&
directions.length > 0 &&
formWidgets !== undefined && (
<ReadonlyObjectiveForm
trial={trial}
directions={directions}
names={objectiveNames}
formWidgets={formWidgets}
/>
)}
{artifactEnabled && <TrialArtifact trial={trial} />}
</Box>
)
Expand Down Expand Up @@ -817,12 +811,7 @@ export const TrialList: FC<{ studyDetail: StudyDetail | null }> = ({
isBestTrial={isBestTrial}
directions={studyDetail?.directions || []}
objectiveNames={studyDetail?.objective_names || []}
objectiveFormWidgets={
studyDetail?.objective_form_widgets?.widgets || []
}
formWigetsOutputType={
studyDetail?.objective_form_widgets?.output_type || ""
}
formWidgets={studyDetail?.form_widgets}
/>
))}
</Box>
Expand Down
5 changes: 3 additions & 2 deletions optuna_dashboard/ts/types/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,15 @@ type ObjectiveUserAttrRef = {
user_attr_key?: string
}

// TODO(kenshin): Rename this type to FormWidget or something.
type ObjectiveFormWidget =
| ObjectiveChoiceWidget
| ObjectiveSliderWidget
| ObjectiveTextInputWidget
| ObjectiveUserAttrRef

type FormWidgets = {
output_type: string
output_type: "objective" | "user_attr"
widgets: ObjectiveFormWidget[]
}

Expand All @@ -184,7 +185,7 @@ type StudyDetail = {
has_intermediate_values: boolean
note: Note
objective_names?: string[]
objective_form_widgets?: FormWidgets
form_widgets?: FormWidgets
}

type StudyDetails = {
Expand Down

0 comments on commit d6f3846

Please sign in to comment.