Skip to content

Commit

Permalink
Fix an inheritance error
Browse files Browse the repository at this point in the history
  • Loading branch information
tmke8 committed Sep 25, 2023
1 parent 772156f commit 7c7889c
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 30 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ reportUnknownMemberType = "none"
reportUnknownVariableType = "none"
reportUnknownParameterType = "none"
reportMissingTypeArgument = "none"
reportUnnecessaryIsInstance = "none"
reportUnnecessaryIsInstance = "warning"
reportPrivateImportUsage = "none"
reportPrivateUsage = "none"
reportUnnecessaryTypeIgnoreComment = "warning"
Expand Down
2 changes: 1 addition & 1 deletion ranzen/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,5 @@ def enum_name_str(enum_class: type[E]) -> type[E]:
def __str__(self: Enum) -> str:
return self.name.lower()

enum_class.__str__ = __str__ # type: ignore
enum_class.__str__ = __str__
return enum_class
27 changes: 6 additions & 21 deletions ranzen/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,23 +200,14 @@ class AddDict(Dict[_KT, _VT], Addable):
"""

@overload
def __add__(
self: Self,
other: int,
) -> Self:
def __add__(self, other: int) -> Self:
...

@overload
def __add__(
self: Self,
other: dict[_KT, _VT2],
) -> AddDict[_KT, _VT | _VT2]:
def __add__(self, other: dict[_KT, _VT2]) -> AddDict[_KT, _VT | _VT2]:
...

def __add__(
self: Self,
other: int | dict[_KT, _VT2],
) -> Self | AddDict[_KT, _VT | _VT2]:
def __add__(self, other: int | dict[_KT, _VT2]) -> Self | AddDict[_KT, _VT | _VT2]:
# Allow ``other`` to be an integer, but specifying the identity function, for compatibility
# with th 'no-default' version of``sum``.
if isinstance(other, int):
Expand Down Expand Up @@ -245,20 +236,14 @@ def __add__(
return copy

@overload
def __radd__(
self: Self,
other: int,
) -> Self:
def __radd__(self, other: int) -> Self:
...

@overload
def __radd__(
self: Self,
other: dict[_KT, _VT2],
) -> AddDict[_KT, _VT | _VT2]:
def __radd__(self, other: dict[_KT, _VT2]) -> AddDict[_KT, _VT | _VT2]:
...

def __radd__(self: Self, other: int | dict[_KT, _VT2]) -> Self | AddDict[_KT, _VT | _VT2]:
def __radd__(self, other: int | dict[_KT, _VT2]) -> Self | AddDict[_KT, _VT | _VT2]:
return self + other


Expand Down
6 changes: 4 additions & 2 deletions ranzen/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def inf_generator(iterable: Iterable[T]) -> Iterator[T]:
"""Get DataLoaders in a single infinite loop.
for i, (x, y) in enumerate(inf_generator(train_loader))
:yield: elements from the given iterable
"""
iterator = iter(iterable)
# try to take one element to ensure that the iterator is not empty
Expand Down Expand Up @@ -84,8 +86,8 @@ def __enter__(self) -> Event:
Mimics torch.cuda.Event.
"""
if self._cuda:
self._event_start = torch.cuda.Event(enable_timing=True) # type: ignore
self._event_start.record() # type: ignore
self._event_start = torch.cuda.Event(enable_timing=True) # pyright: ignore
self._event_start.record() # pyright: ignore
else:
self._event_start = datetime.now()
return self
Expand Down
17 changes: 12 additions & 5 deletions tests/relay_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional, Protocol
Expand All @@ -12,17 +13,20 @@

class DummyOption(Protocol):
name: str
value: int | str

@property
def value(self) -> int | str:
...


@dataclass
class DummyOptionA(DummyOption):
class DummyOptionA:
name: str = "a"
value: int = 7


@dataclass
class DummyOptionB(DummyOption):
class DummyOptionB:
name: str = "b"
value: str = "5"

Expand Down Expand Up @@ -61,11 +65,14 @@ def run(self, raw_config: Optional[Dict[str, Any]] = None) -> None:
def test_relay(tmpdir: Path, clear_cache: bool, instantiate_recursively: bool) -> None:
args = ["", "attr1=foo", "attr2=bar"]
with patch("sys.argv", args):
ops1 = [
ops1: Sequence[type[DummyOption] | Option[DummyOption]] = [
Option(DummyOptionA, "foo"),
DummyOptionB,
]
ops2 = [DummyOptionA, Option(DummyOptionB, "bar")]
ops2: Sequence[type[DummyOption] | Option[DummyOption]] = [
DummyOptionA,
Option(DummyOptionB, "bar"),
]
options = {"attr1": ops1, "attr2": ops2}
for _ in range(2):
DummyRelay.with_hydra(
Expand Down

0 comments on commit 7c7889c

Please sign in to comment.