Skip to content

Commit

Permalink
check status of adapter before returning (#10)
Browse files Browse the repository at this point in the history
Throw errors if fail to get adapter/device
  • Loading branch information
mikedh authored Jan 26, 2024
1 parent 695352f commit 4fb579b
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 15 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "xgpu"
version = "0.5.0"
version = "0.5.1"
readme = "README.md"
requires-python = ">=3.7"
dependencies = ["cffi"]
Expand Down Expand Up @@ -70,4 +70,4 @@ before-build = "pip install cffi && cd xgpu && python _build_ext.py"
skip = ["*-win32", "*-manylinux_i686", "*musl*"]

# just make sure the library imports for now
test-command = 'python -c "import xgpu"'
test-command = 'python -c "import xgpu; import xgpu.conveniences"'
61 changes: 48 additions & 13 deletions xgpu/conveniences.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import Optional
from typing import Optional, Tuple, List

from . import bindings as xg
from .extensions import XAdapter, XDevice
Expand Down Expand Up @@ -27,12 +27,18 @@ def get_adapter(
instance: Optional[xg.Instance] = None,
power=xg.PowerPreference.HighPerformance,
surface: Optional[xg.Surface] = None,
) -> tuple[XAdapter, xg.Instance]:
adapter: list[Optional[xg.Adapter]] = [None]
timeout: float = 60.0,
) -> Tuple[XAdapter, xg.Instance]:
"""
Get an adapter, blocking up to `timeout` seconds
"""

def adapterCB(status: xg.RequestAdapterStatus, gotten: xg.Adapter, msg: str):
# will be populated by a callback
stash: List[Optional[Tuple[xg.RequestAdapterStatus, xg.Adapter, str]]] = [None]

def adapterCB(status: xg.RequestAdapterStatus, adapter: xg.Adapter, msg: str):
print("Got adapter with msg:", msg, ", status:", status.name)
adapter[0] = gotten
stash[0] = (status, adapter, msg)

cb = xg.RequestAdapterCallback(adapterCB)

Expand All @@ -48,22 +54,40 @@ def adapterCB(status: xg.RequestAdapterStatus, gotten: xg.Adapter, msg: str):
cb,
)

while adapter[0] is None:
deadline = time.time() + timeout
while stash[0] is None:
time.sleep(0.1)
if time.time() > deadline:
raise TimeoutError(f"Timed out getting adapter after {timeout:0.2f}s!")

# we have exited the loop without raising
status, adapter, msg = stash[0]

return (XAdapter(adapter[0]), instance)
if status != xg.RequestAdapterStatus.Success:
raise RuntimeError(
f"Failed to get adapter, status=`{status.name}`, message:'{msg}'"
)

return XAdapter(adapter), instance


def get_device(
adapter: xg.Adapter,
features: Optional[list[xg.FeatureName]] = None,
features: Optional[List[xg.FeatureName]] = None,
limits: Optional[xg.RequiredLimits] = None,
timeout: float = 60,
) -> XDevice:
device: list[Optional[xg.Device]] = [None]
"""
Get a device, blocking up to `timeout` seconds.
"""

# collect the device from a callback
stash: List[Optional[Tuple[xg.RequestDeviceStatus, xg.Device, str]]] = [None]

def deviceCB(status: xg.RequestDeviceStatus, gotten: xg.Device, msg: str):
def deviceCB(status: xg.RequestDeviceStatus, device: xg.Device, msg: str):
print("Got device with msg:", msg, ", status:", status.name)
device[0] = gotten

stash[0] = (status, device, msg)

def deviceLostCB(reason: xg.DeviceLostReason, msg: str):
print("Lost device!:", reason, msg)
Expand All @@ -90,7 +114,18 @@ def deviceLostCB(reason: xg.DeviceLostReason, msg: str):
cb,
)

while device[0] is None:
deadline = time.time() + timeout
while stash[0] is None:
time.sleep(0.1)
if time.time() > deadline:
raise TimeoutError(f"Timed out getting device after {timeout:0.2f}s!")

# we have exited the loop without raising
status, device, msg = stash[0]

if status != xg.RequestDeviceStatus.Success:
raise RuntimeError(
f"Failed to get device, status=`{status.name}`, message:'{msg}'"
)

return XDevice(device[0])
return XDevice(device)

0 comments on commit 4fb579b

Please sign in to comment.