-
Notifications
You must be signed in to change notification settings - Fork 0
/
flash_model_handler.py
43 lines (36 loc) · 1.29 KB
/
flash_model_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from flash.image import ObjectDetector, ObjectDetectionData
from flash.core.trainer import Trainer
from flash.core.utilities.stages import RunningStage
from PIL.Image import Image
from nuclio_detection_labels_output import NuclioDetectionLabelsOutput
class MockTrainer(Trainer):
def __init__(self):
super().__init__()
self.state.stage = RunningStage.PREDICTING # type: ignore
class FlashModelHandler:
def __init__(self, model: ObjectDetector, image_size=1024, labels={}):
self.image_size = image_size
self.labels = labels
self.model = model
self.trainer = MockTrainer()
self.model.eval()
def infer(self, image: Image, threshold: float = 0.0):
path = "/tmp/image.jpg"
image.save(path)
datamodule = ObjectDetectionData.from_files(
predict_files=[path],
transform_kwargs={"image_size": self.image_size},
batch_size=1,
)
predictions = self.trainer.predict(
self.model,
datamodule=datamodule,
output=NuclioDetectionLabelsOutput(
threshold=threshold,
labels=self.labels,
image=image,
),
)
if predictions is None:
return []
return predictions[0][0]