Skip to content

Commit

Permalink
Fix GT jobs creation (#7126)
Browse files Browse the repository at this point in the history
- Fixed gt job creation for the whole task size
- Fixed invalid chunk writing for GT jobs
  • Loading branch information
zhiltsov-max authored Nov 28, 2023
1 parent bd27827 commit 81da411
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 7 deletions.
6 changes: 6 additions & 0 deletions changelog.d/20231110_183941_mzhiltsov_fix_gt_jobs.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
### Fixed

- It is now possible to create Ground Truth jobs containing all frames in the task
(<https://github.com/opencv/cvat/pull/7126>)
- Incorrect Ground Truth chunks saving
(<https://github.com/opencv/cvat/pull/7126>)
2 changes: 1 addition & 1 deletion cvat/apps/engine/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def get_task_chunk_data_with_mime(self, chunk_number, quality, db_data):

def get_selective_job_chunk_data_with_mime(self, chunk_number, quality, job):
item = self._get_or_set_cache_item(
key=f'{job.id}_{chunk_number}_{quality}',
key=f'job_{job.id}_{chunk_number}_{quality}',
create_function=lambda: self.prepare_selective_job_chunk(job, quality, chunk_number),
)

Expand Down
4 changes: 2 additions & 2 deletions cvat/apps/engine/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def get_frame_step(self):
return int(match.group(1)) if match else 1

def get_valid_frame_indices(self):
return range(self.start_frame, self.stop_frame, self.get_frame_step())
return range(self.start_frame, self.stop_frame + 1, self.get_frame_step())

def get_data_dirname(self):
return os.path.join(settings.MEDIA_DATA_ROOT, str(self.id))
Expand Down Expand Up @@ -599,7 +599,7 @@ def clean(self) -> None:
)

if self.stop_frame < self.start_frame:
raise ValidationError("stop_frame cannot be lesser than start_frame")
raise ValidationError("stop_frame cannot be less than start_frame")

return super().clean()

Expand Down
13 changes: 10 additions & 3 deletions cvat/apps/engine/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,10 +638,10 @@ def create(self, validated_data):
frame_selection_method = validated_data.pop("frame_selection_method", None)
if frame_selection_method == models.JobFrameSelectionMethod.RANDOM_UNIFORM:
frame_count = validated_data.pop("frame_count")
if size <= frame_count:
if size < frame_count:
raise serializers.ValidationError(
f"The number of frames requested ({frame_count}) must be lesser than "
f"the number of the task frames ({size})"
f"The number of frames requested ({frame_count}) "
f"must be not be greater than the number of the task frames ({size})"
)

seed = validated_data.pop("seed", None)
Expand All @@ -650,6 +650,13 @@ def create(self, validated_data):
# so here we specify it explicitly
from numpy import random
rng = random.Generator(random.MT19937(seed=seed))

if seed is not None and frame_count < size:
# Reproduce the old (a little bit incorrect) behavior that existed before
# https://github.com/opencv/cvat/pull/7126
# to make the old seed-based sequences reproducible
valid_frame_ids = [v for v in valid_frame_ids if v != task.data.stop_frame]

frames = rng.choice(
list(valid_frame_ids), size=frame_count, shuffle=False, replace=False
).tolist()
Expand Down
29 changes: 28 additions & 1 deletion tests/python/rest_api/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,34 @@ def test_can_create_gt_job_with_random_frames_and_seed(self, admin_user, task_id
with make_api_client(user) as api_client:
(gt_job_meta, _) = api_client.jobs_api.retrieve_data_meta(job_id)

assert gt_job_meta.included_frames == frame_ids
assert frame_ids == gt_job_meta.included_frames

@pytest.mark.parametrize("task_mode", ["annotation", "interpolation"])
def test_can_create_gt_job_with_all_frames(self, admin_user, tasks, jobs, task_mode):
user = admin_user
task = next(
t
for t in tasks
if t["mode"] == task_mode
and t["size"]
and not any(j for j in jobs if j["task_id"] == t["id"] and j["type"] == "ground_truth")
)
task_id = task["id"]

job_spec = {
"task_id": task_id,
"type": "ground_truth",
"frame_selection_method": "random_uniform",
"frame_count": task["size"],
}

response = self._test_create_job_ok(user, job_spec)
job_id = json.loads(response.data)["id"]

with make_api_client(user) as api_client:
(gt_job_meta, _) = api_client.jobs_api.retrieve_data_meta(job_id)

assert task["size"] == gt_job_meta.size

def test_can_create_no_more_than_1_gt_job(self, admin_user, jobs):
user = admin_user
Expand Down

0 comments on commit 81da411

Please sign in to comment.