Skip to content

Commit

Permalink
Tried to optimize annotation fetching (#5974)
Browse files Browse the repository at this point in the history
  • Loading branch information
bsekachev authored Apr 5, 2023
1 parent 2090a3c commit 791d93f
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 8 deletions.
6 changes: 3 additions & 3 deletions cvat/apps/dataset_manager/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def _init_tags_from_db(self):
self._extend_attributes(db_tag.labeledimageattributeval_set,
self.db_attributes[db_tag.label_id]["all"].values())

serializer = serializers.LabeledImageSerializer(db_tags, many=True)
serializer = serializers.LabeledImageSerializerFromDB(db_tags, many=True)
self.ir_data.tags = serializer.data

def _init_shapes_from_db(self):
Expand Down Expand Up @@ -453,7 +453,7 @@ def _init_shapes_from_db(self):
for shape_id, shape_elements in elements.items():
shapes[shape_id].elements = shape_elements

serializer = serializers.LabeledShapeSerializer(list(shapes.values()), many=True)
serializer = serializers.LabeledShapeSerializerFromDB(list(shapes.values()), many=True)
self.ir_data.shapes = serializer.data

def _init_tracks_from_db(self):
Expand Down Expand Up @@ -546,7 +546,7 @@ def _init_tracks_from_db(self):
for track_id, track_elements in elements.items():
tracks[track_id].elements = track_elements

serializer = serializers.LabeledTrackSerializer(list(tracks.values()), many=True)
serializer = serializers.LabeledTrackSerializerFromDB(list(tracks.values()), many=True)
self.ir_data.tracks = serializer.data

def _init_version_from_db(self):
Expand Down
6 changes: 2 additions & 4 deletions cvat/apps/engine/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from cvat.apps.engine.models import Location
from cvat.apps.engine.location import StorageType, get_location_configuration
from cvat.apps.engine.serializers import DataSerializer, LabeledDataSerializer
from cvat.apps.engine.serializers import DataSerializer
from cvat.apps.webhooks.signals import signal_update, signal_create, signal_delete

class TusFile:
Expand Down Expand Up @@ -278,9 +278,7 @@ def export_annotations(self, request, db_obj, export_func, callback, get_data=No
return Response("Format is not specified",status=status.HTTP_400_BAD_REQUEST)

data = get_data(self._object.pk)
serializer = LabeledDataSerializer(data=data)
if serializer.is_valid(raise_exception=True):
return Response(serializer.data)
return Response(data)

def import_annotations(self, request, db_obj, import_func, rq_func, rq_id):
is_tus_request = request.headers.get('Upload-Length', None) is not None or \
Expand Down
58 changes: 57 additions & 1 deletion cvat/apps/engine/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,7 +1231,6 @@ def run_child_validation(self, data):

raise exceptions.ValidationError(errors)


class ShapeSerializer(serializers.Serializer):
type = serializers.ChoiceField(choices=models.ShapeType.choices())
occluded = serializers.BooleanField(default=False)
Expand All @@ -1249,6 +1248,63 @@ class SubLabeledShapeSerializer(ShapeSerializer, AnnotationSerializer):
class LabeledShapeSerializer(SubLabeledShapeSerializer):
elements = SubLabeledShapeSerializer(many=True, required=False)

def _convert_annotation(obj, keys):
return OrderedDict([(key, obj[key]) for key in keys])

def _convert_attributes(attr_set):
attr_keys = ['spec_id', 'value']
return [
OrderedDict([(key, attr[key]) for key in attr_keys]) for attr in attr_set
]

class LabeledImageSerializerFromDB(serializers.BaseSerializer):
# Use this serializer to export data from the database
# Because default DRF serializer is too slow on huge collections
def to_representation(self, instance):
def convert_tag(tag):
result = _convert_annotation(tag, ['id', 'label_id', 'frame', 'group', 'source'])
result['attributes'] = _convert_attributes(tag['labeledimageattributeval_set'])
return result

return convert_tag(instance)

class LabeledShapeSerializerFromDB(serializers.BaseSerializer):
# Use this serializer to export data from the database
# Because default DRF serializer is too slow on huge collections
def to_representation(self, instance):
def convert_shape(shape):
result = _convert_annotation(shape, [
'id', 'label_id', 'type', 'frame', 'group', 'source',
'occluded', 'outside', 'z_order', 'rotation', 'points',
])
result['attributes'] = _convert_attributes(shape['labeledshapeattributeval_set'])
if shape.get('elements', None) is not None and shape['parent'] is None:
result['elements'] = [convert_shape(element) for element in shape['elements']]
return result

return convert_shape(instance)

class LabeledTrackSerializerFromDB(serializers.BaseSerializer):
# Use this serializer to export data from the database
# Because default DRF serializer is too slow on huge collections
def to_representation(self, instance):
def convert_track(track):
shape_keys = [
'id', 'type', 'frame', 'occluded', 'outside', 'z_order',
'rotation', 'points', 'trackedshapeattributeval_set',
]
result = _convert_annotation(track, ['id', 'label_id', 'frame', 'group', 'source'])
result['shapes'] = [_convert_annotation(shape, shape_keys) for shape in track['trackedshape_set']]
result['attributes'] = _convert_attributes(track['labeledtrackattributeval_set'])
for shape in result['shapes']:
shape['attributes'] = _convert_attributes(shape['trackedshapeattributeval_set'])
shape.pop('trackedshapeattributeval_set', None)
if track.get('elements', None) is not None and track['parent'] is None:
result['elements'] = [convert_track(element) for element in track['elements']]
return result

return convert_track(instance)

class TrackedShapeSerializer(ShapeSerializer):
id = serializers.IntegerField(default=None, allow_null=True)
frame = serializers.IntegerField(min_value=0)
Expand Down

0 comments on commit 791d93f

Please sign in to comment.