Skip to content

Commit

Permalink
Some fixes & bandit & add specific attr
Browse files Browse the repository at this point in the history
  • Loading branch information
Marishka17 committed Apr 22, 2021
1 parent 6413906 commit 91c0e42
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 81 deletions.
10 changes: 6 additions & 4 deletions cvat/apps/engine/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ImageDatasetManifestReader, VideoDatasetManifestReader)
from cvat.apps.engine.models import DataChoice, StorageChoice
from cvat.apps.engine.models import DimensionType
from cvat.apps.engine.cloud_provider import CloudStorage, Credentials
from cvat.apps.engine.cloud_provider import get_cloud_storage_instance, Credentials
from cvat.apps.engine.utils import md5_hash
class CacheInteraction:
def __init__(self, dimension=DimensionType.DIM_2D):
Expand Down Expand Up @@ -50,7 +50,8 @@ def prepare_chunk_buff(self, db_data, quality, chunk_number):
buff = BytesIO()
upload_dir = {
StorageChoice.LOCAL: db_data.get_upload_dirname(),
StorageChoice.SHARE: settings.SHARE_ROOT
StorageChoice.SHARE: settings.SHARE_ROOT,
StorageChoice.CLOUD_STORAGE: db_data.get_upload_dirname()
}[db_data.storage]
if hasattr(db_data, 'video'):
source_path = os.path.join(upload_dir, db_data.video.path)
Expand All @@ -75,9 +76,10 @@ def prepare_chunk_buff(self, db_data, quality, chunk_number):
})
details = {
'resource': db_cloud_storage.resource,
'credentials': credentials
'credentials': credentials,
'specific_attributes': db_cloud_storage.get_specific_attributes()
}
cloud_storage_instance = CloudStorage(cloud_provider=db_cloud_storage.provider_type, **details)
cloud_storage_instance = get_cloud_storage_instance(cloud_provider=db_cloud_storage.provider_type, **details)
cloud_storage_instance.initialize_content()
for item in reader:
name = f"{item['name']}{item['extension']}"
Expand Down
76 changes: 35 additions & 41 deletions cvat/apps/engine/cloud_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def upload_file(self, file_obj, file_name):
pass

def __contains__(self, file_name):
return file_name in (item['name'] for item in self._files.values())
return file_name in (item['name'] for item in self._files)

def __len__(self):
return len(self._files)
Expand All @@ -67,53 +67,41 @@ def __len__(self):
def content(self):
return list(map(lambda x: x['name'] , self._files))

# def get_cloud_storage_instance(cloud_provider, resource, credentials):
# instance = None
# проверить креденшелы!
# if cloud_provider == CloudProviderChoice.AWS_S3:
# instance = AWS_S3(
# bucket=resource,
# session_token=credentials.session_token,
# )
# elif cloud_provider == CloudProviderChoice.AZURE_CONTAINER:
# instance = AzureBlobContainer(
# container_name=resource,
# sas_token=credentials.session_token,
# )
# return instance

# TODO: подумать возможно оставить функцию provider вместо класса ниже
class CloudStorage:
def __init__(self, cloud_provider, resource, credentials):
if cloud_provider == CloudProviderChoice.AWS_S3:
self.__instance = AWS_S3(
bucket=resource,
access_key_id=credentials.key,
secret_key=credentials.secret_key,
session_token=credentials.session_token,
)
elif cloud_provider == CloudProviderChoice.AZURE_CONTAINER:
self.__instance = AzureBlobContainer(
container=resource,
account_name=credentials.account_name,
sas_token=credentials.session_token,
)
else:
raise NotImplementedError()

def __getattr__(self, name):
assert hasattr(self.__instance, name), 'Unknown behavior: {}'.format(name)
return self.__instance.__getattribute__(name)
def get_cloud_storage_instance(cloud_provider, resource, credentials, specific_attributes=None):
instance = None
if cloud_provider == CloudProviderChoice.AWS_S3:
instance = AWS_S3(
bucket=resource,
access_key_id=credentials.key,
secret_key=credentials.secret_key,
session_token=credentials.session_token,
region=specific_attributes.get('region', 'us-east-2')
)
elif cloud_provider == CloudProviderChoice.AZURE_CONTAINER:
instance = AzureBlobContainer(
container=resource,
account_name=credentials.account_name,
sas_token=credentials.session_token
)
else:
raise NotImplementedError()
return instance

class AWS_S3(_CloudStorage):
def __init__(self, bucket, access_key_id=None, secret_key=None, session_token=None):
def __init__(self,
bucket,
region,
access_key_id=None,
secret_key=None,
session_token=None):
super().__init__()
if all([access_key_id, secret_key, session_token]):
self._client_s3 = boto3.client(
's3',
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_key,
aws_session_token=session_token,
region_name=region
)
elif any([access_key_id, secret_key, session_token]):
raise Exception('Insufficient data for authorization')
Expand All @@ -123,6 +111,7 @@ def __init__(self, bucket, access_key_id=None, secret_key=None, session_token=No
self._s3.meta.client.meta.events.register('choose-signer.s3.*', disable_signing)
self._client_s3 = self._s3.meta.client
self._bucket = self._s3.Bucket(bucket)
self.region = region

@property
def bucket(self):
Expand Down Expand Up @@ -195,13 +184,18 @@ def download_fileobj(self, key):

def create(self):
try:
_ = self._bucket.create(
responce = self._bucket.create(
ACL='private',
CreateBucketConfiguration={
'LocationConstraint': 'us-east-2',#TODO
'LocationConstraint': self.region,
},
ObjectLockEnabledForBucket=False
)
slogger.glob.info(
'Bucket {} has been created on {} region'.format(
self.name,
responce['Location']
))
except Exception as ex:
msg = str(ex)
slogger.glob.info(msg)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Generated by Django 3.1.7 on 2021-04-14 21:10
# Generated by Django 3.1.7 on 2021-04-22 09:45

import cvat.apps.engine.models
from django.conf import settings
Expand Down Expand Up @@ -29,6 +29,8 @@ class Migration(migrations.Migration):
('updated_date', models.DateTimeField(auto_now=True)),
('credentials', models.CharField(max_length=500)),
('credentials_type', models.CharField(choices=[('TEMP_KEY_SECRET_KEY_TOKEN_PAIR', 'TEMP_KEY_SECRET_KEY_TOKEN_PAIR'), ('ACCOUNT_NAME_TOKEN_PAIR', 'ACCOUNT_NAME_TOKEN_PAIR'), ('ANONYMOUS_ACCESS', 'ANONYMOUS_ACCESS')], max_length=30)),
('specific_attributes', models.CharField(blank=True, max_length=50)),
('description', models.TextField(default='')),
('owner', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='cloud_storages', to=settings.AUTH_USER_MODEL)),
],
options={
Expand Down
16 changes: 13 additions & 3 deletions cvat/apps/engine/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,8 +554,9 @@ def __str__(self):
return self.value

class CredentialsTypeChoice(str, Enum):
TEMP_KEY_SECRET_KEY_TOKEN_PAIR = 'TEMP_KEY_SECRET_KEY_TOKEN_PAIR'
ACCOUNT_NAME_TOKEN_PAIR = 'ACCOUNT_NAME_TOKEN_PAIR'
# ignore bandit issues because false positives
TEMP_KEY_SECRET_KEY_TOKEN_PAIR = 'TEMP_KEY_SECRET_KEY_TOKEN_PAIR' # nosec
ACCOUNT_NAME_TOKEN_PAIR = 'ACCOUNT_NAME_TOKEN_PAIR' # nosec
ANONYMOUS_ACCESS = 'ANONYMOUS_ACCESS'

@classmethod
Expand All @@ -578,6 +579,8 @@ class CloudStorage(models.Model):
updated_date = models.DateTimeField(auto_now=True)
credentials = models.CharField(max_length=500)
credentials_type = models.CharField(max_length=30, choices=CredentialsTypeChoice.choices())#auth_type
specific_attributes = models.CharField(max_length=50, blank=True)
description = models.TextField(default='')

class Meta:
default_permissions = ()
Expand All @@ -593,4 +596,11 @@ def get_storage_logs_dirname(self):
return os.path.join(self.get_storage_dirname(), 'logs')

def get_log_path(self):
return os.path.join(self.get_storage_dirname(), "storage.log")
return os.path.join(self.get_storage_dirname(), "storage.log")

def get_specific_attributes(self):
attributes = self.specific_attributes.split('&')
return {
item.split('=')[0].strip(): item.split('=')[1].strip()
for item in attributes
} if len(attributes) else dict()
13 changes: 9 additions & 4 deletions cvat/apps/engine/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from cvat.apps.dataset_manager.formats.utils import get_label_color
from cvat.apps.engine import models
from cvat.apps.engine.cloud_provider import Credentials, CloudStorage
from cvat.apps.engine.cloud_provider import get_cloud_storage_instance, Credentials
from cvat.apps.engine.log import slogger

class BasicUserSerializer(serializers.ModelSerializer):
Expand Down Expand Up @@ -725,7 +725,7 @@ class Meta:
fields = (
'provider_type', 'resource', 'owner', 'credentials_type',
'created_date', 'updated_date', 'session_token', 'account_name', 'key',
'secret_key'
'secret_key', 'specific_attributes', 'description'
)
read_only_fields = ('created_date', 'updated_date', 'owner')

Expand All @@ -748,9 +748,14 @@ def create(self, validated_data):
if should_be_created:
details = {
'resource': validated_data.get('resource'),
'credentials': credentials
'credentials': credentials,
'specific_attributes': {
item.split('=')[0].strip(): item.split('=')[1].strip()
for item in validated_data.get('specific_attributes').split('&')
} if len(validated_data.get('specific_attributes', ''))
else dict()
}
storage = CloudStorage(cloud_provider=provider_type, **details)
storage = get_cloud_storage_instance(cloud_provider=provider_type, **details)
try:
storage.create()
except Exception as ex:
Expand Down
43 changes: 25 additions & 18 deletions cvat/apps/engine/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from . import models
from .log import slogger
from .cloud_provider import CloudStorage, Credentials
from .cloud_provider import get_cloud_storage_instance, Credentials

############################# Low Level server API

Expand Down Expand Up @@ -236,6 +236,23 @@ def _create_thread(tid, data):
_copy_data_from_share(data['server_files'], upload_dir)
elif db_data.storage == StorageChoice.SHARE:
upload_dir = settings.SHARE_ROOT
else: # cloud storage
if not manifest_file: raise Exception('A manifest file not found')
db_cloud_storage = db_data.cloud_storage
credentials = Credentials()
credentials.convert_from_db({
'type': db_cloud_storage.credentials_type,
'value': db_cloud_storage.value,
})

details = {
'resource': db_cloud_storage.resource,
'credentials': credentials,
'specific_attributes': db_cloud_storage.get_specific_attributes()
}
cloud_storage_instance = get_cloud_storage_instance(cloud_provider=db_cloud_storage.provider_type, **details)
cloud_storage_instance.download_file(manifest_file[0], db_data.get_manifest_path())
cloud_storage_instance.download_file(media['image'][0], os.path.join(upload_dir, media['image'][0]))

av_scan_paths(upload_dir)

Expand Down Expand Up @@ -317,7 +334,13 @@ def update_progress(progress):
# calculate chunk size if it isn't specified
if db_data.chunk_size is None:
if isinstance(compressed_chunk_writer, ZipCompressedChunkWriter):
w, h = extractor.get_image_size(0)
if not (db_data.storage == StorageChoice.CLOUD_STORAGE):
w, h = extractor.get_image_size(0)
else:
manifest = ImageManifestManager(db_data.get_manifest_path())
manifest.init_index()
img_properties = manifest[0]
w, h = img_properties['width'], img_properties['height']
area = h * w
db_data.chunk_size = max(2, min(72, 36 * 1920 * 1080 // area))
else:
Expand Down Expand Up @@ -401,8 +424,6 @@ def _update_status(msg):
db_data.size = len(extractor)
manifest = ImageManifestManager(db_data.get_manifest_path())
if not manifest_file:
if db_data.storage == StorageChoice.CLOUD_STORAGE:
raise Exception('A manifest file was not foud')
if db_task.dimension == DimensionType.DIM_2D:
meta_info = manifest.prepare_meta(
sources=extractor.absolute_source_paths,
Expand All @@ -418,20 +439,6 @@ def _update_status(msg):
'extension': ext
})
manifest.create(content)
if db_data.storage == StorageChoice.CLOUD_STORAGE:
db_cloud_storage = db_data.cloud_storage
credentials = Credentials()
credentials.convert_from_db({
'type': db_cloud_storage.credentials_type,
'value': db_cloud_storage.value,
})

details = {
'resource': db_cloud_storage.resource,
'credentials': credentials,
}
cloud_storage_instance = CloudStorage(cloud_provider=db_cloud_storage.provider_type, **details)
cloud_storage_instance.download_file(manifest_file[0], db_data.get_manifest_path())
manifest.init_index()
counter = itertools.count()
for _, chunk_frames in itertools.groupby(extractor.frame_range, lambda x: next(counter) // db_data.chunk_size):
Expand Down
3 changes: 3 additions & 0 deletions cvat/apps/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import subprocess
import os
from av import VideoFrame
from PIL import Image

from django.core.exceptions import ValidationError

Expand Down Expand Up @@ -95,4 +96,6 @@ def rotate_image(image, angle):
def md5_hash(frame):
if isinstance(frame, VideoFrame):
frame = frame.to_image()
elif isinstance(frame, str):
frame = Image.open(frame, 'r')
return hashlib.md5(frame.tobytes()).hexdigest() # nosec
12 changes: 9 additions & 3 deletions cvat/apps/engine/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import cvat.apps.dataset_manager as dm
import cvat.apps.dataset_manager.views # pylint: disable=unused-import
from cvat.apps.authentication import auth
from cvat.apps.engine.cloud_provider import Credentials, CloudStorage
from cvat.apps.engine.cloud_provider import get_cloud_storage_instance, Credentials
from cvat.apps.dataset_manager.bindings import CvatImportError
from cvat.apps.dataset_manager.serializers import DatasetFormatsSerializer
from cvat.apps.engine.frame_provider import FrameProvider
Expand Down Expand Up @@ -1067,8 +1067,13 @@ def perform_create(self, serializer):
details = {
'resource': serializer.validated_data.get('resource'),
'credentials': credentials,
'specific_attributes': {
item.split('=')[0].strip(): item.split('=')[1].strip()
for item in serializer.validated_data.get('specific_attributes').split('&')
} if len(serializer.validated_data.get('specific_attributes', ''))
else dict()
}
storage = CloudStorage(cloud_provider=provider_type, **details)
storage = get_cloud_storage_instance(cloud_provider=provider_type, **details)
try:
storage.is_exist()
except Exception as ex:
Expand Down Expand Up @@ -1120,8 +1125,9 @@ def retrieve(self, request, *args, **kwargs):
details = {
'resource': db_storage.resource,
'credentials': credentials,
'specific_attributes': db_storage.get_specific_attributes()
}
storage = CloudStorage(cloud_provider=db_storage.provider_type, **details)
storage = get_cloud_storage_instance(cloud_provider=db_storage.provider_type, **details)
storage.initialize_content()
storage_files = storage.content

Expand Down
Loading

0 comments on commit 91c0e42

Please sign in to comment.