Skip to content

Commit

Permalink
Introduce new option included_jobs
Browse files Browse the repository at this point in the history
  • Loading branch information
Marishka17 committed Nov 5, 2024
1 parent 8b8140e commit d062b95
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
1 change: 0 additions & 1 deletion cvat/apps/engine/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ def __str__(self):

class JobFrameSelectionMethod(str, Enum):
RANDOM_UNIFORM = 'random_uniform'
RANDOM_PER_JOB = 'random_per_job'
MANUAL = 'manual'

@classmethod
Expand Down
39 changes: 34 additions & 5 deletions cvat/apps/engine/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
from rq.job import Job as RQJob, JobStatus as RQJobStatus
from datetime import timedelta
from decimal import Decimal
from datumaro.util import take_by

from rest_framework import serializers, exceptions
from django.contrib.auth.models import User, Group
from django.db import transaction
from django.utils import timezone
from numpy import random
from django.db.models import Q

from cvat.apps.dataset_manager.formats.utils import get_label_color
from cvat.apps.engine.frame_provider import TaskFrameProvider
Expand Down Expand Up @@ -965,7 +967,9 @@ def validate(self, attrs):
elif frame_selection_method == models.JobFrameSelectionMethod.RANDOM_UNIFORM:
pass
else:
assert False
raise serializers.ValidationError(
f"Unexpected frame selection method '{frame_selection_method}'"
)

if (
'honeypot_real_frames' in attrs and
Expand Down Expand Up @@ -1272,6 +1276,12 @@ class TaskValidationLayoutWriteSerializer(serializers.Serializer):
The list of frame ids. Applicable only to the "{}" frame selection method
""".format(models.JobFrameSelectionMethod.MANUAL))
)
included_jobs = serializers.ListField(
child=serializers.IntegerField(min_value=0), required=False,
help_text=textwrap.dedent("""\
The list of jobs to be included when shuffling honeypots in a task
""")
)

def validate(self, attrs):
frame_selection_method = attrs.get("frame_selection_method")
Expand All @@ -1289,6 +1299,14 @@ def validate(self, attrs):
f'"frame_selection_method" is "{models.JobFrameSelectionMethod.MANUAL}"'
)

if (
attrs.get("included_jobs") and frame_selection_method != models.JobFrameSelectionMethod.RANDOM_UNIFORM
):
raise serializers.ValidationError(
"The field 'included_jobs' can only be used with "
f"frame_selection_method=={models.JobFrameSelectionMethod.RANDOM_UNIFORM}"
)

return super().validate(attrs)

@transaction.atomic
Expand Down Expand Up @@ -1343,12 +1361,23 @@ def update(self, instance: models.Task, validated_data: dict[str, Any]) -> model
)

if frame_selection_method:
for db_job in (
jobs_queryset = (
models.Job.objects.select_related("segment")
.filter(segment__task_id=instance.id, type=models.JobType.ANNOTATION)
.exclude(Q(stage=models.StageChoice.ACCEPTANCE) & Q(state=models.StateChoice.COMPLETED))
.order_by("segment__start_frame")
.all()
):
)

# TODO: set upper limit on the number of included_jobs
if (included_job_ids := validated_data.get('included_jobs')):
merged_queryset = models.Job.objects.none()

for job_ids_batch in take_by(included_job_ids, 1000):
merged_queryset = merged_queryset | jobs_queryset.filter(id__in=job_ids_batch)

jobs_queryset = merged_queryset

for db_job in jobs_queryset.order_by("segment__start_frame").all():
job_serializer_params = {
'frame_selection_method': frame_selection_method
}
Expand Down Expand Up @@ -1638,7 +1667,7 @@ def validate(self, attrs):
attrs, ['frames_per_job_count', 'frames_per_job_share']
)
else:
assert False, f"Unknown validation mode {attrs['mode']}"
raise serializers.ValidationError(f"Unknown validation mode {attrs['mode']}")

if attrs['frame_selection_method'] == models.JobFrameSelectionMethod.RANDOM_UNIFORM:
field_validation.require_one_of_fields(attrs, ['frame_count', 'frame_share'])
Expand Down

0 comments on commit d062b95

Please sign in to comment.