diff --git a/API/main.py b/API/main.py index 3d6e8dd..9234a46 100644 --- a/API/main.py +++ b/API/main.py @@ -17,6 +17,19 @@ class PredictionRequest(BaseModel): """ Request model for the prediction endpoint. + + Example : + { + "bbox": [ + 100.56228021333352, + 13.685230854641182, + 100.56383321235313, + 13.685961853747969 + ], + "checkpoint": "https://fair-dev.hotosm.org/api/v1/workspace/download/dataset_58/output/training_324//checkpoint.tflite", + "zoom_level": 20, + "source": "https://tiles.openaerialmap.org/6501a65c0906de000167e64d/0/6501a65c0906de000167e64e/{z}/{x}/{y}" + } """ bbox: List[float] @@ -127,6 +140,25 @@ def validate_zoom_level(cls, value): raise ValueError("Zoom level should be between 18 and 22") return value + @validator("checkpoint") + def validate_checkpoint(cls, value): + """ + Validates checkpoint parameter. If URL, download the file to temp directory. + """ + if value.startswith("http"): + response = requests.get(value) + if response.status_code != 200: + raise ValueError( + "Failed to download model checkpoint from the provided URL" + ) + _, temp_file_path = tempfile.mkstemp(suffix=".tflite") + with open(temp_file_path, "wb") as f: + f.write(response.content) + return temp_file_path + elif not os.path.exists(value): + raise ValueError("Model checkpoint file not found") + return value + @app.post("/predict/") async def predict_api(request: PredictionRequest):