Skip to content

Commit

Permalink
fix binding generation of some wgpu-native extras (#24)
Browse files Browse the repository at this point in the history
webgpu uses the convention for arrays/lists of:
size_t thingCount
const thing* things

wgpu uses the convention (note reversed order):
const thing* things
size_t thingCount
  • Loading branch information
PyryM authored Mar 8, 2024
1 parent 6da880c commit c5918c2
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 39 deletions.
52 changes: 46 additions & 6 deletions codegen/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,47 @@ interface Emittable {
emitCDef?(api: ApiInfo): string;
}

interface FieldPair {
name: string;
type: Refinfo;
}

function isPluralOf(query: string, thing: string): boolean {
if(query === thing + "s") {
return true
}
if(thing.endsWith("y") && query === `${thing.slice(0, thing.length-1)}ies`) {
return true
}
return false
}

function areListField(f0: FieldPair, f1: FieldPair): [FieldPair, FieldPair] | undefined {
const isCount = (f: FieldPair) => f.name.endsWith("Count") && f.type.inner === "size_t";

if(isCount(f0)) {
// already in (count, field) order
} else if(isCount(f1)) {
// idiosyncratic (field, count) order
[f1, f0] = [f0, f1];
} else {
return undefined
}

const m = f0.name.match(/^(.*)Count$/);
if(!m) {
return undefined
}
const [_wholeMatch, prefix] = m;

// double check that second field is correctly named
if(!isPluralOf(f1.name, prefix)) {
return undefined
}

return [f0, f1]
}

class ApiInfo {
types: Map<string, CType> = new Map();
wrappers: Map<string, Emittable> = new Map();
Expand Down Expand Up @@ -470,14 +511,13 @@ class ApiInfo {
while (fieldPos < rawFields.length) {
const { name, type } = rawFields[fieldPos];
if (
name.endsWith("Count") &&
type.inner === "size_t" &&
fieldPos + 1 < rawFields.length
fieldPos + 1 < rawFields.length &&
areListField(rawFields[fieldPos], rawFields[fieldPos+1])
) {
const { name: arrName, type: arrType } = rawFields[fieldPos + 1];
const innerCType = this.getType(arrType);
const [countField, listField] = areListField(rawFields[fieldPos], rawFields[fieldPos+1])!;
const innerCType = this.getType(listField.type);
this.getListWrapper(innerCType);
fields.push(new ArrayField(arrName, name, innerCType));
fields.push(new ArrayField(listField.name, countField.name, innerCType));
fieldPos += 2;
} else if (
name.endsWith("Callback") &&
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "xgpu"
version = "0.8.0"
version = "0.8.1"
readme = "README.md"
requires-python = ">=3.7"
dependencies = ["cffi"]
Expand Down
76 changes: 44 additions & 32 deletions xgpu/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8385,13 +8385,18 @@ def __init__(self, *, cdata: Optional[CData] = None, parent: Optional[Any] = Non
self._cdata.chain.sType = SType.BindGroupEntryExtras

@property
def buffers(self) -> "Buffer":
return Buffer(self._cdata.buffers, add_ref=True)
def buffers(self) -> "BufferList":
return self._buffers

@buffers.setter
def buffers(self, v: "Buffer") -> None:
self._buffers = v
self._cdata.buffers = v._cdata
def buffers(self, v: Union["BufferList", List["Buffer"]]) -> None:
if isinstance(v, list):
v2 = BufferList(v)
else:
v2 = v
self._buffers = v2
self._cdata.bufferCount = v2._count
self._cdata.buffers = v2._ptr

@property
def samplers(self) -> "SamplerList":
Expand All @@ -8404,7 +8409,7 @@ def samplers(self, v: Union["SamplerList", List["Sampler"]]) -> None:
else:
v2 = v
self._samplers = v2
self._cdata.bufferCount = v2._count
self._cdata.samplerCount = v2._count
self._cdata.samplers = v2._ptr

@property
Expand All @@ -8418,34 +8423,24 @@ def textureViews(self, v: Union["TextureViewList", List["TextureView"]]) -> None
else:
v2 = v
self._textureViews = v2
self._cdata.samplerCount = v2._count
self._cdata.textureViewCount = v2._count
self._cdata.textureViews = v2._ptr

@property
def textureViewCount(self) -> int:
return self._cdata.textureViewCount

@textureViewCount.setter
def textureViewCount(self, v: int) -> None:
self._cdata.textureViewCount = v

@property
def _chain(self) -> Any:
return self._cdata.chain


def bindGroupEntryExtras(
*,
buffers: "Buffer",
buffers: Union["BufferList", List["Buffer"]],
samplers: Union["SamplerList", List["Sampler"]],
textureViews: Union["TextureViewList", List["TextureView"]],
textureViewCount: int,
) -> BindGroupEntryExtras:
ret = BindGroupEntryExtras(cdata=None, parent=None)
ret.buffers = buffers
ret.samplers = samplers
ret.textureViews = textureViews
ret.textureViewCount = textureViewCount
return ret


Expand Down Expand Up @@ -8481,33 +8476,32 @@ def __init__(self, *, cdata: Optional[CData] = None, parent: Optional[Any] = Non
self._cdata.chain.sType = SType.QuerySetDescriptorExtras

@property
def pipelineStatistics(self) -> "PipelineStatisticName":
def pipelineStatistics(self) -> "PipelineStatisticNameList":
return self._pipelineStatistics

@pipelineStatistics.setter
def pipelineStatistics(self, v: "PipelineStatisticName") -> None:
self._pipelineStatistics = v
self._cdata.pipelineStatistics = int(v)

@property
def pipelineStatisticCount(self) -> int:
return self._cdata.pipelineStatisticCount

@pipelineStatisticCount.setter
def pipelineStatisticCount(self, v: int) -> None:
self._cdata.pipelineStatisticCount = v
def pipelineStatistics(
self, v: Union["PipelineStatisticNameList", List["PipelineStatisticName"]]
) -> None:
if isinstance(v, list):
v2 = PipelineStatisticNameList(v)
else:
v2 = v
self._pipelineStatistics = v2
self._cdata.pipelineStatisticCount = v2._count
self._cdata.pipelineStatistics = v2._ptr

@property
def _chain(self) -> Any:
return self._cdata.chain


def querySetDescriptorExtras(
*, pipelineStatistics: "PipelineStatisticName", pipelineStatisticCount: int
*,
pipelineStatistics: Union["PipelineStatisticNameList", List["PipelineStatisticName"]],
) -> QuerySetDescriptorExtras:
ret = QuerySetDescriptorExtras(cdata=None, parent=None)
ret.pipelineStatistics = pipelineStatistics
ret.pipelineStatisticCount = pipelineStatisticCount
return ret


Expand Down Expand Up @@ -8974,6 +8968,15 @@ def __init__(self, items: List["PushConstantRange"], count: int = 0):
self._ptr[idx] = _ffi_deref(item._cdata)


class BufferList:
def __init__(self, items: List["Buffer"], count: int = 0):
self._stashed = items
self._count = max(len(items), count)
self._ptr = _ffi_new("WGPUBuffer[]", self._count)
for idx, item in enumerate(items):
self._ptr[idx] = item._cdata


class SamplerList:
def __init__(self, items: List["Sampler"], count: int = 0):
self._stashed = items
Expand All @@ -8992,6 +8995,15 @@ def __init__(self, items: List["TextureView"], count: int = 0):
self._ptr[idx] = item._cdata


class PipelineStatisticNameList:
def __init__(self, items: List["PipelineStatisticName"], count: int = 0):
self._stashed = items
self._count = max(len(items), count)
self._ptr = _ffi_new("WGPUPipelineStatisticName[]", self._count)
for idx, item in enumerate(items):
self._ptr[idx] = int(item)


class IntList:
def __init__(self, items: List[int], count: int = 0):
self._stashed = items
Expand Down

0 comments on commit c5918c2

Please sign in to comment.