From 8e8bf02f6e22ca42a9368a192b77c5073b939a2b Mon Sep 17 00:00:00 2001 From: OmidSa75 Date: Mon, 4 Dec 2023 13:39:41 +0330 Subject: [PATCH] perf: load model in the base task class --- celery_app.py | 24 +++++++++++++++--------- client_request.py | 27 +++++++++++++++++++++------ 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/celery_app.py b/celery_app.py index 09fa3b1..eb3ffeb 100644 --- a/celery_app.py +++ b/celery_app.py @@ -8,7 +8,7 @@ device = torch.device('cpu') -def load_model(_self): +def load_model(): backbone = torchvision.models.mobilenet_v2().features backbone.out_channels = 1280 @@ -43,19 +43,25 @@ def load_model(_self): class BaseTask(celery.Task): def __init__(self) -> None: super().__init__() + self._ai_model = None + - def run(self, *args, **kwargs): - return super().run(*args, **kwargs) + @property + def ai_model(self): + if self._ai_model is None: + self._ai_model = load_model() + print("Load AI Model") + return self._ai_model -def inference_model(_self): - if not hasattr(_self, 'ai_model'): - _self.ai_model = _self.load_model() - print('Load AI model') +def inference_model(self): + # if not hasattr(_self, 'ai_model'): + # _self.ai_model = _self.load_model() + # print('Load AI model') input_x = [torch.rand(3, 300, 400).to( device), torch.rand(3, 500, 400).to(device)] - prediction = _self.ai_model(input_x) + prediction = self.ai_model(input_x) print('Hi, this is a inference function') return str(type(prediction)) @@ -72,5 +78,5 @@ def inference_model(_self): # prediction = model(input_x) celery_app.task(name='inference_model', bind=True, - load_model=load_model)(inference_model) + base=BaseTask)(inference_model) celery_app.start(['worker', '-l', 'INFO']) diff --git a/client_request.py b/client_request.py index b009acb..05cecc7 100644 --- a/client_request.py +++ b/client_request.py @@ -1,6 +1,7 @@ import celery import redis -import numpy as np +import numpy as np +import sys if __name__ == "__main__": @@ -11,9 +12,23 @@ broker='redis://:Eyval@localhost:6379/1', task_default_queue='AIWithCelery', ) - task = celery_app.send_task( - 'inference_model', - queue='AIWithCelery', - ) - result = task.get() + if len(sys.argv) > 1: + number_of_calls = sys.argv[1] + try: + number_of_calls = int(number_of_calls) + except ValueError: + print('Error:: first arg must be integer ') + sys.exit(1) + else: + number_of_calls = 1 + g = celery.group(celery_app.signature('inference_model', queue='AIWithCelery') for _ in range(number_of_calls)) + # for _ in range(number_of_calls): + + # task = celery_app.send_task( + # 'inference_model', + # queue='AIWithCelery', + # ) + # result = task.get() + task = g.apply_async() + task.get()