Skip to content

Commit

Permalink
Move logic for downloading resources from CS into main background pro…
Browse files Browse the repository at this point in the history
…cess
  • Loading branch information
Marishka17 committed Mar 5, 2024
1 parent a9333db commit 8b5892f
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 106 deletions.
32 changes: 11 additions & 21 deletions cvat/apps/engine/backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@
LabeledDataSerializer, SegmentSerializer, SimpleJobSerializer, TaskReadSerializer,
ProjectReadSerializer, ProjectFileSerializer, TaskFileSerializer, RqIdSerializer)
from cvat.apps.engine.utils import (
av_scan_paths, process_failed_job, configure_dependent_job_to_download_from_cs,
av_scan_paths, process_failed_job,
get_rq_job_meta, get_import_rq_id, import_resource_with_clean_up_after,
sendfile, define_dependent_job, get_rq_lock_by_user, build_backup_file_name,
)
from cvat.apps.engine.models import (
StorageChoice, StorageMethodChoice, DataChoice, Task, Project, Location)
from cvat.apps.engine.task import JobFileMapping, _create_thread
from cvat.apps.engine.cloud_provider import download_file_from_bucket, export_resource_to_cloud_storage
from cvat.apps.engine.cloud_provider import import_resource_from_cloud_storage, export_resource_to_cloud_storage
from cvat.apps.engine.location import StorageType, get_location_configuration
from cvat.apps.engine.view_utils import get_cloud_storage_for_import_or_export
from cvat.apps.dataset_manager.views import TASK_CACHE_TTL, PROJECT_CACHE_TTL, get_export_cache_dir, clear_export_cache, log_exception
Expand Down Expand Up @@ -1051,7 +1051,9 @@ def _import(importer, request, queue, rq_id, Serializer, file_field_name, locati

if not rq_job:
org_id = getattr(request.iam_context['organization'], 'id', None)
dependent_job = None

func = import_resource_with_clean_up_after
func_args = (importer, filename, request.user.id, org_id)

location = location_conf.get('location')
if location == Location.LOCAL:
Expand Down Expand Up @@ -1084,42 +1086,30 @@ def _import(importer, request, queue, rq_id, Serializer, file_field_name, locati
with NamedTemporaryFile(prefix='cvat_', dir=settings.TMP_FILES_ROOT, delete=False) as tf:
filename = tf.name

dependent_job = configure_dependent_job_to_download_from_cs(
queue=queue,
rq_id=rq_id,
rq_func=download_file_from_bucket,
db_storage=db_storage,
filename=filename,
key=key,
request=request,
result_ttl=settings.IMPORT_CACHE_SUCCESS_TTL.total_seconds(),
failure_ttl=settings.IMPORT_CACHE_FAILED_TTL.total_seconds()
)
func_args = (db_storage, key, func) + func_args
func = import_resource_from_cloud_storage

user_id = request.user.id

with get_rq_lock_by_user(queue, user_id):
rq_job = queue.enqueue_call(
func=import_resource_with_clean_up_after,
args=(importer, filename, request.user.id, org_id),
func=func,
args=func_args,
job_id=rq_id,
meta={
'tmp_file': filename,
**get_rq_job_meta(request=request, db_obj=None)
},
depends_on=dependent_job or define_dependent_job(queue, user_id),
depends_on=define_dependent_job(queue, user_id),
result_ttl=settings.IMPORT_CACHE_SUCCESS_TTL.total_seconds(),
failure_ttl=settings.IMPORT_CACHE_FAILED_TTL.total_seconds()
)
else:
if rq_job.is_finished:
if rq_job.dependency:
rq_job.dependency.delete()
project_id = rq_job.return_value()
rq_job.delete()
return Response({'id': project_id}, status=status.HTTP_201_CREATED)
elif rq_job.is_failed or \
rq_job.is_deferred and rq_job.dependency and rq_job.dependency.is_failed:
elif rq_job.is_failed:
exc_info = process_failed_job(rq_job)
# RQ adds a prefix with exception class name
import_error_prefix = '{}.{}'.format(
Expand Down
16 changes: 14 additions & 2 deletions cvat/apps/engine/cloud_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from enum import Enum
from io import BytesIO
from multiprocessing.pool import ThreadPool
from typing import Dict, List, Optional, Any, Callable
from typing import Dict, List, Optional, Any, Callable, TypeVar

import boto3
from azure.core.exceptions import HttpResponseError, ResourceExistsError
Expand Down Expand Up @@ -963,12 +963,24 @@ def db_storage_to_storage_instance(db_storage):
}
return get_cloud_storage_instance(cloud_provider=db_storage.provider_type, **details)

def download_file_from_bucket(db_storage: Any, filename: str, key: str) -> None:
T = TypeVar('T', Callable[[str, int, int], int], Callable[[str, int, str, bool], None])

def import_resource_from_cloud_storage(
db_storage: Any,
key: str,
cleanup_func: Callable[[T, str,], Any],
import_func: T,
filename: str,
*args,
**kwargs,
) -> Any:
storage = db_storage_to_storage_instance(db_storage)

with storage.download_fileobj(key) as data, open(filename, 'wb+') as f:
f.write(data.getbuffer())

return cleanup_func(import_func, filename, *args, **kwargs)

def export_resource_to_cloud_storage(
db_storage: Any,
key: str,
Expand Down
47 changes: 3 additions & 44 deletions cvat/apps/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,7 @@ def parse_exception_message(msg):
return parsed_msg

def process_failed_job(rq_job: Job):
exc_info = str(rq_job.exc_info or getattr(rq_job.dependency, 'exc_info', None) or '')
if rq_job.dependency:
rq_job.dependency.delete()
exc_info = str(rq_job.exc_info or '')
rq_job.delete()

msg = parse_exception_message(exc_info)
Expand Down Expand Up @@ -204,50 +202,11 @@ def define_dependent_job(
return Dependency(jobs=[sorted(user_jobs, key=lambda job: job.created_at)[-1]], allow_failure=True) if user_jobs else None


def get_rq_lock_by_user(queue: DjangoRQ, user_id: int, additional_condition: bool = True) -> Union[Lock, nullcontext]:
if settings.ONE_RUNNING_JOB_IN_QUEUE_PER_USER and additional_condition:
def get_rq_lock_by_user(queue: DjangoRQ, user_id: int) -> Union[Lock, nullcontext]:
if settings.ONE_RUNNING_JOB_IN_QUEUE_PER_USER:
return queue.connection.lock(f'{queue.name}-lock-{user_id}', timeout=30)
return nullcontext()


def configure_dependent_job_to_download_from_cs(
queue: DjangoRQ,
rq_id: str,
rq_func: Callable[[Any, str, str], None],
db_storage: Any,
filename: str,
key: str,
request: HttpRequest,
result_ttl: float,
failure_ttl: float,
should_be_dependent: bool = settings.ONE_RUNNING_JOB_IN_QUEUE_PER_USER,
) -> Job:
rq_job_id_download_file = rq_id + f'?action=download_{key.replace("/", ".")}'
rq_job_download_file = queue.fetch_job(rq_job_id_download_file)

if rq_job_download_file and (rq_job_download_file.is_finished or rq_job_download_file.is_failed):
rq_job_download_file.delete()
rq_job_download_file = None

if not rq_job_download_file:
# note: boto3 resource isn't pickleable, so we can't use storage
user_id = request.user.id

with get_rq_lock_by_user(queue, user_id):
rq_job_download_file = queue.enqueue_call(
func=rq_func,
args=(db_storage, filename, key),
job_id=rq_job_id_download_file,
meta={
**get_rq_job_meta(request=request, db_obj=db_storage),
KEY_TO_EXCLUDE_FROM_DEPENDENCY: True,
},
result_ttl=result_ttl,
failure_ttl=failure_ttl,
depends_on=define_dependent_job(queue, user_id, should_be_dependent, rq_id=rq_job_id_download_file)
)
return rq_job_download_file

def get_rq_job_meta(request, db_obj):
# to prevent circular import
from cvat.apps.webhooks.signals import project_id, organization_id
Expand Down
61 changes: 22 additions & 39 deletions cvat/apps/engine/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

import cvat.apps.dataset_manager as dm
import cvat.apps.dataset_manager.views # pylint: disable=unused-import
from cvat.apps.engine.cloud_provider import db_storage_to_storage_instance, download_file_from_bucket, export_resource_to_cloud_storage
from cvat.apps.engine.cloud_provider import db_storage_to_storage_instance, import_resource_from_cloud_storage, export_resource_to_cloud_storage
from cvat.apps.events.handlers import handle_dataset_export, handle_dataset_import
from cvat.apps.dataset_manager.bindings import CvatImportError
from cvat.apps.dataset_manager.serializers import DatasetFormatsSerializer
Expand Down Expand Up @@ -70,7 +70,7 @@

from utils.dataset_manifest import ImageManifestManager
from cvat.apps.engine.utils import (
av_scan_paths, process_failed_job, configure_dependent_job_to_download_from_cs,
av_scan_paths, process_failed_job,
parse_exception_message, get_rq_job_meta, get_import_rq_id,
import_resource_with_clean_up_after, sendfile, define_dependent_job, get_rq_lock_by_user,
build_annotations_file_name,
Expand Down Expand Up @@ -2836,11 +2836,12 @@ def _import_annotations(request, rq_id_template, rq_func, db_obj, format_name,
# If filename is specified we consider that file was uploaded via TUS, so it exists in filesystem
# Then we dont need to create temporary file
# Or filename specify key in cloud storage so we need to download file
dependent_job = None
location = location_conf.get('location') if location_conf else Location.LOCAL

db_storage = None

func = import_resource_with_clean_up_after
func_args = (rq_func, filename, db_obj.pk, format_name, conv_mask_to_poly)

if not filename or location == Location.CLOUD_STORAGE:
if location != Location.CLOUD_STORAGE:
serializer = AnnotationFileSerializer(data=request.data)
Expand Down Expand Up @@ -2873,27 +2874,18 @@ def _import_annotations(request, rq_id_template, rq_func, db_obj, format_name,
delete=False) as tf:
filename = tf.name

dependent_job = configure_dependent_job_to_download_from_cs(
queue=queue,
rq_id=rq_id,
rq_func=download_file_from_bucket,
db_storage=db_storage,
filename=filename,
key=key,
request=request,
result_ttl=settings.IMPORT_CACHE_SUCCESS_TTL.total_seconds(),
failure_ttl=settings.IMPORT_CACHE_FAILED_TTL.total_seconds()
)
func_args = (db_storage, key, func) + func_args
func = import_resource_from_cloud_storage

av_scan_paths(filename)
user_id = request.user.id

with get_rq_lock_by_user(queue, user_id, additional_condition=not dependent_job):
with get_rq_lock_by_user(queue, user_id):
rq_job = queue.enqueue_call(
func=import_resource_with_clean_up_after,
args=(rq_func, filename, db_obj.pk, format_name, conv_mask_to_poly),
func=func,
args=func_args,
job_id=rq_id,
depends_on=dependent_job or define_dependent_job(queue, user_id, rq_id=rq_id),
depends_on=define_dependent_job(queue, user_id, rq_id=rq_id),
meta={
'tmp_file': filename,
**get_rq_job_meta(request=request, db_obj=db_obj),
Expand All @@ -2910,12 +2902,9 @@ def _import_annotations(request, rq_id_template, rq_func, db_obj, format_name,
return Response(serializer.data, status=status.HTTP_202_ACCEPTED)
else:
if rq_job.is_finished:
if rq_job.dependency:
rq_job.dependency.delete()
rq_job.delete()
return Response(status=status.HTTP_201_CREATED)
elif rq_job.is_failed or \
rq_job.is_deferred and rq_job.dependency and rq_job.dependency.is_failed:
elif rq_job.is_failed:
exc_info = process_failed_job(rq_job)

import_error_prefix = f'{CvatImportError.__module__}.{CvatImportError.__name__}:'
Expand Down Expand Up @@ -3090,10 +3079,13 @@ def _import_project_dataset(request, rq_id_template, rq_func, db_obj, format_nam
# (e.g the user closed the browser tab when job has been created
# but no one requests for checking status were not made)
rq_job.delete()
dependent_job = None

location = location_conf.get('location') if location_conf else None
db_storage = None

func = import_resource_with_clean_up_after
func_args = (rq_func, filename, db_obj.pk, format_name, conv_mask_to_poly)

if not filename and location != Location.CLOUD_STORAGE:
serializer = DatasetFileSerializer(data=request.data)
if serializer.is_valid(raise_exception=True):
Expand Down Expand Up @@ -3125,30 +3117,21 @@ def _import_project_dataset(request, rq_id_template, rq_func, db_obj, format_nam
delete=False) as tf:
filename = tf.name

dependent_job = configure_dependent_job_to_download_from_cs(
queue=queue,
rq_id=rq_id,
rq_func=download_file_from_bucket,
db_storage=db_storage,
filename=filename,
key=key,
request=request,
result_ttl=settings.IMPORT_CACHE_SUCCESS_TTL.total_seconds(),
failure_ttl=settings.IMPORT_CACHE_FAILED_TTL.total_seconds()
)
func_args = (db_storage, key, func) + func_args
func = import_resource_from_cloud_storage

user_id = request.user.id

with get_rq_lock_by_user(queue, user_id, additional_condition=not dependent_job):
with get_rq_lock_by_user(queue, user_id):
rq_job = queue.enqueue_call(
func=import_resource_with_clean_up_after,
args=(rq_func, filename, db_obj.pk, format_name, conv_mask_to_poly),
func=func,
args=func_args,
job_id=rq_id,
meta={
'tmp_file': filename,
**get_rq_job_meta(request=request, db_obj=db_obj),
},
depends_on=dependent_job or define_dependent_job(queue, user_id, rq_id=rq_id),
depends_on=define_dependent_job(queue, user_id, rq_id=rq_id),
result_ttl=settings.IMPORT_CACHE_SUCCESS_TTL.total_seconds(),
failure_ttl=settings.IMPORT_CACHE_FAILED_TTL.total_seconds()
)
Expand Down

0 comments on commit 8b5892f

Please sign in to comment.