Skip to content

Commit

Permalink
fix(api): add tags to endpoints for better doc readibility
Browse files Browse the repository at this point in the history
  • Loading branch information
dtrckd committed Dec 16, 2024
1 parent 176872f commit eb62077
Showing 1 changed file with 41 additions and 11 deletions.
52 changes: 41 additions & 11 deletions api/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _needs_output(db_exp):
#


@router.post("/dataset", response_model=schemas.Dataset)
@router.post("/dataset", response_model=schemas.Dataset, tags=["datasets"])
def create_dataset(dataset: schemas.DatasetCreate, db: Session = Depends(get_db)):
try:
db_dataset = crud.create_dataset(db, dataset)
Expand All @@ -40,12 +40,14 @@ def create_dataset(dataset: schemas.DatasetCreate, db: Session = Depends(get_db)
raise e


@router.get("/datasets", response_model=list[schemas.Dataset])
@router.get("/datasets", response_model=list[schemas.Dataset], tags=["datasets"])
def read_datasets(db: Session = Depends(get_db)):
return crud.get_datasets(db)


@router.get("/dataset/{id}", response_model=schemas.Dataset | schemas.DatasetFull)
@router.get(
"/dataset/{id}", response_model=schemas.Dataset | schemas.DatasetFull, tags=["datasets"]
)
def read_dataset(id: int, with_df: bool = False, db: Session = Depends(get_db)):
dataset = crud.get_dataset(db, id)
if dataset is None:
Expand All @@ -57,7 +59,7 @@ def read_dataset(id: int, with_df: bool = False, db: Session = Depends(get_db)):
return schemas.Dataset.from_orm(dataset)


@router.patch("/dataset/{id}", response_model=schemas.Dataset)
@router.patch("/dataset/{id}", response_model=schemas.Dataset, tags=["datasets"])
def patch_dataset(id: int, dataset_patch: schemas.DatasetPatch, db: Session = Depends(get_db)):
db_dataset = crud.update_dataset(db, id, dataset_patch)
if db_dataset is None:
Expand All @@ -71,7 +73,7 @@ def patch_dataset(id: int, dataset_patch: schemas.DatasetPatch, db: Session = De
#


@router.get("/metrics", response_model=list[Metric])
@router.get("/metrics", response_model=list[Metric], tags=["metrics"])
def read_metrics(db: Session = Depends(get_db)):
return crud.get_metrics(db)

Expand All @@ -85,6 +87,7 @@ def read_metrics(db: Session = Depends(get_db)):
"/experiment",
response_model=schemas.Experiment,
description="Launch an experiment. If a model is given, it will be use to generate the model output (answer), otherwise it will use the `output` column of the given dataset.",
tags=["experiments"],
)
def create_experiment(experiment: schemas.ExperimentCreate, db: Session = Depends(get_db)):
try:
Expand All @@ -108,6 +111,7 @@ def create_experiment(experiment: schemas.ExperimentCreate, db: Session = Depend
"/experiment/{id}",
response_model=schemas.Experiment,
description="Update an experiment. The given metrics will be added (or rerun) to the existing results for this experiments. Use rerun_answers if want to re-generate the answers/output.",
tags=["experiments"],
)
def patch_experiment(
id: int, experiment_patch: schemas.ExperimentPatch, db: Session = Depends(get_db)
Expand Down Expand Up @@ -144,7 +148,11 @@ def patch_experiment(


@router.delete("/experiment/{id}")
def delete_experiment(id: int, db: Session = Depends(get_db)):
def delete_experiment(
id: int,
db: Session = Depends(get_db),
tags=["experiments"],
):
if not crud.remove_experiment(db, id):
raise HTTPException(status_code=404, detail="Experiment not found")
return "ok"
Expand All @@ -157,6 +165,7 @@ def delete_experiment(id: int, db: Session = Depends(get_db)):
| schemas.ExperimentWithAnswers
| schemas.ExperimentFull
| schemas.ExperimentFullWithDataset,
tags=["experiments"],
)
def read_experiment(
id: int,
Expand All @@ -181,7 +190,11 @@ def read_experiment(
return schemas.Experiment.from_orm(experiment)


@router.get("/experiments", response_model=list[schemas.ExperimentWithResults])
@router.get(
"/experiments",
response_model=list[schemas.ExperimentWithResults],
tags=["experiments"],
)
def read_experiments(set_id: int | None = None, limit: int = 100, db: Session = Depends(get_db)):
experiments = crud.get_experiments(db, set_id=set_id, limit=limit)

Expand All @@ -196,7 +209,11 @@ def read_experiments(set_id: int | None = None, limit: int = 100, db: Session =
#


@router.post("/experiment_set", response_model=schemas.ExperimentSet)
@router.post(
"/experiment_set",
response_model=schemas.ExperimentSet,
tags=["experiment_set"],
)
def create_experimentset(experimentset: schemas.ExperimentSetCreate, db: Session = Depends(get_db)):
try:
db_expset = crud.create_experimentset(db, experimentset)
Expand All @@ -219,6 +236,7 @@ def create_experimentset(experimentset: schemas.ExperimentSetCreate, db: Session
"/experiment_set/{id}",
response_model=schemas.ExperimentSet,
description="Update an experimentset: New experiments will be added to the runner queue.",
tags=["experiment_set"],
)
def patch_experimentset(
id: int, experimentset_patch: schemas.ExperimentSetPatch, db: Session = Depends(get_db)
Expand Down Expand Up @@ -247,7 +265,11 @@ def patch_experimentset(
return db_expset


@router.get("/experiment_sets", response_model=list[schemas.ExperimentSet])
@router.get(
"/experiment_sets",
response_model=list[schemas.ExperimentSet],
tags=["experiment_set"],
)
def read_experimentsets(db: Session = Depends(get_db)):
experimentsets = crud.get_experimentsets(db)
if experimentsets is None:
Expand All @@ -256,15 +278,22 @@ def read_experimentsets(db: Session = Depends(get_db)):
# return [schemas.ExperimentSet.from_orm(x) for x in experimentsets]


@router.get("/experiment_set/{id}", response_model=schemas.ExperimentSet)
@router.get(
"/experiment_set/{id}",
response_model=schemas.ExperimentSet,
tags=["experiment_set"],
)
def read_experimentset(id: int, db: Session = Depends(get_db)):
experimentset = crud.get_experimentset(db, id)
if experimentset is None:
raise HTTPException(status_code=404, detail="ExperimentSet not found")
return experimentset


@router.delete("/experiment_set/{id}")
@router.delete(
"/experiment_set/{id}",
tags=["experiment_set"],
)
def delete_experimentset(id: int, db: Session = Depends(get_db), admin_check=Depends(admin_only)):
if not crud.remove_experimentset(db, id):
raise HTTPException(status_code=404, detail="ExperimentSet not found")
Expand All @@ -275,6 +304,7 @@ def delete_experimentset(id: int, db: Session = Depends(get_db), admin_check=Dep
"/retry/experiment_set/{id}",
response_model=schemas.RetryRuns,
description="Re-run failed runs.",
tags=["experiment_set"],
)
def retry_runs(id: int, db: Session = Depends(get_db)):
experimentset = crud.get_experimentset(db, id)
Expand Down

0 comments on commit eb62077

Please sign in to comment.