Skip to content

Commit

Permalink
Merge pull request #674 from gen740/add_upload_study_artifact_api
Browse files Browse the repository at this point in the history
Add upload study artifact api
  • Loading branch information
keisuke-umezawa authored Nov 23, 2023
2 parents cd2159e + 68e181e commit f3fda8f
Show file tree
Hide file tree
Showing 9 changed files with 511 additions and 53 deletions.
60 changes: 54 additions & 6 deletions optuna_dashboard/artifact/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def proxy_trial_artifact(

@app.post("/api/artifacts/<study_id:int>/<trial_id:int>")
@json_api_view
def upload_artifact_api(study_id: int, trial_id: int) -> dict[str, Any]:
def upload_trial_artifact_api(study_id: int, trial_id: int) -> dict[str, Any]:
trial = storage.get_trial(trial_id)
if trial is None:
response.status = 400
Expand Down Expand Up @@ -144,17 +144,50 @@ def upload_artifact_api(study_id: int, trial_id: int) -> dict[str, Any]:
"artifacts": list_trial_artifacts(storage.get_study_system_attrs(study_id), trial),
}

@app.post("/api/artifacts/<study_id:int>")
@json_api_view
def upload_study_artifact_api(study_id: int) -> dict[str, Any]:
if artifact_store is None:
response.status = 400 # Bad Request
return {"reason": "Cannot access to the artifacts."}
file = request.json.get("file")
if file is None:
response.status = 400
return {"reason": "Please specify the 'file' key."}

_, data = parse_data_uri(file)
filename = request.json.get("filename", "")
artifact_id = str(uuid.uuid4())
artifact_store.write(artifact_id, io.BytesIO(data))

mimetype, encoding = mimetypes.guess_type(filename)
artifact = {
"artifact_id": artifact_id,
"filename": filename,
"mimetype": mimetype or DEFAULT_MIME_TYPE,
"encoding": encoding,
}
attr_key = ARTIFACTS_ATTR_PREFIX + artifact_id
storage.set_study_system_attr(study_id, attr_key, json.dumps(artifact))

response.status = 201

return {
"artifact_id": artifact_id,
"artifacts": list_study_artifacts(storage.get_study_system_attrs(study_id)),
}

@app.delete("/api/artifacts/<study_id:int>/<trial_id:int>/<artifact_id:re:[0-9a-fA-F-]+>")
@json_api_view
def delete_artifact(study_id: int, trial_id: int, artifact_id: str) -> dict[str, Any]:
def delete_trial_artifact(study_id: int, trial_id: int, artifact_id: str) -> dict[str, Any]:
if artifact_store is None:
response.status = 400 # Bad Request
return {"reason": "Cannot access to the artifacts."}
artifact_store.remove(artifact_id)

# The artifact's metadata is stored in one of the following two locations:
storage.set_study_system_attr(
study_id, _dashboard_trial_artifact_prefix(trial_id) + artifact_id, json.dumps(None)
study_id, _dashboard_artifact_prefix(trial_id) + artifact_id, json.dumps(None)
)
storage.set_trial_system_attr(
trial_id, ARTIFACTS_ATTR_PREFIX + artifact_id, json.dumps(None)
Expand All @@ -163,6 +196,21 @@ def delete_artifact(study_id: int, trial_id: int, artifact_id: str) -> dict[str,
response.status = 204
return {}

@app.delete("/api/artifacts/<study_id:int>/<artifact_id:re:[0-9a-fA-F-]+>")
@json_api_view
def delete_study_artifact(study_id: int, artifact_id: str) -> dict[str, Any]:
if artifact_store is None:
response.status = 400 # Bad Request
return {"reason": "Cannot access to the artifacts."}
artifact_store.remove(artifact_id)

storage.set_study_system_attr(
study_id, ARTIFACTS_ATTR_PREFIX + artifact_id, json.dumps(None)
)

response.status = 204
return {}


def upload_artifact(
backend: ArtifactBackend,
Expand Down Expand Up @@ -220,7 +268,7 @@ def objective(trial: optuna.Trial) -> float:
return artifact_id


def _dashboard_trial_artifact_prefix(trial_id: int) -> str:
def _dashboard_artifact_prefix(trial_id: int) -> str:
return DASHBOARD_ARTIFACTS_ATTR_PREFIX + f"{trial_id}:"


Expand All @@ -240,7 +288,7 @@ def get_trial_artifact_meta(
) -> Optional[ArtifactMeta]:
# Search study_system_attrs due to backward compatibility.
study_system_attrs = storage.get_study_system_attrs(study_id)
attr_key = _dashboard_trial_artifact_prefix(trial_id=trial_id) + artifact_id
attr_key = _dashboard_artifact_prefix(trial_id=trial_id) + artifact_id
artifact_meta = study_system_attrs.get(attr_key)
if artifact_meta is not None:
return json.loads(artifact_meta)
Expand Down Expand Up @@ -284,7 +332,7 @@ def list_trial_artifacts(
dashboard_artifact_metas = [
json.loads(value)
for key, value in study_system_attrs.items()
if key.startswith(_dashboard_trial_artifact_prefix(trial._trial_id))
if key.startswith(_dashboard_artifact_prefix(trial._trial_id))
]

# Collect ArtifactMeta from trial_system_attrs. Note that artifacts uploaded via
Expand Down
84 changes: 74 additions & 10 deletions optuna_dashboard/ts/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ import {
tellTrialAPI,
saveTrialUserAttrsAPI,
renameStudyAPI,
uploadArtifactAPI,
uploadTrialArtifactAPI,
uploadStudyArtifactAPI,
getMetaInfoAPI,
deleteArtifactAPI,
deleteTrialArtifactAPI,
deleteStudyArtifactAPI,
reportPreferenceAPI,
skipPreferentialTrialAPI,
removePreferentialHistoryAPI,
Expand Down Expand Up @@ -100,7 +102,13 @@ export const actionCreator = () => {
setTrial(studyId, trialIndex, newTrial)
}

const deleteTrialArtifact = (
const setStudyArtifacts = (studyId: number, artifacts: Artifact[]) => {
const newStudy: StudyDetail = Object.assign({}, studyDetails[studyId])
newStudy.artifacts = artifacts
setStudyDetailState(studyId, newStudy)
}

const deleteTrialArtifactState = (
studyId: number,
trialId: number,
artifact_id: string
Expand All @@ -122,6 +130,18 @@ export const actionCreator = () => {
setTrialArtifacts(studyId, index, newArtifacts)
}

const deleteStudyArtifactState = (studyId: number, artifact_id: string) => {
const artifacts = studyDetails[studyId].artifacts
const artifactIndex = artifacts.findIndex(
(a) => a.artifact_id === artifact_id
)
const newArtifacts = [
...artifacts.slice(0, artifactIndex),
...artifacts.slice(artifactIndex + 1, artifacts.length),
]
setStudyArtifacts(studyId, newArtifacts)
}

const setTrialStateValues = (
studyId: number,
index: number,
Expand Down Expand Up @@ -430,7 +450,7 @@ export const actionCreator = () => {
})
}

const uploadArtifact = (
const uploadTrialArtifact = (
studyId: number,
trialId: number,
file: File
Expand All @@ -439,7 +459,7 @@ export const actionCreator = () => {
setUploading(true)
reader.readAsDataURL(file)
reader.onload = (upload: ProgressEvent<FileReader>) => {
uploadArtifactAPI(
uploadTrialArtifactAPI(
studyId,
trialId,
file.name,
Expand Down Expand Up @@ -467,14 +487,56 @@ export const actionCreator = () => {
}
}

const deleteArtifact = (
const uploadStudyArtifact = (studyId: number, file: File): void => {
const reader = new FileReader()
setUploading(true)
reader.readAsDataURL(file)
reader.onload = (upload: ProgressEvent<FileReader>) => {
uploadStudyArtifactAPI(
studyId,
file.name,
upload.target?.result as string
)
.then((res) => {
setUploading(false)
setStudyArtifacts(studyId, res.artifacts)
})
.catch((err) => {
setUploading(false)
const reason = err.response?.data.reason
enqueueSnackbar(`Failed to upload ${reason}`, { variant: "error" })
})
}
reader.onerror = (error) => {
enqueueSnackbar(`Failed to read the file ${error}`, { variant: "error" })
console.log(error)
}
}

const deleteTrialArtifact = (
studyId: number,
trialId: number,
artifactId: string
): void => {
deleteArtifactAPI(studyId, trialId, artifactId)
deleteTrialArtifactAPI(studyId, trialId, artifactId)
.then(() => {
deleteTrialArtifactState(studyId, trialId, artifactId)
enqueueSnackbar(`Success to delete an artifact.`, {
variant: "success",
})
})
.catch((err) => {
const reason = err.response?.data.reason
enqueueSnackbar(`Failed to delete ${reason}.`, {
variant: "error",
})
})
}

const deleteStudyArtifact = (studyId: number, artifactId: string): void => {
deleteStudyArtifactAPI(studyId, artifactId)
.then(() => {
deleteTrialArtifact(studyId, trialId, artifactId)
deleteStudyArtifactState(studyId, artifactId)
enqueueSnackbar(`Success to delete an artifact.`, {
variant: "success",
})
Expand Down Expand Up @@ -693,8 +755,10 @@ export const actionCreator = () => {
saveReloadInterval,
saveStudyNote,
saveTrialNote,
uploadArtifact,
deleteArtifact,
uploadTrialArtifact,
uploadStudyArtifact,
deleteTrialArtifact,
deleteStudyArtifact,
makeTrialComplete,
makeTrialFail,
saveTrialUserAttrs,
Expand Down
30 changes: 28 additions & 2 deletions optuna_dashboard/ts/apiClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ type UploadArtifactAPIResponse = {
artifacts: Artifact[]
}

export const uploadArtifactAPI = (
export const uploadTrialArtifactAPI = (
studyId: number,
trialId: number,
fileName: string,
Expand All @@ -296,7 +296,22 @@ export const uploadArtifactAPI = (
})
}

export const deleteArtifactAPI = (
export const uploadStudyArtifactAPI = (
studyId: number,
fileName: string,
dataUrl: string
): Promise<UploadArtifactAPIResponse> => {
return axiosInstance
.post<UploadArtifactAPIResponse>(`/api/artifacts/${studyId}`, {
file: dataUrl,
filename: fileName,
})
.then((res) => {
return res.data
})
}

export const deleteTrialArtifactAPI = (
studyId: number,
trialId: number,
artifactId: string
Expand All @@ -308,6 +323,17 @@ export const deleteArtifactAPI = (
})
}

export const deleteStudyArtifactAPI = (
studyId: number,
artifactId: string
): Promise<void> => {
return axiosInstance
.delete<void>(`/api/artifacts/${studyId}/${artifactId}`)
.then(() => {
return
})
}

export const tellTrialAPI = (
trialId: number,
state: TrialStateFinished,
Expand Down
Loading

0 comments on commit f3fda8f

Please sign in to comment.