Skip to content

Initialize worker

Latest
Compare
Choose a tag to compare
@OmidSa75 OmidSa75 released this 09 Dec 10:59
· 2 commits to main since this release

One of the main obstacles of the previous solutions is the model will be loaded on the first request in each worker. To overcome this we use signals modules from celery.

class BaseTask(celery.Task):
    def __init__(self) -> None:
        super().__init__()
        signals.worker_init.connect(self.on_worker_init)  # this line send a signal to initialize the worker.

    def on_worker_init(self, *args, **kwargs):
        print("Loading AI Model ...")
        self._ai_model = load_model()
        print("AI Model Loaded")

    @property
    def ai_model(self):
        if self._ai_model is None:
            self._ai_model = load_model()
        return self._ai_model

However, this method only works with those models that are loaded on the CPU. If you want to send the model to the CUDA devices the code falls into the following problem.

Cannot re-initialize CUDA in forked subprocess.

So the celery needs to be executed on the spawn method. Celery doesn't support spawn yet. By following the method we force the execution to run with the spawn method.

    os.environ["FORKED_BY_MULTIPROCESSING"] = "1"
    if os.name != "nt":
        from billiard import context
        context._force_start_method("spawn")
        print('Context is changed to SPAWN')

But the signals doesn't work with it.

class BaseTask(celery.Task):
    def __init__(self) -> None:
        super().__init__()
        self._ai_model = None
    @property
    def ai_model(self):
        if self._ai_model is None:
            self._ai_model = load_model()
        return self._ai_model

The other problem that the spawn method creates is the celery_app.task doesn't register the task we should define include in the celery setup.

    celery_app = celery.Celery(
        'main_celery',
        backend='redis://:Eyval@localhost:6379/1',
        broker='redis://:Eyval@localhost:6379/1',
        task_default_queue='AIWithCelery',
        include=['celery_app']
    )