diff --git a/cvat/apps/engine/backup.py b/cvat/apps/engine/backup.py index 10dead22404d..b70012a2ddbd 100644 --- a/cvat/apps/engine/backup.py +++ b/cvat/apps/engine/backup.py @@ -36,14 +36,14 @@ LabeledDataSerializer, SegmentSerializer, SimpleJobSerializer, TaskReadSerializer, ProjectReadSerializer, ProjectFileSerializer, TaskFileSerializer, RqIdSerializer) from cvat.apps.engine.utils import ( - av_scan_paths, process_failed_job, configure_dependent_job_to_download_from_cs, + av_scan_paths, process_failed_job, get_rq_job_meta, get_import_rq_id, import_resource_with_clean_up_after, sendfile, define_dependent_job, get_rq_lock_by_user, build_backup_file_name, ) from cvat.apps.engine.models import ( StorageChoice, StorageMethodChoice, DataChoice, Task, Project, Location) from cvat.apps.engine.task import JobFileMapping, _create_thread -from cvat.apps.engine.cloud_provider import download_file_from_bucket, export_resource_to_cloud_storage +from cvat.apps.engine.cloud_provider import import_resource_from_cloud_storage, export_resource_to_cloud_storage from cvat.apps.engine.location import StorageType, get_location_configuration from cvat.apps.engine.view_utils import get_cloud_storage_for_import_or_export from cvat.apps.dataset_manager.views import TASK_CACHE_TTL, PROJECT_CACHE_TTL, get_export_cache_dir, clear_export_cache, log_exception @@ -1051,7 +1051,9 @@ def _import(importer, request, queue, rq_id, Serializer, file_field_name, locati if not rq_job: org_id = getattr(request.iam_context['organization'], 'id', None) - dependent_job = None + + func = import_resource_with_clean_up_after + func_args = (importer, filename, request.user.id, org_id) location = location_conf.get('location') if location == Location.LOCAL: @@ -1084,42 +1086,30 @@ def _import(importer, request, queue, rq_id, Serializer, file_field_name, locati with NamedTemporaryFile(prefix='cvat_', dir=settings.TMP_FILES_ROOT, delete=False) as tf: filename = tf.name - dependent_job = configure_dependent_job_to_download_from_cs( - queue=queue, - rq_id=rq_id, - rq_func=download_file_from_bucket, - db_storage=db_storage, - filename=filename, - key=key, - request=request, - result_ttl=settings.IMPORT_CACHE_SUCCESS_TTL.total_seconds(), - failure_ttl=settings.IMPORT_CACHE_FAILED_TTL.total_seconds() - ) + func_args = (db_storage, key, func) + func_args + func = import_resource_from_cloud_storage user_id = request.user.id with get_rq_lock_by_user(queue, user_id): rq_job = queue.enqueue_call( - func=import_resource_with_clean_up_after, - args=(importer, filename, request.user.id, org_id), + func=func, + args=func_args, job_id=rq_id, meta={ 'tmp_file': filename, **get_rq_job_meta(request=request, db_obj=None) }, - depends_on=dependent_job or define_dependent_job(queue, user_id), + depends_on=define_dependent_job(queue, user_id), result_ttl=settings.IMPORT_CACHE_SUCCESS_TTL.total_seconds(), failure_ttl=settings.IMPORT_CACHE_FAILED_TTL.total_seconds() ) else: if rq_job.is_finished: - if rq_job.dependency: - rq_job.dependency.delete() project_id = rq_job.return_value() rq_job.delete() return Response({'id': project_id}, status=status.HTTP_201_CREATED) - elif rq_job.is_failed or \ - rq_job.is_deferred and rq_job.dependency and rq_job.dependency.is_failed: + elif rq_job.is_failed: exc_info = process_failed_job(rq_job) # RQ adds a prefix with exception class name import_error_prefix = '{}.{}'.format( diff --git a/cvat/apps/engine/cloud_provider.py b/cvat/apps/engine/cloud_provider.py index 1ddf88c95bdb..63cfefcd6309 100644 --- a/cvat/apps/engine/cloud_provider.py +++ b/cvat/apps/engine/cloud_provider.py @@ -10,7 +10,7 @@ from enum import Enum from io import BytesIO from multiprocessing.pool import ThreadPool -from typing import Dict, List, Optional, Any, Callable +from typing import Dict, List, Optional, Any, Callable, TypeVar import boto3 from azure.core.exceptions import HttpResponseError, ResourceExistsError @@ -963,12 +963,24 @@ def db_storage_to_storage_instance(db_storage): } return get_cloud_storage_instance(cloud_provider=db_storage.provider_type, **details) -def download_file_from_bucket(db_storage: Any, filename: str, key: str) -> None: +T = TypeVar('T', Callable[[str, int, int], int], Callable[[str, int, str, bool], None]) + +def import_resource_from_cloud_storage( + db_storage: Any, + key: str, + cleanup_func: Callable[[T, str,], Any], + import_func: T, + filename: str, + *args, + **kwargs, +) -> Any: storage = db_storage_to_storage_instance(db_storage) with storage.download_fileobj(key) as data, open(filename, 'wb+') as f: f.write(data.getbuffer()) + return cleanup_func(import_func, filename, *args, **kwargs) + def export_resource_to_cloud_storage( db_storage: Any, key: str, diff --git a/cvat/apps/engine/utils.py b/cvat/apps/engine/utils.py index ee11a65bec94..efb6f0c8ba1c 100644 --- a/cvat/apps/engine/utils.py +++ b/cvat/apps/engine/utils.py @@ -144,9 +144,7 @@ def parse_exception_message(msg): return parsed_msg def process_failed_job(rq_job: Job): - exc_info = str(rq_job.exc_info or getattr(rq_job.dependency, 'exc_info', None) or '') - if rq_job.dependency: - rq_job.dependency.delete() + exc_info = str(rq_job.exc_info or '') rq_job.delete() msg = parse_exception_message(exc_info) @@ -204,50 +202,11 @@ def define_dependent_job( return Dependency(jobs=[sorted(user_jobs, key=lambda job: job.created_at)[-1]], allow_failure=True) if user_jobs else None -def get_rq_lock_by_user(queue: DjangoRQ, user_id: int, additional_condition: bool = True) -> Union[Lock, nullcontext]: - if settings.ONE_RUNNING_JOB_IN_QUEUE_PER_USER and additional_condition: +def get_rq_lock_by_user(queue: DjangoRQ, user_id: int) -> Union[Lock, nullcontext]: + if settings.ONE_RUNNING_JOB_IN_QUEUE_PER_USER: return queue.connection.lock(f'{queue.name}-lock-{user_id}', timeout=30) return nullcontext() - -def configure_dependent_job_to_download_from_cs( - queue: DjangoRQ, - rq_id: str, - rq_func: Callable[[Any, str, str], None], - db_storage: Any, - filename: str, - key: str, - request: HttpRequest, - result_ttl: float, - failure_ttl: float, - should_be_dependent: bool = settings.ONE_RUNNING_JOB_IN_QUEUE_PER_USER, -) -> Job: - rq_job_id_download_file = rq_id + f'?action=download_{key.replace("/", ".")}' - rq_job_download_file = queue.fetch_job(rq_job_id_download_file) - - if rq_job_download_file and (rq_job_download_file.is_finished or rq_job_download_file.is_failed): - rq_job_download_file.delete() - rq_job_download_file = None - - if not rq_job_download_file: - # note: boto3 resource isn't pickleable, so we can't use storage - user_id = request.user.id - - with get_rq_lock_by_user(queue, user_id): - rq_job_download_file = queue.enqueue_call( - func=rq_func, - args=(db_storage, filename, key), - job_id=rq_job_id_download_file, - meta={ - **get_rq_job_meta(request=request, db_obj=db_storage), - KEY_TO_EXCLUDE_FROM_DEPENDENCY: True, - }, - result_ttl=result_ttl, - failure_ttl=failure_ttl, - depends_on=define_dependent_job(queue, user_id, should_be_dependent, rq_id=rq_job_id_download_file) - ) - return rq_job_download_file - def get_rq_job_meta(request, db_obj): # to prevent circular import from cvat.apps.webhooks.signals import project_id, organization_id diff --git a/cvat/apps/engine/views.py b/cvat/apps/engine/views.py index f428e3516119..e86dc5b89ad3 100644 --- a/cvat/apps/engine/views.py +++ b/cvat/apps/engine/views.py @@ -42,7 +42,7 @@ import cvat.apps.dataset_manager as dm import cvat.apps.dataset_manager.views # pylint: disable=unused-import -from cvat.apps.engine.cloud_provider import db_storage_to_storage_instance, download_file_from_bucket, export_resource_to_cloud_storage +from cvat.apps.engine.cloud_provider import db_storage_to_storage_instance, import_resource_from_cloud_storage, export_resource_to_cloud_storage from cvat.apps.events.handlers import handle_dataset_export, handle_dataset_import from cvat.apps.dataset_manager.bindings import CvatImportError from cvat.apps.dataset_manager.serializers import DatasetFormatsSerializer @@ -70,7 +70,7 @@ from utils.dataset_manifest import ImageManifestManager from cvat.apps.engine.utils import ( - av_scan_paths, process_failed_job, configure_dependent_job_to_download_from_cs, + av_scan_paths, process_failed_job, parse_exception_message, get_rq_job_meta, get_import_rq_id, import_resource_with_clean_up_after, sendfile, define_dependent_job, get_rq_lock_by_user, build_annotations_file_name, @@ -2836,11 +2836,12 @@ def _import_annotations(request, rq_id_template, rq_func, db_obj, format_name, # If filename is specified we consider that file was uploaded via TUS, so it exists in filesystem # Then we dont need to create temporary file # Or filename specify key in cloud storage so we need to download file - dependent_job = None location = location_conf.get('location') if location_conf else Location.LOCAL - db_storage = None + func = import_resource_with_clean_up_after + func_args = (rq_func, filename, db_obj.pk, format_name, conv_mask_to_poly) + if not filename or location == Location.CLOUD_STORAGE: if location != Location.CLOUD_STORAGE: serializer = AnnotationFileSerializer(data=request.data) @@ -2873,27 +2874,18 @@ def _import_annotations(request, rq_id_template, rq_func, db_obj, format_name, delete=False) as tf: filename = tf.name - dependent_job = configure_dependent_job_to_download_from_cs( - queue=queue, - rq_id=rq_id, - rq_func=download_file_from_bucket, - db_storage=db_storage, - filename=filename, - key=key, - request=request, - result_ttl=settings.IMPORT_CACHE_SUCCESS_TTL.total_seconds(), - failure_ttl=settings.IMPORT_CACHE_FAILED_TTL.total_seconds() - ) + func_args = (db_storage, key, func) + func_args + func = import_resource_from_cloud_storage av_scan_paths(filename) user_id = request.user.id - with get_rq_lock_by_user(queue, user_id, additional_condition=not dependent_job): + with get_rq_lock_by_user(queue, user_id): rq_job = queue.enqueue_call( - func=import_resource_with_clean_up_after, - args=(rq_func, filename, db_obj.pk, format_name, conv_mask_to_poly), + func=func, + args=func_args, job_id=rq_id, - depends_on=dependent_job or define_dependent_job(queue, user_id, rq_id=rq_id), + depends_on=define_dependent_job(queue, user_id, rq_id=rq_id), meta={ 'tmp_file': filename, **get_rq_job_meta(request=request, db_obj=db_obj), @@ -2910,12 +2902,9 @@ def _import_annotations(request, rq_id_template, rq_func, db_obj, format_name, return Response(serializer.data, status=status.HTTP_202_ACCEPTED) else: if rq_job.is_finished: - if rq_job.dependency: - rq_job.dependency.delete() rq_job.delete() return Response(status=status.HTTP_201_CREATED) - elif rq_job.is_failed or \ - rq_job.is_deferred and rq_job.dependency and rq_job.dependency.is_failed: + elif rq_job.is_failed: exc_info = process_failed_job(rq_job) import_error_prefix = f'{CvatImportError.__module__}.{CvatImportError.__name__}:' @@ -3090,10 +3079,13 @@ def _import_project_dataset(request, rq_id_template, rq_func, db_obj, format_nam # (e.g the user closed the browser tab when job has been created # but no one requests for checking status were not made) rq_job.delete() - dependent_job = None + location = location_conf.get('location') if location_conf else None db_storage = None + func = import_resource_with_clean_up_after + func_args = (rq_func, filename, db_obj.pk, format_name, conv_mask_to_poly) + if not filename and location != Location.CLOUD_STORAGE: serializer = DatasetFileSerializer(data=request.data) if serializer.is_valid(raise_exception=True): @@ -3125,30 +3117,21 @@ def _import_project_dataset(request, rq_id_template, rq_func, db_obj, format_nam delete=False) as tf: filename = tf.name - dependent_job = configure_dependent_job_to_download_from_cs( - queue=queue, - rq_id=rq_id, - rq_func=download_file_from_bucket, - db_storage=db_storage, - filename=filename, - key=key, - request=request, - result_ttl=settings.IMPORT_CACHE_SUCCESS_TTL.total_seconds(), - failure_ttl=settings.IMPORT_CACHE_FAILED_TTL.total_seconds() - ) + func_args = (db_storage, key, func) + func_args + func = import_resource_from_cloud_storage user_id = request.user.id - with get_rq_lock_by_user(queue, user_id, additional_condition=not dependent_job): + with get_rq_lock_by_user(queue, user_id): rq_job = queue.enqueue_call( - func=import_resource_with_clean_up_after, - args=(rq_func, filename, db_obj.pk, format_name, conv_mask_to_poly), + func=func, + args=func_args, job_id=rq_id, meta={ 'tmp_file': filename, **get_rq_job_meta(request=request, db_obj=db_obj), }, - depends_on=dependent_job or define_dependent_job(queue, user_id, rq_id=rq_id), + depends_on=define_dependent_job(queue, user_id, rq_id=rq_id), result_ttl=settings.IMPORT_CACHE_SUCCESS_TTL.total_seconds(), failure_ttl=settings.IMPORT_CACHE_FAILED_TTL.total_seconds() )