From d062b95f7470f5ec3638b6855046655dc6949a39 Mon Sep 17 00:00:00 2001 From: maya Date: Tue, 5 Nov 2024 13:33:23 +0100 Subject: [PATCH] Introduce new option included_jobs --- cvat/apps/engine/models.py | 1 - cvat/apps/engine/serializers.py | 39 ++++++++++++++++++++++++++++----- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/cvat/apps/engine/models.py b/cvat/apps/engine/models.py index 6212ce3a8bc..903d298a5f4 100644 --- a/cvat/apps/engine/models.py +++ b/cvat/apps/engine/models.py @@ -177,7 +177,6 @@ def __str__(self): class JobFrameSelectionMethod(str, Enum): RANDOM_UNIFORM = 'random_uniform' - RANDOM_PER_JOB = 'random_per_job' MANUAL = 'manual' @classmethod diff --git a/cvat/apps/engine/serializers.py b/cvat/apps/engine/serializers.py index 5b3845f8260..e2a6fb139d9 100644 --- a/cvat/apps/engine/serializers.py +++ b/cvat/apps/engine/serializers.py @@ -20,12 +20,14 @@ from rq.job import Job as RQJob, JobStatus as RQJobStatus from datetime import timedelta from decimal import Decimal +from datumaro.util import take_by from rest_framework import serializers, exceptions from django.contrib.auth.models import User, Group from django.db import transaction from django.utils import timezone from numpy import random +from django.db.models import Q from cvat.apps.dataset_manager.formats.utils import get_label_color from cvat.apps.engine.frame_provider import TaskFrameProvider @@ -965,7 +967,9 @@ def validate(self, attrs): elif frame_selection_method == models.JobFrameSelectionMethod.RANDOM_UNIFORM: pass else: - assert False + raise serializers.ValidationError( + f"Unexpected frame selection method '{frame_selection_method}'" + ) if ( 'honeypot_real_frames' in attrs and @@ -1272,6 +1276,12 @@ class TaskValidationLayoutWriteSerializer(serializers.Serializer): The list of frame ids. Applicable only to the "{}" frame selection method """.format(models.JobFrameSelectionMethod.MANUAL)) ) + included_jobs = serializers.ListField( + child=serializers.IntegerField(min_value=0), required=False, + help_text=textwrap.dedent("""\ + The list of jobs to be included when shuffling honeypots in a task + """) + ) def validate(self, attrs): frame_selection_method = attrs.get("frame_selection_method") @@ -1289,6 +1299,14 @@ def validate(self, attrs): f'"frame_selection_method" is "{models.JobFrameSelectionMethod.MANUAL}"' ) + if ( + attrs.get("included_jobs") and frame_selection_method != models.JobFrameSelectionMethod.RANDOM_UNIFORM + ): + raise serializers.ValidationError( + "The field 'included_jobs' can only be used with " + f"frame_selection_method=={models.JobFrameSelectionMethod.RANDOM_UNIFORM}" + ) + return super().validate(attrs) @transaction.atomic @@ -1343,12 +1361,23 @@ def update(self, instance: models.Task, validated_data: dict[str, Any]) -> model ) if frame_selection_method: - for db_job in ( + jobs_queryset = ( models.Job.objects.select_related("segment") .filter(segment__task_id=instance.id, type=models.JobType.ANNOTATION) + .exclude(Q(stage=models.StageChoice.ACCEPTANCE) & Q(state=models.StateChoice.COMPLETED)) .order_by("segment__start_frame") - .all() - ): + ) + + # TODO: set upper limit on the number of included_jobs + if (included_job_ids := validated_data.get('included_jobs')): + merged_queryset = models.Job.objects.none() + + for job_ids_batch in take_by(included_job_ids, 1000): + merged_queryset = merged_queryset | jobs_queryset.filter(id__in=job_ids_batch) + + jobs_queryset = merged_queryset + + for db_job in jobs_queryset.order_by("segment__start_frame").all(): job_serializer_params = { 'frame_selection_method': frame_selection_method } @@ -1638,7 +1667,7 @@ def validate(self, attrs): attrs, ['frames_per_job_count', 'frames_per_job_share'] ) else: - assert False, f"Unknown validation mode {attrs['mode']}" + raise serializers.ValidationError(f"Unknown validation mode {attrs['mode']}") if attrs['frame_selection_method'] == models.JobFrameSelectionMethod.RANDOM_UNIFORM: field_validation.require_one_of_fields(attrs, ['frame_count', 'frame_share'])