Skip to content

Commit

Permalink
Add example and validator to checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijrajsharma committed Nov 23, 2023
1 parent d9f9f16 commit 70c45a8
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions API/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@
class PredictionRequest(BaseModel):
"""
Request model for the prediction endpoint.
Example :
{
"bbox": [
100.56228021333352,
13.685230854641182,
100.56383321235313,
13.685961853747969
],
"checkpoint": "https://fair-dev.hotosm.org/api/v1/workspace/download/dataset_58/output/training_324//checkpoint.tflite",
"zoom_level": 20,
"source": "https://tiles.openaerialmap.org/6501a65c0906de000167e64d/0/6501a65c0906de000167e64e/{z}/{x}/{y}"
}
"""

bbox: List[float]
Expand Down Expand Up @@ -127,6 +140,25 @@ def validate_zoom_level(cls, value):
raise ValueError("Zoom level should be between 18 and 22")
return value

@validator("checkpoint")
def validate_checkpoint(cls, value):
"""
Validates checkpoint parameter. If URL, download the file to temp directory.
"""
if value.startswith("http"):
response = requests.get(value)
if response.status_code != 200:
raise ValueError(
"Failed to download model checkpoint from the provided URL"
)
_, temp_file_path = tempfile.mkstemp(suffix=".tflite")
with open(temp_file_path, "wb") as f:
f.write(response.content)
return temp_file_path
elif not os.path.exists(value):
raise ValueError("Model checkpoint file not found")
return value


@app.post("/predict/")
async def predict_api(request: PredictionRequest):
Expand Down

0 comments on commit 70c45a8

Please sign in to comment.