Skip to content

Commit

Permalink
perf: load AI model with celery function task
Browse files Browse the repository at this point in the history
  • Loading branch information
OmidSa75 committed Dec 4, 2023
0 parents commit 40de774
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 0 deletions.
76 changes: 76 additions & 0 deletions celery_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import celery
import redis
import torch
import torchvision
from torchvision.models.detection import MaskRCNN
from torchvision.models.detection.anchor_utils import AnchorGenerator

device = torch.device('cpu')


def load_model(_self):

backbone = torchvision.models.mobilenet_v2().features
backbone.out_channels = 1280
anchor_generator = AnchorGenerator(
sizes=((32, 642, 128, 256, 512,), ),
aspect_ratios=((0.5, 1.0, 2.0), ),
)
roi_pooler = torchvision.ops.MultiScaleRoIAlign(
featmap_names=['0'],
output_size=7,
sampling_ratio=2
)

mask_roi_pooler = torchvision.ops.MultiScaleRoIAlign(
featmap_names=['0'],
output_size=14,
sampling_ratio=2,
)

model = MaskRCNN(
backbone,
num_classes=2,
rpn_anchor_generator=anchor_generator,
box_roi_pool=roi_pooler,
mask_roi_pool=mask_roi_pooler,
)
model.eval()
model.to(device)
return model


class BaseTask(celery.Task):
def __init__(self) -> None:
super().__init__()

def run(self, *args, **kwargs):
return super().run(*args, **kwargs)


def inference_model(_self):
if not hasattr(_self, 'ai_model'):
_self.ai_model = _self.load_model()
print('Load AI model')

input_x = [torch.rand(3, 300, 400).to(
device), torch.rand(3, 500, 400).to(device)]
prediction = _self.ai_model(input_x)
print('Hi, this is a inference function')
return str(type(prediction))

if __name__ == "__main__":
celery_app = celery.Celery(
'main_celery',
backend='redis://:Eyval@localhost:6379/1',
broker='redis://:Eyval@localhost:6379/1',
task_default_queue='AIWithCelery',
)

# input_x = [torch.rand(3, 300, 400).to(
# device), torch.rand(3, 500, 400).to(device)]
# prediction = model(input_x)

celery_app.task(name='inference_model', bind=True,
load_model=load_model)(inference_model)
celery_app.start(['worker', '-l', 'INFO'])
19 changes: 19 additions & 0 deletions client_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import celery
import redis
import numpy as np


if __name__ == "__main__":

celery_app = celery.Celery(
'main_celery',
backend='redis://:Eyval@localhost:6379/1',
broker='redis://:Eyval@localhost:6379/1',
task_default_queue='AIWithCelery',
)
task = celery_app.send_task(
'inference_model',
queue='AIWithCelery',
)
result = task.get()

0 comments on commit 40de774

Please sign in to comment.