Skip to content

Commit

Permalink
perf: load model in the base task class
Browse files Browse the repository at this point in the history
  • Loading branch information
OmidSa75 committed Dec 4, 2023
1 parent 40de774 commit 8e8bf02
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 15 deletions.
24 changes: 15 additions & 9 deletions celery_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
device = torch.device('cpu')


def load_model(_self):
def load_model():

backbone = torchvision.models.mobilenet_v2().features
backbone.out_channels = 1280
Expand Down Expand Up @@ -43,19 +43,25 @@ def load_model(_self):
class BaseTask(celery.Task):
def __init__(self) -> None:
super().__init__()
self._ai_model = None


def run(self, *args, **kwargs):
return super().run(*args, **kwargs)
@property
def ai_model(self):
if self._ai_model is None:
self._ai_model = load_model()
print("Load AI Model")
return self._ai_model


def inference_model(_self):
if not hasattr(_self, 'ai_model'):
_self.ai_model = _self.load_model()
print('Load AI model')
def inference_model(self):
# if not hasattr(_self, 'ai_model'):
# _self.ai_model = _self.load_model()
# print('Load AI model')

input_x = [torch.rand(3, 300, 400).to(
device), torch.rand(3, 500, 400).to(device)]
prediction = _self.ai_model(input_x)
prediction = self.ai_model(input_x)
print('Hi, this is a inference function')
return str(type(prediction))

Expand All @@ -72,5 +78,5 @@ def inference_model(_self):
# prediction = model(input_x)

celery_app.task(name='inference_model', bind=True,
load_model=load_model)(inference_model)
base=BaseTask)(inference_model)
celery_app.start(['worker', '-l', 'INFO'])
27 changes: 21 additions & 6 deletions client_request.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import celery
import redis
import numpy as np
import numpy as np
import sys


if __name__ == "__main__":
Expand All @@ -11,9 +12,23 @@
broker='redis://:Eyval@localhost:6379/1',
task_default_queue='AIWithCelery',
)
task = celery_app.send_task(
'inference_model',
queue='AIWithCelery',
)
result = task.get()
if len(sys.argv) > 1:
number_of_calls = sys.argv[1]
try:
number_of_calls = int(number_of_calls)
except ValueError:
print('Error:: first arg must be integer <Number_Of_Calls>')
sys.exit(1)
else:
number_of_calls = 1
g = celery.group(celery_app.signature('inference_model', queue='AIWithCelery') for _ in range(number_of_calls))
# for _ in range(number_of_calls):

# task = celery_app.send_task(
# 'inference_model',
# queue='AIWithCelery',
# )
# result = task.get()
task = g.apply_async()
task.get()

0 comments on commit 8e8bf02

Please sign in to comment.