Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Another round of mypy fixes #415

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions kr8s/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class Api(object):
"""

_asyncio = True
_instances = {}
_instances: Dict[str, weakref.WeakValueDictionary] = {}

def __init__(self, **kwargs) -> None:
if not kwargs.pop("bypass_factory", False):
Expand All @@ -61,7 +61,7 @@ def __init__(self, **kwargs) -> None:
self._url = kwargs.get("url")
self._kubeconfig = kwargs.get("kubeconfig")
self._serviceaccount = kwargs.get("serviceaccount")
self._session = None
self._session: Optional[httpx.AsyncClient] = None
self._timeout = None
self.auth = KubeAuth(
url=self._url,
Expand Down Expand Up @@ -159,11 +159,13 @@ async def call_api(
while True:
try:
if stream:
assert self._session
async with self._session.stream(**kwargs) as response:
if raise_for_status:
response.raise_for_status()
yield response
else:
assert self._session
response = await self._session.request(**kwargs)
if raise_for_status:
response.raise_for_status()
Expand Down Expand Up @@ -226,9 +228,9 @@ async def open_websocket(
client=self._session, **kwargs
) as response:
yield response
except httpx_ws.HTTPXWSException as e:
except httpx_ws.WebSocketDisconnect as e:
if e.code and e.code != 1000:
if e.status in (401, 403) and auth_attempts < 3:
if e.code in (401, 403) and auth_attempts < 3:
auth_attempts += 1
await self.auth.reauthenticate()
continue
Expand Down Expand Up @@ -289,14 +291,14 @@ async def async_whoami(self):
@contextlib.asynccontextmanager
async def async_get_kind(
self,
kind: Union[str, type],
kind: Union[str, Type[APIObject]],
namespace: Optional[str] = None,
label_selector: Optional[Union[str, Dict]] = None,
field_selector: Optional[Union[str, Dict]] = None,
params: Optional[dict] = None,
watch: bool = False,
**kwargs,
) -> AsyncGenerator[Tuple[Type[APIObject], dict], None]:
) -> AsyncGenerator[Tuple[Type[APIObject], httpx.Response], None]:
"""Get a Kubernetes resource."""
from ._objects import get_class

Expand Down Expand Up @@ -328,7 +330,8 @@ async def async_get_kind(
break
except ServerError as e:
warnings.warn(str(e))
obj_cls = get_class(kind, _asyncio=self._asyncio)
if isinstance(kind, str):
obj_cls = get_class(kind, _asyncio=self._asyncio)
params = params or None
async with self.call_api(
method="GET",
Expand Down
10 changes: 5 additions & 5 deletions kr8s/_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def __init__(
resource: APIObject,
command: List[str],
container: Optional[str] = None,
stdin: Optional[Union[str | BinaryIO]] = None,
stdout: Optional[Union[str | BinaryIO]] = None,
stderr: Optional[Union[str | BinaryIO]] = None,
stdin: Optional[Union[str, BinaryIO]] = None,
stdout: Optional[BinaryIO] = None,
stderr: Optional[BinaryIO] = None,
check: bool = True,
capture_output: bool = True,
) -> None:
Expand All @@ -46,7 +46,7 @@ def __init__(
self.args = command
self.stdout = b""
self.stderr = b""
self.returncode = None
self.returncode: int
self.check = check

@asynccontextmanager
Expand Down Expand Up @@ -114,7 +114,7 @@ async def run(
yield self

async def wait(self) -> CompletedExec:
return self.returncode
return self.as_completed()

def as_completed(self) -> CompletedExec:
return CompletedExec(
Expand Down
32 changes: 17 additions & 15 deletions kr8s/_portforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import socket
import sys
from contextlib import asynccontextmanager, suppress
from typing import TYPE_CHECKING, BinaryIO, List, Optional
from typing import TYPE_CHECKING, AsyncGenerator, List, Optional

import anyio
import httpx_ws
Expand All @@ -22,7 +22,7 @@
if sys.version_info < (3, 12, 1):
# contextlib.supress() in Python 3.12.1 supprts ExceptionGroups
# For older versions, we use the exceptiongroup backport
from exceptiongroup import suppress # noqa: F811
from exceptiongroup import suppress # type: ignore # noqa: F811


class PortForward:
Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(
"see https://github.com/kr8s-org/kr8s/issues/104"
)
self.server = None
self.servers = []
self.servers: List[asyncio.Server] = []
self.remote_port = remote_port
self.local_port = local_port if local_port is not None else 0
if isinstance(address, str):
Expand All @@ -102,10 +102,10 @@ def __init__(
self._resource = resource
self.pod = None
self._loop = asyncio.get_event_loop()
self._tasks = []
self._tasks: List[asyncio.Task] = []
self._run_task = None
self._bg_future = None
self._bg_task = None
self._bg_future: Optional[asyncio.Future] = None
self._bg_task: Optional[asyncio.Task] = None

async def __aenter__(self, *args, **kwargs):
self._run_task = self._run()
Expand All @@ -117,7 +117,7 @@ async def __aexit__(self, *args, **kwargs):
async def start(self) -> int:
"""Start a background task with the port forward running."""
if self._bg_task is not None:
return
return self.local_port

async def f():
self._bg_future = self._loop.create_future()
Expand All @@ -132,7 +132,8 @@ async def f():

async def stop(self) -> None:
"""Stop the background task."""
self._bg_future.set_result(None)
if self._bg_future:
self._bg_future.set_result(None)
self._bg_task = None

async def run_forever(self) -> None:
Expand All @@ -150,7 +151,7 @@ async def run_forever(self) -> None:
await server.serve_forever()

@asynccontextmanager
async def _run(self) -> int:
async def _run(self) -> AsyncGenerator[int, None]:
"""Start the port forward for multiple bind addresses and yield the local port."""
if self.local_port == 0:
self.local_port = self._find_available_port()
Expand All @@ -173,7 +174,7 @@ async def _run(self) -> int:
await server.wait_closed()
self.servers.remove(server)

async def _select_pod(self) -> object:
async def _select_pod(self) -> APIObject:
"""Select a Pod to forward to."""
from ._objects import Pod

Expand All @@ -184,10 +185,11 @@ async def _select_pod(self) -> object:
try:
return random.choice(await self._resource.async_ready_pods())
except IndexError:
raise RuntimeError("No ready pods found")
pass
raise RuntimeError("No ready pods found")

@asynccontextmanager
async def _connect_websocket(self) -> None:
async def _connect_websocket(self):
"""Connect to the Kubernetes portforward websocket."""
connection_attempts = 0
while True:
Expand All @@ -213,7 +215,7 @@ async def _connect_websocket(self) -> None:
raise ConnectionClosedError("Unable to connect to Pod") from e
await anyio.sleep(0.1 * connection_attempts)

async def _sync_sockets(self, reader: BinaryIO, writer: BinaryIO) -> None:
async def _sync_sockets(self, reader, writer) -> None:
"""Start two tasks to copy bytes from tcp=>websocket and websocket=>tcp."""
try:
async with self._connect_websocket() as ws:
Expand All @@ -224,7 +226,7 @@ async def _sync_sockets(self, reader: BinaryIO, writer: BinaryIO) -> None:
finally:
writer.close()

async def _tcp_to_ws(self, ws, reader: BinaryIO) -> None:
async def _tcp_to_ws(self, ws, reader) -> None:
while True:
data = await reader.read(1024 * 1024)
if not data:
Expand All @@ -237,7 +239,7 @@ async def _tcp_to_ws(self, ws, reader: BinaryIO) -> None:
except ConnectionResetError:
raise ConnectionClosedError("Websocket closed")

async def _ws_to_tcp(self, ws, writer: BinaryIO) -> None:
async def _ws_to_tcp(self, ws, writer) -> None:
channels = []
while True:
message = await ws.receive_bytes()
Expand Down
Loading