Skip to content

Commit

Permalink
Used ModelViewSets instead of GeneralViews.
Browse files Browse the repository at this point in the history
  • Loading branch information
nmanovic committed Feb 1, 2019
1 parent 4feb191 commit 6e7ea98
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 105 deletions.
25 changes: 25 additions & 0 deletions cvat/apps/engine/migrations/0022_auto_20190201_2223.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Generated by Django 2.1.5 on 2019-02-01 19:23

import cvat.apps.engine.models
import django.core.files.storage
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('engine', '0021_task_image_quality'),
]

operations = [
migrations.AlterField(
model_name='clientfile',
name='file',
field=models.FileField(storage=django.core.files.storage.FileSystemStorage(), upload_to=cvat.apps.engine.models.upload_path_handler),
),
migrations.AlterField(
model_name='task',
name='image_quality',
field=models.PositiveSmallIntegerField(),
),
]
41 changes: 18 additions & 23 deletions cvat/apps/engine/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,20 +80,29 @@ class Meta:
def to_internal_value(self, data):
return { 'file' : data }

class RequestStatusSerializer(serializers.Serializer):
state = serializers.ChoiceField(choices=["Unknown",
class RqStatusSerializer(serializers.Serializer):
state = serializers.ChoiceField(choices=[
"Queued", "Started", "Finished", "Failed"])
message = serializers.CharField(allow_blank=True, default="")

class TaskSerializer(serializers.ModelSerializer):
labels = LabelSerializer(many=True, source='label_set', partial=True)
segments = SegmentSerializer(many=True, source='segment_set', read_only=True)
class TaskDataSerializer(serializers.ModelSerializer):
client_files = ClientFileSerializer(many=True, source='clientfile_set',
write_only=True, default=[])
default=[])
server_files = ServerFileSerializer(many=True, source='serverfile_set',
write_only=True, default=[])
default=[])
remote_files = RemoteFileSerializer(many=True, source='remotefile_set',
write_only=True, default=[])
default=[])

class Meta:
model = Task
fields = ('client_files', 'server_files', 'remote_files')

def create(self, validated_data):
pass

class TaskSerializer(serializers.ModelSerializer):
labels = LabelSerializer(many=True, source='label_set', partial=True)
segments = SegmentSerializer(many=True, source='segment_set', read_only=True)
image_quality = serializers.IntegerField(min_value=0, max_value=100,
default=50)

Expand All @@ -102,16 +111,13 @@ class Meta:
fields = ('url', 'id', 'name', 'size', 'mode', 'owner', 'assignee',
'bug_tracker', 'created_date', 'updated_date', 'overlap',
'segment_size', 'z_order', 'flipped', 'status', 'labels', 'segments',
'server_files', 'client_files', 'remote_files', 'image_quality')
'image_quality')
read_only_fields = ('size', 'mode', 'created_date', 'updated_date',
'overlap', 'status', 'segment_size')
ordering = ['-id']

def create(self, validated_data):
labels = validated_data.pop('label_set')
client_files = validated_data.pop('clientfile_set')
server_files = validated_data.pop('serverfile_set')
remote_files = validated_data.pop('remotefile_set')
if not validated_data.get('segment_size'):
validated_data['segment_size'] = 0
db_task = Task.objects.create(size=0, **validated_data)
Expand All @@ -121,17 +127,6 @@ def create(self, validated_data):
for attr in attributes:
AttributeSpec.objects.create(label=db_label, **attr)

for obj in client_files:
serializer = ClientFileSerializer(data=obj['file'])
if serializer.is_valid(raise_exception=True):
serializer.save()

for path in server_files:
ServerFile.objects.create(task=db_task, file=path)

for path in remote_files:
RemoteFile.objects.create(task=db_task, file=path)

task_path = db_task.get_task_dirname()
if os.path.isdir(task_path):
shutil.rmtree(task_path)
Expand Down
36 changes: 9 additions & 27 deletions cvat/apps/engine/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,20 @@
#
# SPDX-License-Identifier: MIT

from django.urls import path
from django.urls import path, include
from . import views
from rest_framework import routers

REST_API_PREFIX = 'api/<version>/'

router = routers.DefaultRouter(trailing_slash=False)
router.register("tasks", views.TaskViewSet)
router.register("jobs", views.JobViewSet)
router.register("users", views.UserViewSet)

urlpatterns = [
# entry point for API
path(REST_API_PREFIX, views.api_root, name='root'),
# GET list of users, POST a new user
path(REST_API_PREFIX + 'users/', views.UserList.as_view(),
name='user-list'),
# GET current active user
path(REST_API_PREFIX + 'users/self', views.UserSelf.as_view(),
name='user-self'),
# GET, DELETE, PATCH the user
path(REST_API_PREFIX + 'users/<int:pk>', views.UserDetail.as_view(),
name='user-detail'),
path(REST_API_PREFIX, include(router.urls)),
# GET a frame for a specific task
path(REST_API_PREFIX + 'tasks/<int:pk>/frames/<int:frame>',
views.get_frame, name='task-frame'),
Expand All @@ -28,24 +25,9 @@
name='exception-list'),
# GET information about the backend
path(REST_API_PREFIX + 'about/', views.About.as_view(), name='about'),
# GET a list of jobs for a specific task
path(REST_API_PREFIX + 'tasks/<int:pk>/jobs/', views.JobList.as_view(),
name='job-list'),
# GET and PATCH the specific job
path(REST_API_PREFIX + 'jobs/<int:pk>', views.JobDetail.as_view(),
name='job-detail'),
# GET a list of annotation tasks, POST an annotation task
path(REST_API_PREFIX + 'tasks/', views.TaskList.as_view(),
name='task-list'),
path( # GET, DELETE, PATCH
REST_API_PREFIX + 'tasks/<int:pk>', views.TaskDetail.as_view(),
name='task-detail'),
path( # PUT
REST_API_PREFIX + 'tasks/<int:pk>/data/', views.TaskDetail.as_view(),
REST_API_PREFIX + 'tasks/<int:pk>/data', views.dummy_view,
name='task-data'),
path( # GET
REST_API_PREFIX + 'tasks/<int:pk>/status', views.TaskStatus.as_view(),
name='task-status'),
# GET meta information for all frames
path(REST_API_PREFIX + 'tasks/<int:pk>/frames/meta',
views.get_image_meta_cache, name='image-meta-cache'),
Expand Down
83 changes: 28 additions & 55 deletions cvat/apps/engine/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from rest_framework.reverse import reverse
from rest_framework.renderers import JSONRenderer
from rest_framework import status
from rest_framework import viewsets
from rest_framework.decorators import action
from rest_framework import mixins
import django_rq


Expand All @@ -32,52 +35,33 @@
from cvat.apps.engine.models import StatusChoice, Task, Job
from cvat.apps.engine.serializers import (TaskSerializer, UserSerializer,
ExceptionSerializer, AboutSerializer, JobSerializer, ImageMetaSerializer,
RequestStatusSerializer)
RqStatusSerializer)
from django.contrib.auth.models import User

# Server REST API

@api_view(['GET'])
def api_root(request, version=None):
return Response({
'tasks': reverse('task-list', request=request),
'users': reverse('user-list', request=request),
'myself': reverse('user-self', request=request),
'exceptions': reverse('exception-list', request=request),
'about': reverse('about', request=request),
'plugins': reverse('plugin-list', request=request)
})


class TaskList(generics.ListCreateAPIView):
class TaskViewSet(viewsets.ModelViewSet):
queryset = Task.objects.all()
serializer_class = TaskSerializer

def perform_create(self, serializer):
if self.request.data.get('owner', None):
serializer.save()
else:
serializer.save(owner=self.request.user)
tid = serializer.data["id"]
task.create(tid, serializer.data)

class TaskDetail(generics.RetrieveUpdateDestroyAPIView):
queryset = Task.objects.all()
serializer_class = TaskSerializer
@action(detail=True, methods=['GET'], serializer_class=JobSerializer)
def jobs(self, request, pk, version):
queryset = Job.objects.filter(segment__task_id=pk)
serializer = JobSerializer(queryset, many=True,
context={"request": request})

class TaskStatus(APIView):
serializer_class = RequestStatusSerializer
return Response(serializer.data)

def get(self, request, version, pk):
db_task = get_object_or_404(Task, pk=pk)
response = self._get_response(queue="default",
@action(detail=True, methods=['GET'], serializer_class=RqStatusSerializer)
def status(self, request, pk, version):
response = self._get_rq_response(queue="default",
job_id="/api/{}/tasks/{}".format(version, pk))
serializer = TaskStatus.serializer_class(data=response)
serializer = RqStatusSerializer(data=response)

if serializer.is_valid(raise_exception=True):
return Response(serializer.data)

def _get_response(self, queue, job_id):
def _get_rq_response(self, queue, job_id):
queue = django_rq.get_queue(queue)
job = queue.fetch_job(job_id)
response = {}
Expand All @@ -95,37 +79,26 @@ def _get_response(self, queue, job_id):
return response


class JobList(generics.ListAPIView):
queryset = Job.objects.all()
serializer_class = JobSerializer

def list(self, request, pk, version=None):
queryset = self.queryset.filter(segment__task_id=pk)
serializer = JobSerializer(queryset, many=True,
context={"request": request})
def perform_create(self, serializer):
if self.request.data.get('owner', None):
serializer.save()
else:
serializer.save(owner=self.request.user)

return Response(serializer.data)

class JobDetail(generics.RetrieveUpdateAPIView):
class JobViewSet(viewsets.GenericViewSet,
mixins.RetrieveModelMixin, mixins.UpdateModelMixin):
queryset = Job.objects.all()
serializer_class = JobSerializer

class UserList(generics.ListCreateAPIView):
queryset = User.objects.all()
serializer_class = UserSerializer


class UserDetail(generics.RetrieveUpdateDestroyAPIView):
class UserViewSet(viewsets.ModelViewSet):
queryset = User.objects.all()
serializer_class = UserSerializer


class UserSelf(generics.RetrieveAPIView):
serializer_class = UserSerializer

def get_object(self):
return self.request.user

@action(detail=False, methods=['GET'], serializer_class=UserSerializer)
def self(self, request, version):
serializer = UserSerializer(request.user, context={ "request": request })
return Response(serializer.data)

@login_required
@permission_required(perm=['engine.task.access'],
Expand Down

0 comments on commit 6e7ea98

Please sign in to comment.