Skip to content

Commit

Permalink
Merge branch 'develop' into zm/compare-point-groups-in-image-space
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiltsov-max authored Nov 1, 2024
2 parents bc141e7 + ceba1b7 commit 8826c0f
Show file tree
Hide file tree
Showing 34 changed files with 182 additions and 146 deletions.
4 changes: 4 additions & 0 deletions changelog.d/20241016_133802_sekachev.bs_fixed_propagation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
### Fixed

- Propagation creates copies on non-existing frames in a ground truth job
(<https://github.com/cvat-ai/cvat/pull/8550>)
8 changes: 5 additions & 3 deletions cvat-core/src/annotations-actions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// SPDX-License-Identifier: MIT

import { omit, throttle } from 'lodash';
import { omit, range, throttle } from 'lodash';
import { ArgumentError } from './exceptions';
import { SerializedCollection, SerializedShape } from './server-response-types';
import { Job, Task } from './session';
Expand Down Expand Up @@ -107,13 +107,15 @@ class PropagateShapes extends BaseSingleFrameAction {
}

public async run(
instance,
instance: Job | Task,
{ collection: { shapes }, frameData: { number } },
): Promise<SingleFrameActionOutput> {
if (number === this.#targetFrame) {
return { collection: { shapes } };
}
const propagatedShapes = propagateShapes<SerializedShape>(shapes, number, this.#targetFrame);

const frameNumbers = instance instanceof Job ? await instance.frames.frameNumbers() : range(0, instance.size);
const propagatedShapes = propagateShapes<SerializedShape>(shapes, number, this.#targetFrame, frameNumbers);
return { collection: { shapes: [...shapes, ...propagatedShapes] } };
}

Expand Down
13 changes: 7 additions & 6 deletions cvat-core/src/frames.ts
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,11 @@ export class FramesMetaData {
return Math.floor(this.getFrameIndex(dataFrameNumber) / this.chunkSize);
}

getSegmentFrameNumbers(jobStartFrame: number): number[] {
const frames = this.getDataFrameNumbers();
return frames.map((frame) => this.getJobRelativeFrameNumber(frame) + jobStartFrame);
}

getDataFrameNumbers(): number[] {
if (this.includedFrames) {
return [...this.includedFrames];
Expand Down Expand Up @@ -348,9 +353,7 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', {
const requestId = +_.uniqueId();
const requestedDataFrameNumber = meta.getDataFrameNumber(this.number - jobStartFrame);
const chunkIndex = meta.getFrameChunkIndex(requestedDataFrameNumber);
const segmentFrameNumbers = meta.getDataFrameNumbers().map((dataFrameNumber: number) => (
meta.getJobRelativeFrameNumber(dataFrameNumber) + jobStartFrame
));
const segmentFrameNumbers = meta.getSegmentFrameNumbers(jobStartFrame);
const frame = provider.frame(this.number);

function findTheNextNotDecodedChunk(currentFrameIndex: number): number | null {
Expand Down Expand Up @@ -889,9 +892,7 @@ export function getJobFrameNumbers(jobID: number): number[] {
}

const { meta, jobStartFrame } = frameDataCache[jobID];
return meta.getDataFrameNumbers().map((dataFrameNumber: number): number => (
meta.getJobRelativeFrameNumber(dataFrameNumber) + jobStartFrame
));
return meta.getSegmentFrameNumbers(jobStartFrame);
}

export function clear(jobID: number): void {
Expand Down
15 changes: 12 additions & 3 deletions cvat-core/src/object-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ export function rle2Mask(rle: number[], width: number, height: number): number[]
}

export function propagateShapes<T extends SerializedShape | ObjectState>(
shapes: T[], from: number, to: number,
shapes: T[], from: number, to: number, frameNumbers: number[],
): T[] {
const getCopy = (shape: T): SerializedShape | SerializedData => {
if (shape instanceof ObjectState) {
Expand Down Expand Up @@ -397,9 +397,18 @@ export function propagateShapes<T extends SerializedShape | ObjectState>(
};
};

const targetFrameNumbers = frameNumbers.filter(
(frameNumber: number) => frameNumber >= Math.min(from, to) &&
frameNumber <= Math.max(from, to) &&
frameNumber !== from,
);

const states: T[] = [];
const sign = Math.sign(to - from);
for (let frame = from + sign; sign > 0 ? frame <= to : frame >= to; frame += sign) {
for (const frame of targetFrameNumbers) {
if (frame === from) {
continue;
}

for (const shape of shapes) {
const copy = getCopy(shape);

Expand Down
8 changes: 8 additions & 0 deletions cvat-core/src/session-implementation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,14 @@ export function implementTask(Task: typeof TaskClass): typeof TaskClass {
},
});

Object.defineProperty(Task.prototype.frames.frameNumbers, 'implementation', {
value: function includedFramesImplementation(
this: TaskClass,
): ReturnType<typeof TaskClass.prototype.frames.frameNumbers> {
throw new Error('Not implemented for Task');
},
});

Object.defineProperty(Task.prototype.frames.preview, 'implementation', {
value: function previewImplementation(
this: TaskClass,
Expand Down
8 changes: 4 additions & 4 deletions cvat-core/src/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,8 @@ export class Session {
};

public actions: {
undo: (count: number) => Promise<number[]>;
redo: (count: number) => Promise<number[]>;
undo: (count?: number) => Promise<number[]>;
redo: (count?: number) => Promise<number[]>;
freeze: (frozen: boolean) => Promise<void>;
clear: () => Promise<void>;
get: () => Promise<{ undo: [HistoryActions, number][], redo: [HistoryActions, number][] }>;
Expand Down Expand Up @@ -403,8 +403,8 @@ export class Session {
public logger: {
log: (
scope: Parameters<typeof logger.log>[0],
payload: Parameters<typeof logger.log>[1],
wait: Parameters<typeof logger.log>[2],
payload?: Parameters<typeof logger.log>[1],
wait?: Parameters<typeof logger.log>[2],
) => ReturnType<typeof logger.log>;
};

Expand Down
5 changes: 3 additions & 2 deletions cvat-sdk/cvat_sdk/auto_annotation/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# SPDX-License-Identifier: MIT

import logging
from typing import List, Mapping, Optional, Sequence
from collections.abc import Mapping, Sequence
from typing import Optional

import attrs

Expand Down Expand Up @@ -119,7 +120,7 @@ def __init__(
fun_label, ds_labels_by_name
)

def validate_and_remap(self, shapes: List[models.LabeledShapeRequest], ds_frame: int) -> None:
def validate_and_remap(self, shapes: list[models.LabeledShapeRequest], ds_frame: int) -> None:
new_shapes = []

for shape in shapes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# SPDX-License-Identifier: MIT

from functools import cached_property
from typing import List

import PIL.Image
import torchvision.models
Expand All @@ -28,7 +27,7 @@ def spec(self) -> cvataa.DetectionFunctionSpec:
]
)

def detect(self, context, image: PIL.Image.Image) -> List[models.LabeledShapeRequest]:
def detect(self, context, image: PIL.Image.Image) -> list[models.LabeledShapeRequest]:
results = self._model([self._transforms(image)])

return [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# SPDX-License-Identifier: MIT

from functools import cached_property
from typing import List

import PIL.Image
import torchvision.models
Expand Down Expand Up @@ -36,7 +35,7 @@ def spec(self) -> cvataa.DetectionFunctionSpec:
]
)

def detect(self, context, image: PIL.Image.Image) -> List[models.LabeledShapeRequest]:
def detect(self, context, image: PIL.Image.Image) -> list[models.LabeledShapeRequest]:
results = self._model([self._transforms(image)])

return [
Expand Down
5 changes: 3 additions & 2 deletions cvat-sdk/cvat_sdk/auto_annotation/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# SPDX-License-Identifier: MIT

import abc
from typing import List, Protocol, Sequence
from collections.abc import Sequence
from typing import Protocol

import attrs
import PIL.Image
Expand Down Expand Up @@ -79,7 +80,7 @@ def spec(self) -> DetectionFunctionSpec:

def detect(
self, context: DetectionFunctionContext, image: PIL.Image.Image
) -> List[models.LabeledShapeRequest]:
) -> list[models.LabeledShapeRequest]:
"""
Detects objects on the supplied image and returns the results.
Expand Down
15 changes: 8 additions & 7 deletions cvat-sdk/cvat_sdk/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

import logging
import urllib.parse
from collections.abc import Generator, Sequence
from contextlib import contextmanager, suppress
from pathlib import Path
from time import sleep
from typing import Any, Dict, Generator, Optional, Sequence, Tuple, TypeVar
from typing import Any, Optional, TypeVar

import attrs
import packaging.specifiers as specifiers
Expand Down Expand Up @@ -95,7 +96,7 @@ def __init__(
if check_server_version:
self.check_server_version()

self._repos: Dict[str, Repo] = {}
self._repos: dict[str, Repo] = {}
"""A cache for created Repository instances"""

_ORG_SLUG_HEADER = "X-Organization"
Expand Down Expand Up @@ -183,7 +184,7 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
def close(self) -> None:
return self.__exit__(None, None, None)

def login(self, credentials: Tuple[str, str]) -> None:
def login(self, credentials: tuple[str, str]) -> None:
(auth, _) = self.api_client.auth_api.create_login(
models.LoginSerializerExRequest(username=credentials[0], password=credentials[1])
)
Expand Down Expand Up @@ -211,7 +212,7 @@ def wait_for_completion(
rq_id: str,
*,
status_check_period: Optional[int] = None,
) -> Tuple[models.Request, urllib3.HTTPResponse]:
) -> tuple[models.Request, urllib3.HTTPResponse]:
if status_check_period is None:
status_check_period = self.config.status_check_period

Expand Down Expand Up @@ -319,8 +320,8 @@ def make_endpoint_url(
path: str,
*,
psub: Optional[Sequence[Any]] = None,
kwsub: Optional[Dict[str, Any]] = None,
query_params: Optional[Dict[str, Any]] = None,
kwsub: Optional[dict[str, Any]] = None,
query_params: Optional[dict[str, Any]] = None,
) -> str:
url = self.host + path
if psub or kwsub:
Expand All @@ -331,7 +332,7 @@ def make_endpoint_url(


def make_client(
host: str, *, port: Optional[int] = None, credentials: Optional[Tuple[str, str]] = None
host: str, *, port: Optional[int] = None, credentials: Optional[tuple[str, str]] = None
) -> Client:
url = host.rstrip("/")
if port:
Expand Down
10 changes: 5 additions & 5 deletions cvat-sdk/cvat_sdk/core/downloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json
from contextlib import closing
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Optional

from cvat_sdk.api_client.api_client import Endpoint
from cvat_sdk.core.helpers import expect_status
Expand Down Expand Up @@ -80,8 +80,8 @@ def prepare_file(
self,
endpoint: Endpoint,
*,
url_params: Optional[Dict[str, Any]] = None,
query_params: Optional[Dict[str, Any]] = None,
url_params: Optional[dict[str, Any]] = None,
query_params: Optional[dict[str, Any]] = None,
status_check_period: Optional[int] = None,
):
client = self._client
Expand Down Expand Up @@ -118,8 +118,8 @@ def prepare_and_download_file_from_endpoint(
endpoint: Endpoint,
filename: Path,
*,
url_params: Optional[Dict[str, Any]] = None,
query_params: Optional[Dict[str, Any]] = None,
url_params: Optional[dict[str, Any]] = None,
query_params: Optional[dict[str, Any]] = None,
pbar: Optional[ProgressReporter] = None,
status_check_period: Optional[int] = None,
):
Expand Down
5 changes: 3 additions & 2 deletions cvat-sdk/cvat_sdk/core/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import io
import json
import warnings
from typing import Any, Dict, Iterable, List, Optional, Union
from collections.abc import Iterable
from typing import Any, Optional, Union

import tqdm
import urllib3
Expand All @@ -19,7 +20,7 @@

def get_paginated_collection(
endpoint: Endpoint, *, return_json: bool = False, **kwargs
) -> Union[List, List[Dict[str, Any]]]:
) -> Union[list, list[dict[str, Any]]]:
"""
Accumulates results from all the pages
"""
Expand Down
3 changes: 2 additions & 1 deletion cvat-sdk/cvat_sdk/core/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from __future__ import annotations

import contextlib
from typing import Generator, Iterable, Optional, TypeVar
from collections.abc import Generator, Iterable
from typing import Optional, TypeVar

T = TypeVar("T")

Expand Down
3 changes: 2 additions & 1 deletion cvat-sdk/cvat_sdk/core/proxies/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
# SPDX-License-Identifier: MIT

from abc import ABC
from collections.abc import Sequence
from enum import Enum
from typing import Optional, Sequence
from typing import Optional

from cvat_sdk import models
from cvat_sdk.core.proxies.model_proxy import _EntityT
Expand Down
4 changes: 1 addition & 3 deletions cvat-sdk/cvat_sdk/core/proxies/issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

from __future__ import annotations

from typing import List

from cvat_sdk.api_client import apis, models
from cvat_sdk.core.helpers import get_paginated_collection
from cvat_sdk.core.proxies.model_proxy import (
Expand Down Expand Up @@ -53,7 +51,7 @@ class Issue(
):
_model_partial_update_arg = "patched_issue_write_request"

def get_comments(self) -> List[Comment]:
def get_comments(self) -> list[Comment]:
return [
Comment(self._client, m)
for m in get_paginated_collection(
Expand Down
11 changes: 6 additions & 5 deletions cvat-sdk/cvat_sdk/core/proxies/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

import io
import mimetypes
from collections.abc import Sequence
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Sequence
from typing import TYPE_CHECKING, Optional

from PIL import Image

Expand Down Expand Up @@ -93,7 +94,7 @@ def download_frames(
outdir: StrPath = ".",
quality: str = "original",
filename_pattern: str = "frame_{frame_id:06d}{frame_ext}",
) -> Optional[List[Image.Image]]:
) -> Optional[list[Image.Image]]:
"""
Download the requested frame numbers for a job and save images as outdir/filename_pattern
"""
Expand Down Expand Up @@ -125,12 +126,12 @@ def get_meta(self) -> models.IDataMetaRead:
(meta, _) = self.api.retrieve_data_meta(self.id)
return meta

def get_labels(self) -> List[models.ILabel]:
def get_labels(self) -> list[models.ILabel]:
return get_paginated_collection(
self._client.api_client.labels_api.list_endpoint, job_id=self.id
)

def get_frames_info(self) -> List[models.IFrameMeta]:
def get_frames_info(self) -> list[models.IFrameMeta]:
return self.get_meta().frames

def remove_frames_by_ids(self, ids: Sequence[int]) -> None:
Expand All @@ -141,7 +142,7 @@ def remove_frames_by_ids(self, ids: Sequence[int]) -> None:
),
)

def get_issues(self) -> List[Issue]:
def get_issues(self) -> list[Issue]:
return [
Issue(self._client, m)
for m in get_paginated_collection(
Expand Down
Loading

0 comments on commit 8826c0f

Please sign in to comment.