Skip to content

Commit

Permalink
Fix return type annotations for functions used with @contextmanager (#…
Browse files Browse the repository at this point in the history
…8617)

Some of them are annotated with an `Iterator` return type. However...

It just occurred to me that `@contextmanager` cannot work with a
function that returns a plain iterator, since it relies on the generator
class's `throw` method. `contextmanager` is defined in typeshed as
accepting an iterator-returning function, but that appears to be a bug:
<python/typeshed#2772>.

Change all such annotations to a `Generator` type instead.

Some annotations are also broken in other ways; fix them too.
  • Loading branch information
SpecLad authored Oct 30, 2024
1 parent 4dd7f22 commit 23e1e0a
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 9 deletions.
4 changes: 2 additions & 2 deletions cvat-sdk/cvat_sdk/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from contextlib import contextmanager, suppress
from pathlib import Path
from time import sleep
from typing import Any, Dict, Iterator, Optional, Sequence, Tuple, TypeVar
from typing import Any, Dict, Generator, Optional, Sequence, Tuple, TypeVar

import attrs
import packaging.specifiers as specifiers
Expand Down Expand Up @@ -121,7 +121,7 @@ def organization_slug(self, org_slug: Optional[str]):
self.api_client.default_headers[self._ORG_SLUG_HEADER] = org_slug

@contextmanager
def organization_context(self, slug: str) -> Iterator[None]:
def organization_context(self, slug: str) -> Generator[None, None, None]:
prev_slug = self.organization_slug
self.organization_slug = slug
try:
Expand Down
4 changes: 2 additions & 2 deletions cvat-sdk/cvat_sdk/core/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

import contextlib
from typing import ContextManager, Iterable, Optional, TypeVar
from typing import Generator, Iterable, Optional, TypeVar

T = TypeVar("T")

Expand All @@ -26,7 +26,7 @@ class ProgressReporter:
"""

@contextlib.contextmanager
def task(self, **kwargs) -> ContextManager[None]:
def task(self, **kwargs) -> Generator[None, None, None]:
"""
Returns a context manager that represents a long-running task
for which progress can be reported.
Expand Down
4 changes: 2 additions & 2 deletions cvat-sdk/cvat_sdk/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
BinaryIO,
ContextManager,
Dict,
Iterator,
Generator,
Literal,
Sequence,
TextIO,
Expand Down Expand Up @@ -43,7 +43,7 @@ def atomic_writer(
@contextlib.contextmanager
def atomic_writer(
path: Union[os.PathLike, str], mode: Literal["w", "wb"], encoding: str = "UTF-8"
) -> Iterator[IO]:
) -> Generator[IO, None, None]:
"""
Returns a context manager that, when entered, returns a handle to a temporary
file opened with the specified `mode` and `encoding`. If the context manager
Expand Down
4 changes: 3 additions & 1 deletion cvat/apps/engine/media_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,9 @@ def extract(self):

class _AvVideoReading:
@contextmanager
def read_av_container(self, source: Union[str, io.BytesIO]) -> av.container.InputContainer:
def read_av_container(
self, source: Union[str, io.BytesIO]
) -> Generator[av.container.InputContainer, None, None]:
if isinstance(source, io.BytesIO):
source.seek(0) # required for re-reading

Expand Down
4 changes: 2 additions & 2 deletions tests/python/cli/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import threading
import unittest
from pathlib import Path
from typing import Any, Dict, Iterator, List, Union
from typing import Any, Dict, Generator, List, Union

import requests

Expand Down Expand Up @@ -39,7 +39,7 @@ def generate_images(dst_dir: Path, count: int) -> List[Path]:


@contextlib.contextmanager
def https_reverse_proxy() -> Iterator[str]:
def https_reverse_proxy() -> Generator[str, None, None]:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
cert_dir = Path(__file__).parent
Expand Down

0 comments on commit 23e1e0a

Please sign in to comment.