diff --git a/pyproject.toml b/pyproject.toml index 50e4c4d..42e63fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "xgpu" -version = "0.5.0" +version = "0.5.1" readme = "README.md" requires-python = ">=3.7" dependencies = ["cffi"] @@ -70,4 +70,4 @@ before-build = "pip install cffi && cd xgpu && python _build_ext.py" skip = ["*-win32", "*-manylinux_i686", "*musl*"] # just make sure the library imports for now -test-command = 'python -c "import xgpu"' +test-command = 'python -c "import xgpu; import xgpu.conveniences"' diff --git a/xgpu/conveniences.py b/xgpu/conveniences.py index 22155c4..987c6b9 100644 --- a/xgpu/conveniences.py +++ b/xgpu/conveniences.py @@ -1,5 +1,5 @@ import time -from typing import Optional +from typing import Optional, Tuple, List from . import bindings as xg from .extensions import XAdapter, XDevice @@ -27,12 +27,18 @@ def get_adapter( instance: Optional[xg.Instance] = None, power=xg.PowerPreference.HighPerformance, surface: Optional[xg.Surface] = None, -) -> tuple[XAdapter, xg.Instance]: - adapter: list[Optional[xg.Adapter]] = [None] + timeout: float = 60.0, +) -> Tuple[XAdapter, xg.Instance]: + """ + Get an adapter, blocking up to `timeout` seconds + """ - def adapterCB(status: xg.RequestAdapterStatus, gotten: xg.Adapter, msg: str): + # will be populated by a callback + stash: List[Optional[Tuple[xg.RequestAdapterStatus, xg.Adapter, str]]] = [None] + + def adapterCB(status: xg.RequestAdapterStatus, adapter: xg.Adapter, msg: str): print("Got adapter with msg:", msg, ", status:", status.name) - adapter[0] = gotten + stash[0] = (status, adapter, msg) cb = xg.RequestAdapterCallback(adapterCB) @@ -48,22 +54,40 @@ def adapterCB(status: xg.RequestAdapterStatus, gotten: xg.Adapter, msg: str): cb, ) - while adapter[0] is None: + deadline = time.time() + timeout + while stash[0] is None: time.sleep(0.1) + if time.time() > deadline: + raise TimeoutError(f"Timed out getting adapter after {timeout:0.2f}s!") + + # we have exited the loop without raising + status, adapter, msg = stash[0] - return (XAdapter(adapter[0]), instance) + if status != xg.RequestAdapterStatus.Success: + raise RuntimeError( + f"Failed to get adapter, status=`{status.name}`, message:'{msg}'" + ) + + return XAdapter(adapter), instance def get_device( adapter: xg.Adapter, - features: Optional[list[xg.FeatureName]] = None, + features: Optional[List[xg.FeatureName]] = None, limits: Optional[xg.RequiredLimits] = None, + timeout: float = 60, ) -> XDevice: - device: list[Optional[xg.Device]] = [None] + """ + Get a device, blocking up to `timeout` seconds. + """ + + # collect the device from a callback + stash: List[Optional[Tuple[xg.RequestDeviceStatus, xg.Device, str]]] = [None] - def deviceCB(status: xg.RequestDeviceStatus, gotten: xg.Device, msg: str): + def deviceCB(status: xg.RequestDeviceStatus, device: xg.Device, msg: str): print("Got device with msg:", msg, ", status:", status.name) - device[0] = gotten + + stash[0] = (status, device, msg) def deviceLostCB(reason: xg.DeviceLostReason, msg: str): print("Lost device!:", reason, msg) @@ -90,7 +114,18 @@ def deviceLostCB(reason: xg.DeviceLostReason, msg: str): cb, ) - while device[0] is None: + deadline = time.time() + timeout + while stash[0] is None: time.sleep(0.1) + if time.time() > deadline: + raise TimeoutError(f"Timed out getting device after {timeout:0.2f}s!") + + # we have exited the loop without raising + status, device, msg = stash[0] + + if status != xg.RequestDeviceStatus.Success: + raise RuntimeError( + f"Failed to get device, status=`{status.name}`, message:'{msg}'" + ) - return XDevice(device[0]) + return XDevice(device)