Skip to content

Commit

Permalink
Merge pull request #214 from CogStack/bg-process-preds
Browse files Browse the repository at this point in the history
 CU-8695x1dy9: Changes for running predictions async in a bg process
  • Loading branch information
tomolopolis authored Nov 29, 2024
2 parents 7c49911 + 92ac80f commit 3a9337f
Show file tree
Hide file tree
Showing 19 changed files with 367 additions and 139 deletions.
20 changes: 20 additions & 0 deletions docker-compose-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ services:
context: ./webapp
args:
SPACY_MODELS: ${SPACY_MODELS:-en_core_web_md}
image: medcattrainer-api
restart: always
volumes:
- ./webapp/api/core:/home/api/core
Expand All @@ -23,6 +24,25 @@ services:
- MCT_VERSION=latest
command: /home/scripts/run.sh

# bg process task runner
medcattrainer-bg-process:
image: medcattrainer-api
depends_on:
- medcattrainer
restart: always
volumes:
- ./webapp/api/core:/home/api/core
- ./webapp/api/api:/home/api/api
- ./webapp/scripts/run-bg-process.sh:/home/scripts/run-bg-process.sh
- ./configs:/home/configs
- api-media:/home/api/media
- api-static:/home/api/static
- api-db:/home/api/db
env_file:
- ./envs/env
command: /home/scripts/run-bg-process.sh


nginx:
image: nginx
restart: always
Expand Down
14 changes: 14 additions & 0 deletions docker-compose-prod.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@ services:
- MCT_VERSION=v2.17.4
command: /home/scripts/run.sh

# bg process task runner
medcattrainer-bg-process:
image: cogstacksystems/medcat-trainer:v2.17.1
restart: always
volumes:
- ./configs:/home/configs
- api-media:/home/api/media
- api-static:/home/api/static
- api-db:/home/api/db
- api-db-backup:/home/api/db-backup
env_file:
- ./envs/env
command: /home/scripts/run-bg-process.sh

# crontab - for db backup
medcattrainer-db-backup:
image: cogstacksystems/medcat-trainer:v2.17.4
Expand Down
15 changes: 15 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# projects are not used.

services:
# api server
medcattrainer:
image: cogstacksystems/medcat-trainer:v2.17.4
restart: always
Expand All @@ -17,6 +18,20 @@ services:
- MCT_VERSION=v2.17.4
command: /home/scripts/run.sh

# bg process task runner
medcattrainer-bg-process:
image: cogstacksystems/medcat-trainer:v2.17.4
restart: always
volumes:
- ./configs:/home/configs
- api-media:/home/api/media
- api-static:/home/api/static
- api-db:/home/api/db
- api-db-backup:/home/api/db-backup
env_file:
- ./envs/env
command: /home/scripts/run-bg-process.sh

# crontab - for db backup
medcattrainer-db-backup:
image: cogstacksystems/medcat-trainer:v2.17.4
Expand Down
2 changes: 2 additions & 0 deletions envs/env
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ OPENBLAS_NUM_THREADS=1

### MedCAT cfg ###
MEDCAT_CONFIG_FILE=/home/configs/base.txt
# number of MedCAT models that can be cached, run in bg processes at any one time
MAX_MEDCAT_MODELS=2

### Deployment Realm ###
ENV=non-prod
Expand Down
2 changes: 2 additions & 0 deletions envs/env-prod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ OPENBLAS_NUM_THREADS=1

### MedCAT cfg ###
MEDCAT_CONFIG_FILE=/home/configs/base.txt
# number of MedCAT models that can be cached, run in bg processes at any one time
MAX_MEDCAT_MODELS=2
ENV=prod

# SECRET KEY - edit this for prod deployments,
Expand Down
1 change: 1 addition & 0 deletions webapp/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ RUN for SPACY_MODEL in ${SPACY_MODELS}; do python -m spacy download ${SPACY_MODE

WORKDIR /home/api/
RUN chmod a+x /home/scripts/run.sh
RUN chmod a+x /home/scripts/run-bg-process.sh
17 changes: 9 additions & 8 deletions webapp/api/api/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Dict

from django.contrib.auth.models import User
from django.db import transaction
from django.db.models import Q

from core.settings import MEDIA_ROOT
Expand Down Expand Up @@ -42,14 +43,14 @@ def dataset_from_file(dataset: Dataset):
"The 'name' column are document IDs, and the 'text' column is the text you're "
"collecting annotations for")


for i, row in enumerate(df.iterrows()):
row = row[1]
document = Document()
document.name = row['name']
document.text = sanitise_input(row['text'])
document.dataset = dataset
document.save()
with transaction.atomic():
for i, row in enumerate(df.iterrows()):
row = row[1]
document = Document()
document.name = row['name']
document.text = sanitise_input(row['text'])
document.dataset = dataset
document.save()


def sanitise_input(text: str):
Expand Down
14 changes: 14 additions & 0 deletions webapp/api/api/model_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,22 @@
VOCAB_MAP = {}
CAT_MAP = {}

_MAX_MODELS_LOADED = os.getenv("MAX_MEDCAT_MODELS", 1)

logger = logging.getLogger(__name__)


def _clear_models(cdb_map: Dict[str, CDB]=CDB_MAP,
vocab_map: Dict[str, Vocab]=VOCAB_MAP,
cat_map: Dict[str, CAT]=CAT_MAP):
if len(cat_map) == _MAX_MODELS_LOADED:
(k := next(iter(cat_map)), cat_map.pop(k))
if len(cdb_map) == _MAX_MODELS_LOADED:
(k := next(iter(cdb_map)), cdb_map.pop(k))
if len(vocab_map) == _MAX_MODELS_LOADED:
(k := next(iter(vocab_map)), vocab_map.pop(k))


def get_medcat_from_cdb_vocab(project,
cdb_map: Dict[str, CDB]=CDB_MAP,
vocab_map: Dict[str, Vocab]=VOCAB_MAP,
Expand Down Expand Up @@ -61,6 +73,7 @@ def get_medcat_from_cdb_vocab(project,
vocab_map[vocab_id] = vocab
cat = CAT(cdb=cdb, config=cdb.config, vocab=vocab)
cat_map[cat_id] = cat
_clear_models(cat_map=cat_map, cdb_map=cdb_map, vocab_map=vocab_map)
return cat


Expand All @@ -70,6 +83,7 @@ def get_medcat_from_model_pack(project, cat_map: Dict[str, CAT]=CAT_MAP) -> CAT:
logger.info('Loading model pack from:%s', model_pack_obj.model_pack.path)
cat = CAT.load_model_pack(model_pack_obj.model_pack.path)
cat_map[cat_id] = cat
_clear_models(cat_map=cat_map)
return cat


Expand Down
23 changes: 13 additions & 10 deletions webapp/api/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from background_task import background
from django.contrib.auth.models import User
from django.db import transaction
from django.db.models.signals import post_save
from django.dispatch import receiver
from medcat.cat import CAT
Expand Down Expand Up @@ -241,26 +242,28 @@ def prep_docs(project_id: List[int], doc_ids: List[int], user_id: int):
project = ProjectAnnotateEntities.objects.get(id=project_id)
docs = Document.objects.filter(id__in=doc_ids)

logger.info('Loading CAT object in bg process')
logger.info('Loading CAT object in bg process for project: %s', project.id)
cat = get_medcat(project=project)

# Set CAT filters
cat.config.linking['filters']['cuis'] = project.cuis

for doc in docs:
logger.info(f'Running MedCAT model over doc: {doc.id}')
logger.info(f'Running MedCAT model for project {project.id}:{project.name} over doc: {doc.id}')
spacy_doc = cat(doc.text)
anns = AnnotatedEntity.objects.filter(document=doc).filter(project=project)

add_annotations(spacy_doc=spacy_doc,
user=user,
project=project,
document=doc,
cat=cat,
existing_annotations=anns)
# add doc to prepared_documents
with transaction.atomic():
add_annotations(spacy_doc=spacy_doc,
user=user,
project=project,
document=doc,
cat=cat,
existing_annotations=anns)
# add doc to prepared_documents
project.prepared_documents.add(doc)
project.save()
logger.info('Prepared all docs for project: %s, docs processed: %s',
project.id, project.prepared_documents)


@receiver(post_save, sender=ProjectAnnotateEntities)
Expand Down
105 changes: 71 additions & 34 deletions webapp/api/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from background_task.models import Task, CompletedTask
from django.contrib.auth.views import PasswordResetView
from django.core.exceptions import ObjectDoesNotExist
from django.db import transaction
from django.http import HttpResponseBadRequest, HttpResponseServerError, HttpResponse
from django.shortcuts import render
from django.utils import timezone
Expand Down Expand Up @@ -247,12 +249,6 @@ def prepare_documents(request):
'description': 'Missing CUI filter file, %s, cannot be found on the filesystem, '
'but is still set on the project. To fix remove and reset the '
'cui filter file' % project.cuis_file}, status=500)

if request.data.get('bg_task'):
# execute model infer in bg
job = prep_docs(p_id, d_ids, user.id)
return Response({'bg_job_id': job.id})

try:
for d_id in d_ids:
document = Document.objects.get(id=d_id)
Expand All @@ -268,26 +264,28 @@ def prepare_documents(request):

is_validated = document in project.validated_documents.all()

# If the document is not already annotated, annotate it
if (len(anns) == 0 and not is_validated) or update:
# Based on the project id get the right medcat
cat = get_medcat(project=project)
logger.info('loaded medcat model for project: %s', project.id)

# Set CAT filters
cat.config.linking['filters']['cuis'] = cuis

spacy_doc = cat(document.text)
add_annotations(spacy_doc=spacy_doc,
user=user,
project=project,
document=document,
cat=cat,
existing_annotations=anns)

# add doc to prepared_documents
project.prepared_documents.add(document)
project.save()
with transaction.atomic():
# If the document is not already annotated, annotate it
if (len(anns) == 0 and not is_validated) or update:
# Based on the project id get the right medcat
cat = get_medcat(project=project)
logger.info('loaded medcat model for project: %s', project.id)

# Set CAT filters
cat.config.linking['filters']['cuis'] = cuis

spacy_doc = cat(document.text)

add_annotations(spacy_doc=spacy_doc,
user=user,
project=project,
document=document,
cat=cat,
existing_annotations=anns)

# add doc to prepared_documents
project.prepared_documents.add(document)
project.save()

except Exception as e:
stack = traceback.format_exc()
Expand All @@ -297,24 +295,59 @@ def prepare_documents(request):
return Response({'message': 'Documents prepared successfully'})


@api_view(http_method_names=['POST'])
def prepare_documents_bg(request):
user = request.user
# Get project id
p_id = request.data['project_id']
project = ProjectAnnotateEntities.objects.get(id=p_id)
docs = Document.objects.filter(dataset=project.dataset)

# Get docs that have no AnnotatedEntities
d_ids = [d.id for d in docs if len(AnnotatedEntity.objects.filter(document=d).filter(project=project)) == 0 or
d in project.validated_documents.all()]

# execute model infer in bg
job = prep_docs(p_id, d_ids, user.id)
return Response({'bg_job_id': job.id})


@api_view(http_method_names=['GET'])
def prepare_docs_bg_tasks(request):
proj_id = int(request.GET['project'])
def prepare_docs_bg_tasks(_):
running_doc_prep_tasks = Task.objects.filter(queue='doc_prep')
completed_doc_prep_tasks = CompletedTask.objects.filter(queue='doc_prep')

def transform_task_params(task_params_str):
task_params = json.loads(task_params_str)[0]
return {
'document': task_params[1][0],
'project': task_params[0],
'user_id': task_params[2]
}
running_tasks = [transform_task_params(task.task_params) for task in running_doc_prep_tasks
if json.loads(task.task_params)[0][0] == proj_id]
complete_tasks = [transform_task_params(task.task_params) for task in completed_doc_prep_tasks
if json.loads(task.task_params)[0][0] == proj_id]
running_tasks = [transform_task_params(task.task_params) for task in running_doc_prep_tasks]
complete_tasks = [transform_task_params(task.task_params) for task in completed_doc_prep_tasks]
return Response({'running_tasks': running_tasks, 'comp_tasks': complete_tasks})


@api_view(http_method_names=['GET', 'DELETE'])
def prepare_docs_bg_task(request, proj_id):
if request.method == 'GET':
# state of bg running process as determined by prepared docs
try:
proj = ProjectAnnotateEntities.objects.get(id=proj_id)
prepd_docs_count = proj.prepared_documents.count()
ds_total_count = Document.objects.filter(dataset=ProjectAnnotateEntities.objects.get(id=proj_id).dataset.id).count()
return Response({'proj_id': proj_id, 'dataset_len': ds_total_count, 'prepd_docs_len': prepd_docs_count})
except ObjectDoesNotExist:
return HttpResponseBadRequest('No Project found for ID: %s', proj_id)
else:
running_doc_prep_tasks = {json.loads(task.task_params)[0][0]: task.id
for task in Task.objects.filter(queue='doc_prep')}
if proj_id in running_doc_prep_tasks:
Task.objects.filter(id=running_doc_prep_tasks[proj_id]).delete()
return Response("Successfully stopped running response")
else:
return HttpResponseBadRequest('Could not find running BG Process to stop')

@api_view(http_method_names=['POST'])
def add_annotation(request):
# Get project id
Expand Down Expand Up @@ -623,7 +656,11 @@ def version(_):
def concept_search_index_available(request):
cdb_ids = request.GET.get('cdbs', '').split(',')
cdb_ids = [c for c in cdb_ids if len(c)]
return collections_available(cdb_ids)
try:
return collections_available(cdb_ids)
except Exception as e:
return HttpResponseServerError("Solr Search Service not available check the service is up, running "
"and configured correctly. %s", e)


@api_view(http_method_names=['GET'])
Expand Down
11 changes: 8 additions & 3 deletions webapp/api/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

environ_origins = os.environ.get('CSRF_TRUSTED_ORIGINS', None)
trusted_origins = [] if environ_origins is None else environ_origins.split(',')
CSRF_TRUSTED_ORIGINS = ['https://127.0.0.1:8001', 'http://localhost:8001'] + trusted_origins
environ_origins = os.environ.get('CSRF_TRUSTED_ORIGINS', '')
trusted_origins = [origin.strip() for origin in environ_origins.split(',') if origin.strip()]

CSRF_TRUSTED_ORIGINS = ['http://127.0.0.1:8001', 'http://localhost:8001'] + trusted_origins

SECURE_CROSS_ORIGIN_OPENER_POLICY = None

Expand Down Expand Up @@ -129,6 +130,10 @@
'default': {
'ENGINE': 'django.db.backends.sqlite3',
'NAME': os.path.join(BASE_DIR, 'db/db.sqlite3'),
'OPTIONS': {
'timeout': 20,
'transaction_mode': 'IMMEDIATE'
}
}
}

Expand Down
Loading

0 comments on commit 3a9337f

Please sign in to comment.