Skip to content

Commit

Permalink
Another round of mypy fixes (#415)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobtomlinson authored Jun 25, 2024
1 parent b56456f commit 853d3f5
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 27 deletions.
17 changes: 10 additions & 7 deletions kr8s/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class Api(object):
"""

_asyncio = True
_instances = {}
_instances: Dict[str, weakref.WeakValueDictionary] = {}

def __init__(self, **kwargs) -> None:
if not kwargs.pop("bypass_factory", False):
Expand All @@ -61,7 +61,7 @@ def __init__(self, **kwargs) -> None:
self._url = kwargs.get("url")
self._kubeconfig = kwargs.get("kubeconfig")
self._serviceaccount = kwargs.get("serviceaccount")
self._session = None
self._session: Optional[httpx.AsyncClient] = None
self._timeout = None
self.auth = KubeAuth(
url=self._url,
Expand Down Expand Up @@ -159,11 +159,13 @@ async def call_api(
while True:
try:
if stream:
assert self._session
async with self._session.stream(**kwargs) as response:
if raise_for_status:
response.raise_for_status()
yield response
else:
assert self._session
response = await self._session.request(**kwargs)
if raise_for_status:
response.raise_for_status()
Expand Down Expand Up @@ -226,9 +228,9 @@ async def open_websocket(
client=self._session, **kwargs
) as response:
yield response
except httpx_ws.HTTPXWSException as e:
except httpx_ws.WebSocketDisconnect as e:
if e.code and e.code != 1000:
if e.status in (401, 403) and auth_attempts < 3:
if e.code in (401, 403) and auth_attempts < 3:
auth_attempts += 1
await self.auth.reauthenticate()
continue
Expand Down Expand Up @@ -289,14 +291,14 @@ async def async_whoami(self):
@contextlib.asynccontextmanager
async def async_get_kind(
self,
kind: Union[str, type],
kind: Union[str, Type[APIObject]],
namespace: Optional[str] = None,
label_selector: Optional[Union[str, Dict]] = None,
field_selector: Optional[Union[str, Dict]] = None,
params: Optional[dict] = None,
watch: bool = False,
**kwargs,
) -> AsyncGenerator[Tuple[Type[APIObject], dict], None]:
) -> AsyncGenerator[Tuple[Type[APIObject], httpx.Response], None]:
"""Get a Kubernetes resource."""
from ._objects import get_class

Expand Down Expand Up @@ -328,7 +330,8 @@ async def async_get_kind(
break
except ServerError as e:
warnings.warn(str(e))
obj_cls = get_class(kind, _asyncio=self._asyncio)
if isinstance(kind, str):
obj_cls = get_class(kind, _asyncio=self._asyncio)
params = params or None
async with self.call_api(
method="GET",
Expand Down
10 changes: 5 additions & 5 deletions kr8s/_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def __init__(
resource: APIObject,
command: List[str],
container: Optional[str] = None,
stdin: Optional[Union[str | BinaryIO]] = None,
stdout: Optional[Union[str | BinaryIO]] = None,
stderr: Optional[Union[str | BinaryIO]] = None,
stdin: Optional[Union[str, BinaryIO]] = None,
stdout: Optional[BinaryIO] = None,
stderr: Optional[BinaryIO] = None,
check: bool = True,
capture_output: bool = True,
) -> None:
Expand All @@ -46,7 +46,7 @@ def __init__(
self.args = command
self.stdout = b""
self.stderr = b""
self.returncode = None
self.returncode: int
self.check = check

@asynccontextmanager
Expand Down Expand Up @@ -114,7 +114,7 @@ async def run(
yield self

async def wait(self) -> CompletedExec:
return self.returncode
return self.as_completed()

def as_completed(self) -> CompletedExec:
return CompletedExec(
Expand Down
32 changes: 17 additions & 15 deletions kr8s/_portforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import socket
import sys
from contextlib import asynccontextmanager, suppress
from typing import TYPE_CHECKING, BinaryIO, List, Optional
from typing import TYPE_CHECKING, AsyncGenerator, List, Optional

import anyio
import httpx_ws
Expand All @@ -22,7 +22,7 @@
if sys.version_info < (3, 12, 1):
# contextlib.supress() in Python 3.12.1 supprts ExceptionGroups
# For older versions, we use the exceptiongroup backport
from exceptiongroup import suppress # noqa: F811
from exceptiongroup import suppress # type: ignore # noqa: F811


class PortForward:
Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(
"see https://github.com/kr8s-org/kr8s/issues/104"
)
self.server = None
self.servers = []
self.servers: List[asyncio.Server] = []
self.remote_port = remote_port
self.local_port = local_port if local_port is not None else 0
if isinstance(address, str):
Expand All @@ -102,10 +102,10 @@ def __init__(
self._resource = resource
self.pod = None
self._loop = asyncio.get_event_loop()
self._tasks = []
self._tasks: List[asyncio.Task] = []
self._run_task = None
self._bg_future = None
self._bg_task = None
self._bg_future: Optional[asyncio.Future] = None
self._bg_task: Optional[asyncio.Task] = None

async def __aenter__(self, *args, **kwargs):
self._run_task = self._run()
Expand All @@ -117,7 +117,7 @@ async def __aexit__(self, *args, **kwargs):
async def start(self) -> int:
"""Start a background task with the port forward running."""
if self._bg_task is not None:
return
return self.local_port

async def f():
self._bg_future = self._loop.create_future()
Expand All @@ -132,7 +132,8 @@ async def f():

async def stop(self) -> None:
"""Stop the background task."""
self._bg_future.set_result(None)
if self._bg_future:
self._bg_future.set_result(None)
self._bg_task = None

async def run_forever(self) -> None:
Expand All @@ -150,7 +151,7 @@ async def run_forever(self) -> None:
await server.serve_forever()

@asynccontextmanager
async def _run(self) -> int:
async def _run(self) -> AsyncGenerator[int, None]:
"""Start the port forward for multiple bind addresses and yield the local port."""
if self.local_port == 0:
self.local_port = self._find_available_port()
Expand All @@ -173,7 +174,7 @@ async def _run(self) -> int:
await server.wait_closed()
self.servers.remove(server)

async def _select_pod(self) -> object:
async def _select_pod(self) -> APIObject:
"""Select a Pod to forward to."""
from ._objects import Pod

Expand All @@ -184,10 +185,11 @@ async def _select_pod(self) -> object:
try:
return random.choice(await self._resource.async_ready_pods())
except IndexError:
raise RuntimeError("No ready pods found")
pass
raise RuntimeError("No ready pods found")

@asynccontextmanager
async def _connect_websocket(self) -> None:
async def _connect_websocket(self):
"""Connect to the Kubernetes portforward websocket."""
connection_attempts = 0
while True:
Expand All @@ -213,7 +215,7 @@ async def _connect_websocket(self) -> None:
raise ConnectionClosedError("Unable to connect to Pod") from e
await anyio.sleep(0.1 * connection_attempts)

async def _sync_sockets(self, reader: BinaryIO, writer: BinaryIO) -> None:
async def _sync_sockets(self, reader, writer) -> None:
"""Start two tasks to copy bytes from tcp=>websocket and websocket=>tcp."""
try:
async with self._connect_websocket() as ws:
Expand All @@ -224,7 +226,7 @@ async def _sync_sockets(self, reader: BinaryIO, writer: BinaryIO) -> None:
finally:
writer.close()

async def _tcp_to_ws(self, ws, reader: BinaryIO) -> None:
async def _tcp_to_ws(self, ws, reader) -> None:
while True:
data = await reader.read(1024 * 1024)
if not data:
Expand All @@ -237,7 +239,7 @@ async def _tcp_to_ws(self, ws, reader: BinaryIO) -> None:
except ConnectionResetError:
raise ConnectionClosedError("Websocket closed")

async def _ws_to_tcp(self, ws, writer: BinaryIO) -> None:
async def _ws_to_tcp(self, ws, writer) -> None:
channels = []
while True:
message = await ws.receive_bytes()
Expand Down

0 comments on commit 853d3f5

Please sign in to comment.