From c5918c2975756b8df197332d7d2b0f9c11ac3d0f Mon Sep 17 00:00:00 2001 From: PyryM Date: Fri, 8 Mar 2024 16:50:19 -0500 Subject: [PATCH] fix binding generation of some wgpu-native extras (#24) webgpu uses the convention for arrays/lists of: size_t thingCount const thing* things wgpu uses the convention (note reversed order): const thing* things size_t thingCount --- codegen/generate.ts | 52 +++++++++++++++++++++++++++---- pyproject.toml | 2 +- xgpu/bindings.py | 76 ++++++++++++++++++++++++++------------------- 3 files changed, 91 insertions(+), 39 deletions(-) diff --git a/codegen/generate.ts b/codegen/generate.ts index 3c7da7a..ded735c 100644 --- a/codegen/generate.ts +++ b/codegen/generate.ts @@ -314,6 +314,47 @@ interface Emittable { emitCDef?(api: ApiInfo): string; } +interface FieldPair { + name: string; + type: Refinfo; +} + +function isPluralOf(query: string, thing: string): boolean { + if(query === thing + "s") { + return true + } + if(thing.endsWith("y") && query === `${thing.slice(0, thing.length-1)}ies`) { + return true + } + return false +} + +function areListField(f0: FieldPair, f1: FieldPair): [FieldPair, FieldPair] | undefined { + const isCount = (f: FieldPair) => f.name.endsWith("Count") && f.type.inner === "size_t"; + + if(isCount(f0)) { + // already in (count, field) order + } else if(isCount(f1)) { + // idiosyncratic (field, count) order + [f1, f0] = [f0, f1]; + } else { + return undefined + } + + const m = f0.name.match(/^(.*)Count$/); + if(!m) { + return undefined + } + const [_wholeMatch, prefix] = m; + + // double check that second field is correctly named + if(!isPluralOf(f1.name, prefix)) { + return undefined + } + + return [f0, f1] +} + class ApiInfo { types: Map = new Map(); wrappers: Map = new Map(); @@ -470,14 +511,13 @@ class ApiInfo { while (fieldPos < rawFields.length) { const { name, type } = rawFields[fieldPos]; if ( - name.endsWith("Count") && - type.inner === "size_t" && - fieldPos + 1 < rawFields.length + fieldPos + 1 < rawFields.length && + areListField(rawFields[fieldPos], rawFields[fieldPos+1]) ) { - const { name: arrName, type: arrType } = rawFields[fieldPos + 1]; - const innerCType = this.getType(arrType); + const [countField, listField] = areListField(rawFields[fieldPos], rawFields[fieldPos+1])!; + const innerCType = this.getType(listField.type); this.getListWrapper(innerCType); - fields.push(new ArrayField(arrName, name, innerCType)); + fields.push(new ArrayField(listField.name, countField.name, innerCType)); fieldPos += 2; } else if ( name.endsWith("Callback") && diff --git a/pyproject.toml b/pyproject.toml index e59f8f8..3b2371e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "xgpu" -version = "0.8.0" +version = "0.8.1" readme = "README.md" requires-python = ">=3.7" dependencies = ["cffi"] diff --git a/xgpu/bindings.py b/xgpu/bindings.py index 3854677..4ceb4bd 100644 --- a/xgpu/bindings.py +++ b/xgpu/bindings.py @@ -8385,13 +8385,18 @@ def __init__(self, *, cdata: Optional[CData] = None, parent: Optional[Any] = Non self._cdata.chain.sType = SType.BindGroupEntryExtras @property - def buffers(self) -> "Buffer": - return Buffer(self._cdata.buffers, add_ref=True) + def buffers(self) -> "BufferList": + return self._buffers @buffers.setter - def buffers(self, v: "Buffer") -> None: - self._buffers = v - self._cdata.buffers = v._cdata + def buffers(self, v: Union["BufferList", List["Buffer"]]) -> None: + if isinstance(v, list): + v2 = BufferList(v) + else: + v2 = v + self._buffers = v2 + self._cdata.bufferCount = v2._count + self._cdata.buffers = v2._ptr @property def samplers(self) -> "SamplerList": @@ -8404,7 +8409,7 @@ def samplers(self, v: Union["SamplerList", List["Sampler"]]) -> None: else: v2 = v self._samplers = v2 - self._cdata.bufferCount = v2._count + self._cdata.samplerCount = v2._count self._cdata.samplers = v2._ptr @property @@ -8418,17 +8423,9 @@ def textureViews(self, v: Union["TextureViewList", List["TextureView"]]) -> None else: v2 = v self._textureViews = v2 - self._cdata.samplerCount = v2._count + self._cdata.textureViewCount = v2._count self._cdata.textureViews = v2._ptr - @property - def textureViewCount(self) -> int: - return self._cdata.textureViewCount - - @textureViewCount.setter - def textureViewCount(self, v: int) -> None: - self._cdata.textureViewCount = v - @property def _chain(self) -> Any: return self._cdata.chain @@ -8436,16 +8433,14 @@ def _chain(self) -> Any: def bindGroupEntryExtras( *, - buffers: "Buffer", + buffers: Union["BufferList", List["Buffer"]], samplers: Union["SamplerList", List["Sampler"]], textureViews: Union["TextureViewList", List["TextureView"]], - textureViewCount: int, ) -> BindGroupEntryExtras: ret = BindGroupEntryExtras(cdata=None, parent=None) ret.buffers = buffers ret.samplers = samplers ret.textureViews = textureViews - ret.textureViewCount = textureViewCount return ret @@ -8481,21 +8476,20 @@ def __init__(self, *, cdata: Optional[CData] = None, parent: Optional[Any] = Non self._cdata.chain.sType = SType.QuerySetDescriptorExtras @property - def pipelineStatistics(self) -> "PipelineStatisticName": + def pipelineStatistics(self) -> "PipelineStatisticNameList": return self._pipelineStatistics @pipelineStatistics.setter - def pipelineStatistics(self, v: "PipelineStatisticName") -> None: - self._pipelineStatistics = v - self._cdata.pipelineStatistics = int(v) - - @property - def pipelineStatisticCount(self) -> int: - return self._cdata.pipelineStatisticCount - - @pipelineStatisticCount.setter - def pipelineStatisticCount(self, v: int) -> None: - self._cdata.pipelineStatisticCount = v + def pipelineStatistics( + self, v: Union["PipelineStatisticNameList", List["PipelineStatisticName"]] + ) -> None: + if isinstance(v, list): + v2 = PipelineStatisticNameList(v) + else: + v2 = v + self._pipelineStatistics = v2 + self._cdata.pipelineStatisticCount = v2._count + self._cdata.pipelineStatistics = v2._ptr @property def _chain(self) -> Any: @@ -8503,11 +8497,11 @@ def _chain(self) -> Any: def querySetDescriptorExtras( - *, pipelineStatistics: "PipelineStatisticName", pipelineStatisticCount: int + *, + pipelineStatistics: Union["PipelineStatisticNameList", List["PipelineStatisticName"]], ) -> QuerySetDescriptorExtras: ret = QuerySetDescriptorExtras(cdata=None, parent=None) ret.pipelineStatistics = pipelineStatistics - ret.pipelineStatisticCount = pipelineStatisticCount return ret @@ -8974,6 +8968,15 @@ def __init__(self, items: List["PushConstantRange"], count: int = 0): self._ptr[idx] = _ffi_deref(item._cdata) +class BufferList: + def __init__(self, items: List["Buffer"], count: int = 0): + self._stashed = items + self._count = max(len(items), count) + self._ptr = _ffi_new("WGPUBuffer[]", self._count) + for idx, item in enumerate(items): + self._ptr[idx] = item._cdata + + class SamplerList: def __init__(self, items: List["Sampler"], count: int = 0): self._stashed = items @@ -8992,6 +8995,15 @@ def __init__(self, items: List["TextureView"], count: int = 0): self._ptr[idx] = item._cdata +class PipelineStatisticNameList: + def __init__(self, items: List["PipelineStatisticName"], count: int = 0): + self._stashed = items + self._count = max(len(items), count) + self._ptr = _ffi_new("WGPUPipelineStatisticName[]", self._count) + for idx, item in enumerate(items): + self._ptr[idx] = int(item) + + class IntList: def __init__(self, items: List[int], count: int = 0): self._stashed = items