diff --git a/CHANGELOG.md b/CHANGELOG.md index 44d4149dbbcd..0c2b180d8692 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Support of context images for 2D image tasks () +- Support of cloud storage without copying data into CVAT: server part () - Filter `is_active` for user list () - Ability to export/import tasks () diff --git a/cvat/apps/authentication/auth.py b/cvat/apps/authentication/auth.py index 5e19efb7609c..5c1f8ea3c81c 100644 --- a/cvat/apps/authentication/auth.py +++ b/cvat/apps/authentication/auth.py @@ -159,6 +159,10 @@ def is_comment_author(db_user, db_comment): has_rights = (db_comment.author == db_user) return has_rights +@rules.predicate +def is_cloud_storage_owner(db_user, db_storage): + return db_storage.owner == db_user + # AUTH PERMISSIONS RULES rules.add_perm('engine.role.user', has_user_role) rules.add_perm('engine.role.admin', has_admin_role) @@ -190,6 +194,9 @@ def is_comment_author(db_user, db_comment): rules.add_perm('engine.comment.change', has_admin_role | is_comment_author) +rules.add_perm('engine.cloudstorage.create', has_admin_role | has_user_role) +rules.add_perm('engine.cloudstorage.change', has_admin_role | is_cloud_storage_owner) + class AdminRolePermission(BasePermission): # pylint: disable=no-self-use def has_permission(self, request, view): @@ -329,3 +336,21 @@ class CommentChangePermission(BasePermission): def has_object_permission(self, request, view, obj): return request.user.has_perm('engine.comment.change', obj) +class CloudStorageAccessPermission(BasePermission): + # pylint: disable=no-self-use + def has_object_permission(self, request, view, obj): + return request.user.has_perm("engine.cloudstorage.change", obj) + +class CloudStorageChangePermission(BasePermission): + # pylint: disable=no-self-use + def has_object_permission(self, request, view, obj): + return request.user.has_perm("engine.cloudstorage.change", obj) + +class CloudStorageGetQuerySetMixin(object): + def get_queryset(self): + queryset = super().get_queryset() + user = self.request.user + if has_admin_role(user) or self.detail: + return queryset + else: + return queryset.filter(owner=user) diff --git a/cvat/apps/engine/admin.py b/cvat/apps/engine/admin.py index ddacf69ab027..0dab80a8d0de 100644 --- a/cvat/apps/engine/admin.py +++ b/cvat/apps/engine/admin.py @@ -4,14 +4,14 @@ # SPDX-License-Identifier: MIT from django.contrib import admin -from .models import Task, Segment, Job, Label, AttributeSpec, Project +from .models import Task, Segment, Job, Label, AttributeSpec, Project, CloudStorage class JobInline(admin.TabularInline): model = Job can_delete = False # Don't show extra lines to add an object - def has_add_permission(self, request, object=None): + def has_add_permission(self, request, obj): return False class SegmentInline(admin.TabularInline): @@ -21,7 +21,7 @@ class SegmentInline(admin.TabularInline): can_delete = False # Don't show extra lines to add an object - def has_add_permission(self, request, object=None): + def has_add_permission(self, request, obj): return False @@ -84,8 +84,20 @@ class TaskAdmin(admin.ModelAdmin): def has_add_permission(self, request): return False +class CloudStorageAdmin(admin.ModelAdmin): + date_hierarchy = 'updated_date' + readonly_fields = ('created_date', 'updated_date', 'provider_type') + list_display = ('__str__', 'resource', 'owner', 'created_date', 'updated_date') + search_fields = ('provider_type', 'display_name', 'resource', 'owner__username', 'owner__first_name', + 'owner__last_name', 'owner__email',) + + empty_value_display = 'unknown' + + def has_add_permission(self, request): + return False admin.site.register(Task, TaskAdmin) admin.site.register(Segment, SegmentAdmin) admin.site.register(Label, LabelAdmin) admin.site.register(Project, ProjectAdmin) +admin.site.register(CloudStorage, CloudStorageAdmin) diff --git a/cvat/apps/engine/cache.py b/cvat/apps/engine/cache.py index 077e6ef14fe9..d57861bcb4e4 100644 --- a/cvat/apps/engine/cache.py +++ b/cvat/apps/engine/cache.py @@ -7,13 +7,16 @@ from diskcache import Cache from django.conf import settings +from tempfile import NamedTemporaryFile +from cvat.apps.engine.log import slogger from cvat.apps.engine.media_extractors import (Mpeg4ChunkWriter, Mpeg4CompressedChunkWriter, ZipChunkWriter, ZipCompressedChunkWriter, ImageDatasetManifestReader, VideoDatasetManifestReader) from cvat.apps.engine.models import DataChoice, StorageChoice from cvat.apps.engine.models import DimensionType - +from cvat.apps.engine.cloud_provider import get_cloud_storage_instance, Credentials +from cvat.apps.engine.utils import md5_hash class CacheInteraction: def __init__(self, dimension=DimensionType.DIM_2D): self._cache = Cache(settings.CACHE_ROOT) @@ -49,10 +52,12 @@ def prepare_chunk_buff(self, db_data, quality, chunk_number): buff = BytesIO() upload_dir = { StorageChoice.LOCAL: db_data.get_upload_dirname(), - StorageChoice.SHARE: settings.SHARE_ROOT + StorageChoice.SHARE: settings.SHARE_ROOT, + StorageChoice.CLOUD_STORAGE: db_data.get_upload_dirname(), }[db_data.storage] if hasattr(db_data, 'video'): source_path = os.path.join(upload_dir, db_data.video.path) + reader = VideoDatasetManifestReader(manifest_path=db_data.get_manifest_path(), source_path=source_path, chunk_number=chunk_number, chunk_size=db_data.chunk_size, start=db_data.start_frame, @@ -64,12 +69,43 @@ def prepare_chunk_buff(self, db_data, quality, chunk_number): chunk_number=chunk_number, chunk_size=db_data.chunk_size, start=db_data.start_frame, stop=db_data.stop_frame, step=db_data.get_frame_step()) - for item in reader: - source_path = os.path.join(upload_dir, f"{item['name']}{item['extension']}") - images.append((source_path, source_path, None)) - + if db_data.storage == StorageChoice.CLOUD_STORAGE: + db_cloud_storage = db_data.cloud_storage + credentials = Credentials() + credentials.convert_from_db({ + 'type': db_cloud_storage.credentials_type, + 'value': db_cloud_storage.credentials, + }) + details = { + 'resource': db_cloud_storage.resource, + 'credentials': credentials, + 'specific_attributes': db_cloud_storage.get_specific_attributes() + } + cloud_storage_instance = get_cloud_storage_instance(cloud_provider=db_cloud_storage.provider_type, **details) + cloud_storage_instance.initialize_content() + for item in reader: + name = f"{item['name']}{item['extension']}" + if name not in cloud_storage_instance: + raise Exception('{} file was not found on a {} storage'.format(name, cloud_storage_instance.name)) + with NamedTemporaryFile(mode='w+b', prefix='cvat', suffix=name, delete=False) as temp_file: + source_path = temp_file.name + buf = cloud_storage_instance.download_fileobj(name) + temp_file.write(buf.getvalue()) + if not (checksum := item.get('checksum', None)): + slogger.glob.warning('A manifest file does not contain checksum for image {}'.format(item.get('name'))) + if checksum and not md5_hash(source_path) == checksum: + slogger.glob.warning('Hash sums of files {} do not match'.format(name)) + images.append((source_path, source_path, None)) + else: + for item in reader: + source_path = os.path.join(upload_dir, f"{item['name']}{item['extension']}") + images.append((source_path, source_path, None)) writer.save_as_chunk(images, buff) buff.seek(0) + if db_data.storage == StorageChoice.CLOUD_STORAGE: + images = [image_path for image in images if os.path.exists((image_path := image[0]))] + for image_path in images: + os.remove(image_path) return buff, mime_type def save_chunk(self, db_data_id, chunk_number, quality, buff, mime_type): diff --git a/cvat/apps/engine/cloud_provider.py b/cvat/apps/engine/cloud_provider.py new file mode 100644 index 000000000000..017d5f7db9e0 --- /dev/null +++ b/cvat/apps/engine/cloud_provider.py @@ -0,0 +1,296 @@ +#from dataclasses import dataclass +from abc import ABC, abstractmethod, abstractproperty +from io import BytesIO + +import boto3 +from boto3.s3.transfer import TransferConfig +from botocore.exceptions import WaiterError +from botocore.handlers import disable_signing + +from azure.storage.blob import BlobServiceClient +from azure.core.exceptions import ResourceExistsError +from azure.storage.blob import PublicAccess + +from cvat.apps.engine.log import slogger +from cvat.apps.engine.models import CredentialsTypeChoice, CloudProviderChoice + +class _CloudStorage(ABC): + + def __init__(self): + self._files = [] + + @abstractproperty + def name(self): + pass + + @abstractmethod + def create(self): + pass + + @abstractmethod + def exists(self): + pass + + @abstractmethod + def initialize_content(self): + pass + + @abstractmethod + def download_fileobj(self, key): + pass + + def download_file(self, key, path): + file_obj = self.download_fileobj(key) + if isinstance(file_obj, BytesIO): + with open(path, 'wb') as f: + f.write(file_obj.getvalue()) + else: + raise NotImplementedError("Unsupported type {} was found".format(type(file_obj))) + + @abstractmethod + def upload_file(self, file_obj, file_name): + pass + + def __contains__(self, file_name): + return file_name in (item['name'] for item in self._files) + + def __len__(self): + return len(self._files) + + @property + def content(self): + return list(map(lambda x: x['name'] , self._files)) + +def get_cloud_storage_instance(cloud_provider, resource, credentials, specific_attributes=None): + instance = None + if cloud_provider == CloudProviderChoice.AWS_S3: + instance = AWS_S3( + bucket=resource, + access_key_id=credentials.key, + secret_key=credentials.secret_key, + session_token=credentials.session_token, + region=specific_attributes.get('region', 'us-east-2') + ) + elif cloud_provider == CloudProviderChoice.AZURE_CONTAINER: + instance = AzureBlobContainer( + container=resource, + account_name=credentials.account_name, + sas_token=credentials.session_token + ) + else: + raise NotImplementedError() + return instance + +class AWS_S3(_CloudStorage): + waiter_config = { + 'Delay': 5, # The amount of time in seconds to wait between attempts. Default: 5 + 'MaxAttempts': 3, # The maximum number of attempts to be made. Default: 20 + } + transfer_config = { + 'max_io_queue': 10, + } + def __init__(self, + bucket, + region, + access_key_id=None, + secret_key=None, + session_token=None): + super().__init__() + if all([access_key_id, secret_key, session_token]): + self._s3 = boto3.resource( + 's3', + aws_access_key_id=access_key_id, + aws_secret_access_key=secret_key, + aws_session_token=session_token, + region_name=region + ) + elif any([access_key_id, secret_key, session_token]): + raise Exception('Insufficient data for authorization') + # anonymous access + if not any([access_key_id, secret_key, session_token]): + self._s3 = boto3.resource('s3', region_name=region) + self._s3.meta.client.meta.events.register('choose-signer.s3.*', disable_signing) + self._client_s3 = self._s3.meta.client + self._bucket = self._s3.Bucket(bucket) + self.region = region + + @property + def bucket(self): + return self._bucket + + @property + def name(self): + return self._bucket.name + + def exists(self): + waiter = self._client_s3.get_waiter('bucket_exists') + try: + waiter.wait( + Bucket=self.name, + WaiterConfig=self.waiter_config + ) + except WaiterError: + raise Exception('A resource {} unavailable'.format(self.name)) + + def is_object_exist(self, key_object): + waiter = self._client_s3.get_waiter('object_exists') + try: + waiter.wait( + Bucket=self._bucket, + Key=key_object, + WaiterConfig=self.waiter_config + ) + except WaiterError: + raise Exception('A file {} unavailable'.format(key_object)) + + def upload_file(self, file_obj, file_name): + self._bucket.upload_fileobj( + Fileobj=file_obj, + Key=file_name, + Config=TransferConfig(max_io_queue=self.transfer_config['max_io_queue']) + ) + + def initialize_content(self): + files = self._bucket.objects.all() + self._files = [{ + 'name': item.key, + } for item in files] + + def download_fileobj(self, key): + buf = BytesIO() + self.bucket.download_fileobj( + Key=key, + Fileobj=buf, + Config=TransferConfig(max_io_queue=self.transfer_config['max_io_queue']) + ) + buf.seek(0) + return buf + + def create(self): + try: + responce = self._bucket.create( + ACL='private', + CreateBucketConfiguration={ + 'LocationConstraint': self.region, + }, + ObjectLockEnabledForBucket=False + ) + slogger.glob.info( + 'Bucket {} has been created on {} region'.format( + self.name, + responce['Location'] + )) + except Exception as ex: + msg = str(ex) + slogger.glob.info(msg) + raise Exception(msg) + +class AzureBlobContainer(_CloudStorage): + MAX_CONCURRENCY = 3 + def __init__(self, container, account_name, sas_token=None): + super().__init__() + self._account_name = account_name + if sas_token: + self._blob_service_client = BlobServiceClient(account_url=self.account_url, credential=sas_token) + else: + self._blob_service_client = BlobServiceClient(account_url=self.account_url) + self._container_client = self._blob_service_client.get_container_client(container) + + @property + def container(self): + return self._container_client + + @property + def name(self): + return self._container_client.container_name + + @property + def account_url(self): + return "{}.blob.core.windows.net".format(self._account_name) + + def create(self): + try: + self._container_client.create_container( + metadata={ + 'type' : 'created by CVAT', + }, + public_access=PublicAccess.OFF + ) + except ResourceExistsError: + msg = f"{self._container_client.container_name} already exists" + slogger.glob.info(msg) + raise Exception(msg) + + def exists(self): + return self._container_client.exists(timeout=5) + + def is_object_exist(self, file_name): + blob_client = self._container_client.get_blob_client(file_name) + return blob_client.exists() + + def upload_file(self, file_obj, file_name): + self._container_client.upload_blob(name=file_name, data=file_obj) + + + # TODO: + # def multipart_upload(self, file_obj): + # pass + + def initialize_content(self): + files = self._container_client.list_blobs() + self._files = [{ + 'name': item.name + } for item in files] + + def download_fileobj(self, key): + buf = BytesIO() + storage_stream_downloader = self._container_client.download_blob( + blob=key, + offset=None, + length=None, + ) + storage_stream_downloader.download_to_stream(buf, max_concurrency=self.MAX_CONCURRENCY) + buf.seek(0) + return buf + +class GOOGLE_DRIVE(_CloudStorage): + pass + +class Credentials: + __slots__ = ('key', 'secret_key', 'session_token', 'account_name', 'credentials_type') + + def __init__(self, **credentials): + self.key = credentials.get('key', '') + self.secret_key = credentials.get('secret_key', '') + self.session_token = credentials.get('session_token', '') + self.account_name = credentials.get('account_name', '') + self.credentials_type = credentials.get('credentials_type', None) + + def convert_to_db(self): + converted_credentials = { + CredentialsTypeChoice.TEMP_KEY_SECRET_KEY_TOKEN_SET : \ + " ".join([self.key, self.secret_key, self.session_token]), + CredentialsTypeChoice.ACCOUNT_NAME_TOKEN_PAIR : " ".join([self.account_name, self.session_token]), + CredentialsTypeChoice.ANONYMOUS_ACCESS: "", + } + return converted_credentials[self.credentials_type] + + def convert_from_db(self, credentials): + self.credentials_type = credentials.get('type') + if self.credentials_type == CredentialsTypeChoice.TEMP_KEY_SECRET_KEY_TOKEN_SET: + self.key, self.secret_key, self.session_token = credentials.get('value').split() + elif self.credentials_type == CredentialsTypeChoice.ACCOUNT_NAME_TOKEN_PAIR: + self.account_name, self.session_token = credentials.get('value').split() + else: + self.account_name, self.session_token, self.key, self.secret_key = ('', '', '', '') + self.credentials_type = None + + def mapping_with_new_values(self, credentials): + self.credentials_type = credentials.get('credentials_type', self.credentials_type) + self.key = credentials.get('key', self.key) + self.secret_key = credentials.get('secret_key', self.secret_key) + self.session_token = credentials.get('session_token', self.session_token) + self.account_name = credentials.get('account_name', self.account_name) + + def values(self): + return [self.key, self.secret_key, self.session_token, self.account_name] diff --git a/cvat/apps/engine/log.py b/cvat/apps/engine/log.py index 98d5c8e2e48b..dfa7dc99349b 100644 --- a/cvat/apps/engine/log.py +++ b/cvat/apps/engine/log.py @@ -5,7 +5,7 @@ import logging import sys from cvat.settings.base import LOGGING -from .models import Job, Task, Project +from .models import Job, Task, Project, CloudStorage def _get_project(pid): try: @@ -25,6 +25,12 @@ def _get_job(jid): except Exception: raise Exception('{} key must be a job identifier'.format(jid)) +def _get_storage(storage_id): + try: + return CloudStorage.objects.get(pk=storage_id) + except Exception: + raise Exception('{} key must be a cloud storage identifier'.format(storage_id)) + def get_logger(logger_name, log_file): logger = logging.getLogger(name=logger_name) logger.setLevel(logging.INFO) @@ -91,6 +97,27 @@ def _get_task_logger(self, jid): job = _get_job(jid) return slogger.task[job.segment.task.id] +class CloudSourceLoggerStorage: + def __init__(self): + self._storage = dict() + + def __getitem__(self, sid): + """Get ceratain storage object for some cloud storage.""" + if sid not in self._storage: + self._storage[sid] = self._create_cloud_storage_logger(sid) + return self._storage[sid] + + def _create_cloud_storage_logger(self, sid): + cloud_storage = _get_storage(sid) + + logger = logging.getLogger('cvat.server.cloud_storage_{}'.format(sid)) + server_file = logging.FileHandler(filename=cloud_storage.get_log_path()) + formatter = logging.Formatter(LOGGING['formatters']['standard']['format']) + server_file.setFormatter(formatter) + logger.addHandler(server_file) + + return logger + class ProjectClientLoggerStorage: def __init__(self): self._storage = dict() @@ -156,5 +183,6 @@ class dotdict(dict): 'project': ProjectLoggerStorage(), 'task': TaskLoggerStorage(), 'job': JobLoggerStorage(), + 'cloud_storage': CloudSourceLoggerStorage(), 'glob': logging.getLogger('cvat.server'), }) diff --git a/cvat/apps/engine/migrations/0040_cloud_storage.py b/cvat/apps/engine/migrations/0040_cloud_storage.py new file mode 100644 index 000000000000..c73609fd9fef --- /dev/null +++ b/cvat/apps/engine/migrations/0040_cloud_storage.py @@ -0,0 +1,47 @@ +# Generated by Django 3.1.8 on 2021-05-07 06:42 + +import cvat.apps.engine.models +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('engine', '0039_auto_training'), + ] + + operations = [ + migrations.AlterField( + model_name='data', + name='storage', + field=models.CharField(choices=[('cloud_storage', 'CLOUD_STORAGE'), ('local', 'LOCAL'), ('share', 'SHARE')], default=cvat.apps.engine.models.StorageChoice['LOCAL'], max_length=15), + ), + migrations.CreateModel( + name='CloudStorage', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('provider_type', models.CharField(choices=[('AWS_S3_BUCKET', 'AWS_S3'), ('AZURE_CONTAINER', 'AZURE_CONTAINER'), ('GOOGLE_DRIVE', 'GOOGLE_DRIVE')], max_length=20)), + ('resource', models.CharField(max_length=63)), + ('display_name', models.CharField(max_length=63)), + ('created_date', models.DateTimeField(auto_now_add=True)), + ('updated_date', models.DateTimeField(auto_now=True)), + ('credentials', models.CharField(max_length=500)), + ('credentials_type', models.CharField(choices=[('TEMP_KEY_SECRET_KEY_TOKEN_SET', 'TEMP_KEY_SECRET_KEY_TOKEN_SET'), ('ACCOUNT_NAME_TOKEN_PAIR', 'ACCOUNT_NAME_TOKEN_PAIR'), ('ANONYMOUS_ACCESS', 'ANONYMOUS_ACCESS')], max_length=29)), + ('specific_attributes', models.CharField(blank=True, max_length=50)), + ('description', models.TextField(blank=True)), + ('owner', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='cloud_storages', to=settings.AUTH_USER_MODEL)), + ], + options={ + 'default_permissions': (), + 'unique_together': {('provider_type', 'resource', 'credentials')}, + }, + ), + migrations.AddField( + model_name='data', + name='cloud_storage', + field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='data', to='engine.cloudstorage'), + ), + ] diff --git a/cvat/apps/engine/models.py b/cvat/apps/engine/models.py index bcc467386fc9..f88f748aa689 100644 --- a/cvat/apps/engine/models.py +++ b/cvat/apps/engine/models.py @@ -68,7 +68,7 @@ def __str__(self): return self.value class StorageChoice(str, Enum): - #AWS_S3 = 'aws_s3_bucket' + CLOUD_STORAGE = 'cloud_storage' LOCAL = 'local' SHARE = 'share' @@ -92,6 +92,7 @@ class Data(models.Model): default=DataChoice.IMAGESET) storage_method = models.CharField(max_length=15, choices=StorageMethodChoice.choices(), default=StorageMethodChoice.FILE_SYSTEM) storage = models.CharField(max_length=15, choices=StorageChoice.choices(), default=StorageChoice.LOCAL) + cloud_storage = models.ForeignKey('CloudStorage', on_delete=models.SET_NULL, null=True, related_name='data') class Meta: default_permissions = () @@ -535,3 +536,80 @@ class Comment(models.Model): message = models.TextField(default='') created_date = models.DateTimeField(auto_now_add=True) updated_date = models.DateTimeField(auto_now=True) + +class CloudProviderChoice(str, Enum): + AWS_S3 = 'AWS_S3_BUCKET' + AZURE_CONTAINER = 'AZURE_CONTAINER' + GOOGLE_DRIVE = 'GOOGLE_DRIVE' + + @classmethod + def choices(cls): + return tuple((x.value, x.name) for x in cls) + + @classmethod + def list(cls): + return list(map(lambda x: x.value, cls)) + + def __str__(self): + return self.value + +class CredentialsTypeChoice(str, Enum): + # ignore bandit issues because false positives + TEMP_KEY_SECRET_KEY_TOKEN_SET = 'TEMP_KEY_SECRET_KEY_TOKEN_SET' # nosec + ACCOUNT_NAME_TOKEN_PAIR = 'ACCOUNT_NAME_TOKEN_PAIR' # nosec + ANONYMOUS_ACCESS = 'ANONYMOUS_ACCESS' + + @classmethod + def choices(cls): + return tuple((x.value, x.name) for x in cls) + + @classmethod + def list(cls): + return list(map(lambda x: x.value, cls)) + + def __str__(self): + return self.value + +class CloudStorage(models.Model): + # restrictions: + # AWS bucket name, Azure container name - 63 + # AWS access key id - 20 + # AWS secret access key - 40 + # AWS temporary session tocken - None + # The size of the security token that AWS STS API operations return is not fixed. + # We strongly recommend that you make no assumptions about the maximum size. + # The typical token size is less than 4096 bytes, but that can vary. + provider_type = models.CharField(max_length=20, choices=CloudProviderChoice.choices()) + resource = models.CharField(max_length=63) + display_name = models.CharField(max_length=63) + owner = models.ForeignKey(User, null=True, blank=True, + on_delete=models.SET_NULL, related_name="cloud_storages") + created_date = models.DateTimeField(auto_now_add=True) + updated_date = models.DateTimeField(auto_now=True) + credentials = models.CharField(max_length=500) + credentials_type = models.CharField(max_length=29, choices=CredentialsTypeChoice.choices())#auth_type + specific_attributes = models.CharField(max_length=50, blank=True) + description = models.TextField(blank=True) + + class Meta: + default_permissions = () + unique_together = (('provider_type', 'resource', 'credentials'),) + + def __str__(self): + return "{} {} {}".format(self.provider_type, self.display_name, self.id) + + def get_storage_dirname(self): + return os.path.join(settings.CLOUD_STORAGE_ROOT, str(self.id)) + + def get_storage_logs_dirname(self): + return os.path.join(self.get_storage_dirname(), 'logs') + + def get_log_path(self): + return os.path.join(self.get_storage_dirname(), "storage.log") + + def get_specific_attributes(self): + attributes = self.specific_attributes.split('&') + return { + item.split('=')[0].strip(): item.split('=')[1].strip() + for item in attributes + } if len(attributes) else dict() \ No newline at end of file diff --git a/cvat/apps/engine/serializers.py b/cvat/apps/engine/serializers.py index aaabe18a3ae5..519148f7c9d0 100644 --- a/cvat/apps/engine/serializers.py +++ b/cvat/apps/engine/serializers.py @@ -9,12 +9,11 @@ from rest_framework import serializers, exceptions from django.contrib.auth.models import User, Group - from cvat.apps.dataset_manager.formats.utils import get_label_color from cvat.apps.engine import models +from cvat.apps.engine.cloud_provider import get_cloud_storage_instance, Credentials from cvat.apps.engine.log import slogger - class BasicUserSerializer(serializers.ModelSerializer): def validate(self, data): if hasattr(self, 'initial_data'): @@ -273,12 +272,13 @@ class DataSerializer(serializers.ModelSerializer): remote_files = RemoteFileSerializer(many=True, default=[]) use_cache = serializers.BooleanField(default=False) copy_data = serializers.BooleanField(default=False) + cloud_storage_id = serializers.IntegerField(write_only=True, allow_null=True, required=False) class Meta: model = models.Data fields = ('chunk_size', 'size', 'image_quality', 'start_frame', 'stop_frame', 'frame_filter', 'compressed_chunk_type', 'original_chunk_type', 'client_files', 'server_files', 'remote_files', 'use_zip_chunks', - 'use_cache', 'copy_data', 'storage_method', 'storage') + 'cloud_storage_id', 'use_cache', 'copy_data', 'storage_method', 'storage') # pylint: disable=no-self-use def validate_frame_filter(self, value): @@ -771,9 +771,96 @@ def create(self, validated_data): return db_review +class BaseCloudStorageSerializer(serializers.ModelSerializer): + class Meta: + model = models.CloudStorage + exclude = ['credentials'] + +class CloudStorageSerializer(serializers.ModelSerializer): + owner = BasicUserSerializer(required=False) + session_token = serializers.CharField(max_length=440, allow_blank=True, required=False) + key = serializers.CharField(max_length=20, allow_blank=True, required=False) + secret_key = serializers.CharField(max_length=40, allow_blank=True, required=False) + account_name = serializers.CharField(max_length=24, allow_blank=True, required=False) + + class Meta: + model = models.CloudStorage + fields = ( + 'provider_type', 'resource', 'display_name', 'owner', 'credentials_type', + 'created_date', 'updated_date', 'session_token', 'account_name', 'key', + 'secret_key', 'specific_attributes', 'description' + ) + read_only_fields = ('created_date', 'updated_date', 'owner') + + # pylint: disable=no-self-use + def validate_specific_attributes(self, value): + if value: + attributes = value.split('&') + for attribute in attributes: + if not len(attribute.split('=')) == 2: + raise serializers.ValidationError('Invalid specific attributes') + return value + + def validate(self, attrs): + if attrs.get('provider_type') == models.CloudProviderChoice.AZURE_CONTAINER: + if not attrs.get('account_name', ''): + raise serializers.ValidationError('Account name for Azure container was not specified') + return attrs + + def create(self, validated_data): + provider_type = validated_data.get('provider_type') + should_be_created = validated_data.pop('should_be_created', None) + credentials = Credentials( + account_name=validated_data.pop('account_name', ''), + key=validated_data.pop('key', ''), + secret_key=validated_data.pop('secret_key', ''), + session_token=validated_data.pop('session_token', ''), + credentials_type = validated_data.get('credentials_type') + ) + if should_be_created: + details = { + 'resource': validated_data.get('resource'), + 'credentials': credentials, + 'specific_attributes': { + item.split('=')[0].strip(): item.split('=')[1].strip() + for item in validated_data.get('specific_attributes').split('&') + } if len(validated_data.get('specific_attributes', '')) + else dict() + } + storage = get_cloud_storage_instance(cloud_provider=provider_type, **details) + try: + storage.create() + except Exception as ex: + slogger.glob.warning("Failed with creating storage\n{}".format(str(ex))) + raise + + db_storage = models.CloudStorage.objects.create( + credentials=credentials.convert_to_db(), + **validated_data + ) + db_storage.save() + return db_storage + + # pylint: disable=no-self-use + def update(self, instance, validated_data): + credentials = Credentials() + credentials.convert_from_db({ + 'type': instance.credentials_type, + 'value': instance.credentials, + }) + tmp = {k:v for k,v in validated_data.items() if k in {'key', 'secret_key', 'account_name', 'session_token', 'credentials_type'}} + credentials.mapping_with_new_values(tmp) + instance.credentials = credentials.convert_to_db() + instance.credentials_type = validated_data.get('credentials_type', instance.credentials_type) + instance.resource = validated_data.get('resource', instance.resource) + instance.display_name = validated_data.get('display_name', instance.display_name) + + instance.save() + return instance + class RelatedFileSerializer(serializers.ModelSerializer): class Meta: model = models.RelatedFile fields = '__all__' - read_only_fields = ('path',) \ No newline at end of file + read_only_fields = ('path',) diff --git a/cvat/apps/engine/task.py b/cvat/apps/engine/task.py index 4aa123eb653a..a864bf142449 100644 --- a/cvat/apps/engine/task.py +++ b/cvat/apps/engine/task.py @@ -27,6 +27,7 @@ from utils.dataset_manifest import ImageManifestManager, VideoManifestManager from utils.dataset_manifest.core import VideoManifestValidator from utils.dataset_manifest.utils import detect_related_images +from .cloud_provider import get_cloud_storage_instance, Credentials ############################# Low Level server API @@ -221,7 +222,8 @@ def _create_thread(tid, data, isImport=False): upload_dir = db_data.get_upload_dirname() if data['remote_files']: - data['remote_files'] = _download_data(data['remote_files'], upload_dir) + if db_data.storage != models.StorageChoice.CLOUD_STORAGE: + data['remote_files'] = _download_data(data['remote_files'], upload_dir) manifest_file = [] media = _count_files(data, manifest_file) @@ -233,8 +235,25 @@ def _create_thread(tid, data, isImport=False): if data['server_files']: if db_data.storage == models.StorageChoice.LOCAL: _copy_data_from_share(data['server_files'], upload_dir) - else: + elif db_data.storage == models.StorageChoice.SHARE: upload_dir = settings.SHARE_ROOT + else: # cloud storage + if not manifest_file: raise Exception('A manifest file not found') + db_cloud_storage = db_data.cloud_storage + credentials = Credentials() + credentials.convert_from_db({ + 'type': db_cloud_storage.credentials_type, + 'value': db_cloud_storage.value, + }) + + details = { + 'resource': db_cloud_storage.resource, + 'credentials': credentials, + 'specific_attributes': db_cloud_storage.get_specific_attributes() + } + cloud_storage_instance = get_cloud_storage_instance(cloud_provider=db_cloud_storage.provider_type, **details) + cloud_storage_instance.download_file(manifest_file[0], db_data.get_manifest_path()) + cloud_storage_instance.download_file(media['image'][0], os.path.join(upload_dir, media['image'][0])) av_scan_paths(upload_dir) @@ -332,7 +351,13 @@ def update_progress(progress): # calculate chunk size if it isn't specified if db_data.chunk_size is None: if isinstance(compressed_chunk_writer, ZipCompressedChunkWriter): - w, h = extractor.get_image_size(0) + if not (db_data.storage == models.StorageChoice.CLOUD_STORAGE): + w, h = extractor.get_image_size(0) + else: + manifest = ImageManifestManager(db_data.get_manifest_path()) + manifest.init_index() + img_properties = manifest[0] + w, h = img_properties['width'], img_properties['height'] area = h * w db_data.chunk_size = max(2, min(72, 36 * 1920 * 1080 // area)) else: @@ -370,8 +395,8 @@ def _update_status(msg): manifest.validate_frame_numbers() assert len(manifest) > 0, 'No key frames.' - all_frames = manifest['properties']['length'] - video_size = manifest['properties']['resolution'] + all_frames = manifest.video_length + video_size = manifest.video_resolution manifest_is_prepared = True except Exception as ex: if os.path.exists(db_data.get_index_path()): diff --git a/cvat/apps/engine/urls.py b/cvat/apps/engine/urls.py index cb3b25bda5f7..e46228efae6f 100644 --- a/cvat/apps/engine/urls.py +++ b/cvat/apps/engine/urls.py @@ -55,6 +55,7 @@ def _map_format_to_schema(request, scheme=None): router.register('comments', views.CommentViewSet) router.register('restrictions', RestrictionsViewSet, basename='restrictions') router.register('predict', PredictView, basename='predict') +router.register('cloudstorages', views.CloudStorageViewSet) urlpatterns = [ # Entry point for a client diff --git a/cvat/apps/engine/utils.py b/cvat/apps/engine/utils.py index f37440731281..87b7b856e301 100644 --- a/cvat/apps/engine/utils.py +++ b/cvat/apps/engine/utils.py @@ -12,6 +12,7 @@ import subprocess import os from av import VideoFrame +from PIL import Image from django.core.exceptions import ValidationError @@ -95,4 +96,6 @@ def rotate_image(image, angle): def md5_hash(frame): if isinstance(frame, VideoFrame): frame = frame.to_image() + elif isinstance(frame, str): + frame = Image.open(frame, 'r') return hashlib.md5(frame.tobytes()).hexdigest() # nosec \ No newline at end of file diff --git a/cvat/apps/engine/views.py b/cvat/apps/engine/views.py index 0e68477c17b6..a389d64d01be 100644 --- a/cvat/apps/engine/views.py +++ b/cvat/apps/engine/views.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: MIT import io +import json import os import os.path as osp import shutil @@ -10,7 +11,7 @@ import uuid from datetime import datetime from distutils.util import strtobool -from tempfile import mkstemp +from tempfile import mkstemp, NamedTemporaryFile import cv2 from django.db.models.query import Prefetch @@ -19,14 +20,14 @@ from django.conf import settings from django.contrib.auth.models import User from django.db import IntegrityError -from django.http import HttpResponse +from django.http import HttpResponse, HttpResponseNotFound, HttpResponseBadRequest from django.shortcuts import get_object_or_404 from django.utils import timezone from django.utils.decorators import method_decorator from django_filters import rest_framework as filters from django_filters.rest_framework import DjangoFilterBackend from drf_yasg import openapi -from drf_yasg.inspectors import CoreAPICompatInspector, NotHandled +from drf_yasg.inspectors import CoreAPICompatInspector, NotHandled, FieldInspector from drf_yasg.utils import swagger_auto_schema from rest_framework import mixins, serializers, status, viewsets from rest_framework.decorators import action @@ -39,13 +40,16 @@ import cvat.apps.dataset_manager as dm import cvat.apps.dataset_manager.views # pylint: disable=unused-import from cvat.apps.authentication import auth +from cvat.apps.engine.cloud_provider import get_cloud_storage_instance, Credentials from cvat.apps.dataset_manager.bindings import CvatImportError from cvat.apps.dataset_manager.serializers import DatasetFormatsSerializer from cvat.apps.engine.frame_provider import FrameProvider from cvat.apps.engine.models import ( Job, StatusChoice, Task, Project, Review, Issue, - Comment, StorageMethodChoice, ReviewStatus, StorageChoice, Image + Comment, StorageMethodChoice, ReviewStatus, StorageChoice, Image, + CredentialsTypeChoice, CloudProviderChoice ) +from cvat.apps.engine.models import CloudStorage as CloudStorageModel from cvat.apps.engine.serializers import ( AboutSerializer, AnnotationFileSerializer, BasicUserSerializer, DataMetaSerializer, DataSerializer, ExceptionSerializer, @@ -53,14 +57,13 @@ LogEventSerializer, ProjectSerializer, ProjectSearchSerializer, ProjectWithoutTaskSerializer, RqStatusSerializer, TaskSerializer, UserSerializer, PluginsSerializer, ReviewSerializer, CombinedReviewSerializer, IssueSerializer, CombinedIssueSerializer, CommentSerializer, - TaskFileSerializer, -) + CloudStorageSerializer, BaseCloudStorageSerializer, TaskFileSerializer,) +from utils.dataset_manifest import ImageManifestManager from cvat.apps.engine.utils import av_scan_paths from cvat.apps.engine.backup import import_task from . import models, task from .log import clogger, slogger - class ServerViewSet(viewsets.ViewSet): serializer_class = None @@ -547,9 +550,12 @@ def data(self, request, pk): if data['use_cache']: db_task.data.storage_method = StorageMethodChoice.CACHE db_task.data.save(update_fields=['storage_method']) - if data['server_files'] and data.get('copy_data') == False: + if data['server_files'] and not data.get('copy_data'): db_task.data.storage = StorageChoice.SHARE db_task.data.save(update_fields=['storage']) + if db_data.cloud_storage: + db_task.data.storage = StorageChoice.CLOUD_STORAGE + db_task.data.save(update_fields=['storage']) # if the value of stop_frame is 0, then inside the function we cannot know # the value specified by the user or it's default value from the database if 'stop_frame' not in serializer.validated_data: @@ -1100,6 +1106,199 @@ def self(self, request): serializer = serializer_class(request.user, context={ "request": request }) return Response(serializer.data) +class RedefineDescriptionField(FieldInspector): + # pylint: disable=no-self-use + def process_result(self, result, method_name, obj, **kwargs): + if isinstance(result, openapi.Schema): + if hasattr(result, 'title') and result.title == 'Specific attributes': + result.description = 'structure like key1=value1&key2=value2\n' \ + 'supported: range=aws_range' + return result + +@method_decorator( + name='retrieve', + decorator=swagger_auto_schema( + operation_summary='Method returns details of a specific cloud storage', + responses={ + '200': openapi.Response(description='A details of a storage'), + }, + tags=['cloud storages'] + ) +) +@method_decorator(name='list', decorator=swagger_auto_schema( + operation_summary='Returns a paginated list of storages according to query parameters', + manual_parameters=[ + openapi.Parameter('provider_type', openapi.IN_QUERY, description="A supported provider of cloud storages", + type=openapi.TYPE_STRING, enum=CloudProviderChoice.list()), + openapi.Parameter('display_name', openapi.IN_QUERY, description="A display name of storage", type=openapi.TYPE_STRING), + openapi.Parameter('resource', openapi.IN_QUERY, description="A name of bucket or container", type=openapi.TYPE_STRING), + openapi.Parameter('owner', openapi.IN_QUERY, description="A resource owner", type=openapi.TYPE_STRING), + openapi.Parameter('credentials_type', openapi.IN_QUERY, description="A type of a granting access", type=openapi.TYPE_STRING, enum=CredentialsTypeChoice.list()), + ], + responses={'200': BaseCloudStorageSerializer(many=True)}, + tags=['cloud storages'], + field_inspectors=[RedefineDescriptionField] + ) +) +@method_decorator(name='destroy', decorator=swagger_auto_schema( + operation_summary='Method deletes a specific cloud storage', + tags=['cloud storages'] + ) +) +@method_decorator(name='partial_update', decorator=swagger_auto_schema( + operation_summary='Methods does a partial update of chosen fields in a cloud storage instance', + tags=['cloud storages'], + field_inspectors=[RedefineDescriptionField] + ) +) +class CloudStorageViewSet(auth.CloudStorageGetQuerySetMixin, viewsets.ModelViewSet): + http_method_names = ['get', 'post', 'patch', 'delete'] + queryset = CloudStorageModel.objects.all().prefetch_related('data').order_by('-id') + search_fields = ('provider_type', 'display_name', 'resource', 'owner__username') + filterset_fields = ['provider_type', 'display_name', 'resource', 'credentials_type'] + + def get_permissions(self): + http_method = self.request.method + permissions = [IsAuthenticated] + + if http_method in SAFE_METHODS: + permissions.append(auth.CloudStorageAccessPermission) + elif http_method in ("POST", "PATCH", "DELETE"): + permissions.append(auth.CloudStorageChangePermission) + else: + permissions.append(auth.AdminRolePermission) + return [perm() for perm in permissions] + + def get_serializer_class(self): + if self.request.method in ("POST", "PATCH"): + return CloudStorageSerializer + else: + return BaseCloudStorageSerializer + + def get_queryset(self): + queryset = super().get_queryset() + if (provider_type := self.request.query_params.get('provider_type', None)): + if provider_type in CloudProviderChoice.list(): + return queryset.filter(provider_type=provider_type) + raise ValidationError('Unsupported type of cloud provider') + return queryset + + def perform_create(self, serializer): + # check that instance of cloud storage exists + provider_type = serializer.validated_data.get('provider_type') + credentials = Credentials( + session_token=serializer.validated_data.get('session_token', ''), + account_name=serializer.validated_data.get('account_name', ''), + key=serializer.validated_data.get('key', ''), + secret_key=serializer.validated_data.get('secret_key', '') + ) + details = { + 'resource': serializer.validated_data.get('resource'), + 'credentials': credentials, + 'specific_attributes': { + item.split('=')[0].strip(): item.split('=')[1].strip() + for item in serializer.validated_data.get('specific_attributes').split('&') + } if len(serializer.validated_data.get('specific_attributes', '')) + else dict() + } + storage = get_cloud_storage_instance(cloud_provider=provider_type, **details) + try: + storage.exists() + except Exception as ex: + message = str(ex) + slogger.glob.error(message) + raise + + owner = self.request.data.get('owner') + if owner: + serializer.save() + else: + serializer.save(owner=self.request.user) + + def perform_destroy(self, instance): + cloud_storage_dirname = instance.get_storage_dirname() + super().perform_destroy(instance) + shutil.rmtree(cloud_storage_dirname, ignore_errors=True) + + @method_decorator(name='create', decorator=swagger_auto_schema( + operation_summary='Method creates a cloud storage with a specified characteristics', + responses={ + '201': openapi.Response(description='A storage has beed created') + }, + tags=['cloud storages'], + field_inspectors=[RedefineDescriptionField], + ) + ) + def create(self, request, *args, **kwargs): + try: + response = super().create(request, *args, **kwargs) + except IntegrityError: + response = HttpResponseBadRequest('Same storage already exists') + except ValidationError as exceptions: + msg_body = "" + for ex in exceptions.args: + for field, ex_msg in ex.items(): + msg_body += ": ".join([field, str(ex_msg[0])]) + msg_body += '\n' + return HttpResponseBadRequest(msg_body) + except APIException as ex: + return Response(data=ex.get_full_details(), status=ex.status_code) + except Exception as ex: + response = HttpResponseBadRequest(str(ex)) + return response + + @swagger_auto_schema( + method='get', + operation_summary='Method returns a mapped names of an available files from a storage and a manifest content', + manual_parameters=[ + openapi.Parameter('manifest_path', openapi.IN_QUERY, + description="Path to the manifest file in a cloud storage", + type=openapi.TYPE_STRING) + ], + responses={ + '200': openapi.Response(description='Mapped names of an available files from a storage and a manifest content'), + }, + tags=['cloud storages'] + ) + @action(detail=True, methods=['GET'], url_path='content') + def content(self, request, pk): + try: + db_storage = CloudStorageModel.objects.get(pk=pk) + credentials = Credentials() + credentials.convert_from_db({ + 'type': db_storage.credentials_type, + 'value': db_storage.credentials, + }) + details = { + 'resource': db_storage.resource, + 'credentials': credentials, + 'specific_attributes': db_storage.get_specific_attributes() + } + storage = get_cloud_storage_instance(cloud_provider=db_storage.provider_type, **details) + storage.initialize_content() + storage_files = storage.content + + manifest_path = request.query_params.get('manifest_path', 'manifest.jsonl') + with NamedTemporaryFile(mode='w+b', suffix='manifest', prefix='cvat') as tmp_manifest: + storage.download_file(manifest_path, tmp_manifest.name) + manifest = ImageManifestManager(tmp_manifest.name) + manifest.init_index() + manifest_files = manifest.data + content = {f:[] for f in set(storage_files) | set(manifest_files)} + for key, _ in content.items(): + if key in storage_files: content[key].append('s') # storage + if key in manifest_files: content[key].append('m') # manifest + + data = json.dumps(content) + return Response(data=data, content_type="aplication/json") + + except CloudStorageModel.DoesNotExist: + message = f"Storage {pk} does not exist" + slogger.glob.error(message) + return HttpResponseNotFound(message) + except Exception as ex: + return HttpResponseBadRequest(str(ex)) + def rq_handler(job, exc_type, exc_value, tb): job.exc_info = "".join( traceback.format_exception_only(exc_type, exc_value)) diff --git a/cvat/requirements/base.txt b/cvat/requirements/base.txt index 08d61d6faf05..a0433b1d3537 100644 --- a/cvat/requirements/base.txt +++ b/cvat/requirements/base.txt @@ -45,6 +45,8 @@ tensorflow==2.4.1 # Optional requirement of Datumaro patool==1.12 diskcache==5.0.2 open3d==0.11.2 +boto3==1.17.61 +azure-storage-blob==12.8.1 # --no-binary=datumaro: workaround for pip to install # opencv-headless instead of regular opencv, to actually run setup script # --no-binary=pycocotools: workaround for binary incompatibility on numpy 1.20 diff --git a/cvat/settings/base.py b/cvat/settings/base.py index 7e633173a205..3da6c06854ea 100644 --- a/cvat/settings/base.py +++ b/cvat/settings/base.py @@ -56,8 +56,8 @@ def add_ssh_keys(): IGNORE_FILES = ('README.md', 'ssh.pid') keys_to_add = [entry.name for entry in os.scandir(ssh_dir) if entry.name not in IGNORE_FILES] keys_to_add = ' '.join(os.path.join(ssh_dir, f) for f in keys_to_add) - subprocess.run(['ssh-add {}'.format(keys_to_add)], - shell = True, + subprocess.run(['ssh-add {}'.format(keys_to_add)], # nosec + shell=True, stderr = subprocess.PIPE, # lets set the timeout if ssh-add requires a input passphrase for key # otherwise the process will be freezed @@ -68,14 +68,14 @@ def add_ssh_keys(): fcntl.flock(pid, fcntl.LOCK_EX) try: add_ssh_keys() - keys = subprocess.run(['ssh-add -l'], shell = True, + keys = subprocess.run(['ssh-add', '-l'], # nosec stdout = subprocess.PIPE).stdout.decode('utf-8').split('\n') if 'has no identities' in keys[0]: print('SSH keys were not found') volume_keys = os.listdir(keys_dir) if not ('id_rsa' in volume_keys and 'id_rsa.pub' in volume_keys): print('New pair of keys are being generated') - subprocess.run(['ssh-keygen -b 4096 -t rsa -f {}/id_rsa -q -N ""'.format(ssh_dir)], shell = True) + subprocess.run(['ssh-keygen -b 4096 -t rsa -f {}/id_rsa -q -N ""'.format(ssh_dir)], shell=True) # nosec shutil.copyfile('{}/id_rsa'.format(ssh_dir), '{}/id_rsa'.format(keys_dir)) shutil.copymode('{}/id_rsa'.format(ssh_dir), '{}/id_rsa'.format(keys_dir)) shutil.copyfile('{}/id_rsa.pub'.format(ssh_dir), '{}/id_rsa.pub'.format(keys_dir)) @@ -86,15 +86,15 @@ def add_ssh_keys(): shutil.copymode('{}/id_rsa'.format(keys_dir), '{}/id_rsa'.format(ssh_dir)) shutil.copyfile('{}/id_rsa.pub'.format(keys_dir), '{}/id_rsa.pub'.format(ssh_dir)) shutil.copymode('{}/id_rsa.pub'.format(keys_dir), '{}/id_rsa.pub'.format(ssh_dir)) - subprocess.run(['ssh-add', '{}/id_rsa'.format(ssh_dir)], shell = True) + subprocess.run(['ssh-add', '{}/id_rsa'.format(ssh_dir)]) # nosec finally: fcntl.flock(pid, fcntl.LOCK_UN) try: if os.getenv("SSH_AUTH_SOCK", None): generate_ssh_keys() -except Exception: - pass +except Exception as ex: + print(str(ex)) INSTALLED_APPS = [ 'django.contrib.admin', @@ -369,6 +369,9 @@ def add_ssh_keys(): MIGRATIONS_LOGS_ROOT = os.path.join(LOGS_ROOT, 'migrations') os.makedirs(MIGRATIONS_LOGS_ROOT, exist_ok=True) +CLOUD_STORAGE_ROOT = os.path.join(DATA_ROOT, 'storages') +os.makedirs(CLOUD_STORAGE_ROOT, exist_ok=True) + LOGGING = { 'version': 1, 'disable_existing_loggers': False, diff --git a/utils/dataset_manifest/core.py b/utils/dataset_manifest/core.py index 7a82f8eace2c..b357daf9b58e 100644 --- a/utils/dataset_manifest/core.py +++ b/utils/dataset_manifest/core.py @@ -5,7 +5,7 @@ import av import json import os -from abc import ABC, abstractmethod +from abc import ABC, abstractmethod, abstractproperty from collections import OrderedDict from contextlib import closing from PIL import Image @@ -327,6 +327,10 @@ def __getitem__(self, item): def index(self): return self._index + @abstractproperty + def data(self): + pass + class VideoManifestManager(_ManifestManager): def __init__(self, manifest_path): super().__init__(manifest_path) @@ -376,6 +380,22 @@ def prepare_meta(media_file, upload_dir=None, chunk_size=36, force=False): meta_info.validate_seek_key_frames() return meta_info + @property + def video_name(self): + return self['properties']['name'] + + @property + def video_resolution(self): + return self['properties']['resolution'] + + @property + def video_length(self): + return self['properties']['length'] + + @property + def data(self): + return [self.video_name] + #TODO: add generic manifest structure file validation class ManifestValidator: def validate_base_info(self): @@ -419,7 +439,7 @@ def validate_frame_numbers(self): # not all videos contain information about numbers of frames frames = video_stream.frames if frames: - assert frames == self['properties']['length'], "The uploaded manifest does not match the video" + assert frames == self.video_length, "The uploaded manifest does not match the video" return class ImageManifestManager(_ManifestManager): @@ -452,4 +472,8 @@ def partial_update(self, number, properties): def prepare_meta(sources, **kwargs): meta_info = DatasetImagesReader(sources=sources, **kwargs) meta_info.create() - return meta_info \ No newline at end of file + return meta_info + + @property + def data(self): + return [f"{image['name']}{image['extension']}" for _, image in self] \ No newline at end of file