Skip to content

Commit

Permalink
Merge pull request #2353 from Rusteam/feature/cvat-custom-task-name
Browse files Browse the repository at this point in the history
Provide custom task name for CVAT
  • Loading branch information
brimoor authored Dec 5, 2022
2 parents 166e464 + 12208fb commit 0c11bb0
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/source/integrations/cvat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ provided:
otherwise a new project is created. By default, no project is used
- **project_id** (*None*): an optional ID of an existing CVAT project to
which to upload the annotation tasks. By default, no project is used
- **task_name** (None): an optional task name to use for the created CVAT task
- **occluded_attr** (*None*): an optional attribute name containing existing
occluded values and/or in which to store downloaded occluded values for all
objects in the annotation run
Expand Down
17 changes: 15 additions & 2 deletions fiftyone/utils/cvat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
| `voxel51.com <https://voxel51.com/>`_
|
"""
import math
from collections import defaultdict
from copy import copy, deepcopy
from datetime import datetime
Expand Down Expand Up @@ -3060,6 +3061,7 @@ class CVATBackendConfig(foua.AnnotationBackendConfig):
default, no project is used
project_id (None): an optional ID of an existing CVAT project to which
to upload the annotation tasks. By default, no project is used
task_name (None): an optional task name to use for the created CVAT task
occluded_attr (None): an optional attribute name containing existing
occluded values and/or in which to store downloaded occluded values
for all objects in the annotation run
Expand Down Expand Up @@ -3091,6 +3093,7 @@ def __init__(
job_reviewers=None,
project_name=None,
project_id=None,
task_name=None,
occluded_attr=None,
group_id_attr=None,
issue_tracker=None,
Expand All @@ -3109,6 +3112,7 @@ def __init__(
self.job_reviewers = job_reviewers
self.project_name = project_name
self.project_id = project_id
self.task_name = task_name
self.occluded_attr = occluded_attr
self.group_id_attr = group_id_attr
self.issue_tracker = issue_tracker
Expand Down Expand Up @@ -4226,6 +4230,7 @@ def upload_samples(self, samples, backend):

num_samples = len(samples)
batch_size = self._get_batch_size(samples, task_size)
num_batches = math.ceil(num_samples / batch_size)

samples.compute_metadata()

Expand Down Expand Up @@ -4290,8 +4295,16 @@ def upload_samples(self, samples, backend):
project_id = self.create_project(project_name, cvat_schema)
project_ids.append(project_id)

_dataset_name = samples_batch._dataset.name.replace(" ", "_")
task_name = "FiftyOne_%s" % _dataset_name
if config.task_name is None:
_dataset_name = samples_batch._dataset.name.replace(
" ", "_"
)
task_name = f"FiftyOne_{_dataset_name}"
else:
task_name = config.task_name
# append task number when multiple tasks are created
if num_batches > 1:
task_name += f"_{idx + 1}"

(
task_id,
Expand Down
5 changes: 4 additions & 1 deletion tests/intensive/cvat_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ def test_task_creation_arguments(self):

anno_key = "anno_key"
bug_tracker = "test_tracker"
task_name = "test_task"
results = dataset.annotate(
anno_key,
backend="cvat",
Expand All @@ -427,15 +428,17 @@ def test_task_creation_arguments(self):
job_assignees=users,
job_reviewers=users,
issue_tracker=bug_tracker,
task_name=task_name,
)
task_ids = results.task_ids
with results:
api = results.connect_to_api()
self.assertEqual(len(task_ids), 2)
for task_id in task_ids:
for idx, task_id in enumerate(task_ids):
task_json = api.get(api.task_url(task_id)).json()
self.assertEqual(task_json["bug_tracker"], bug_tracker)
self.assertEqual(task_json["segment_size"], 1)
self.assertEqual(task_json["name"], f"{task_name}_{idx + 1}")
if user is not None:
self.assertEqual(task_json["assignee"]["username"], user)
for job in api.get(api.jobs_url(task_id)).json():
Expand Down

0 comments on commit 0c11bb0

Please sign in to comment.