Skip to content

Commit

Permalink
Allow bitwise OR-ing of enum flags
Browse files Browse the repository at this point in the history
  • Loading branch information
PyryM committed Jan 12, 2024
1 parent 1d740a8 commit 6df337a
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 3 deletions.
8 changes: 6 additions & 2 deletions codegen/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ function pyOptional(pyType: string): string {
return IS_PY12 ? `${pyType} | None` : `Optional[${pyType}]`;
}

function pyUnion(a: string, b: string): string {
return IS_PY12 ? `${a} | ${b}` : `Union[${a}, ${b}]`;
function pyUnion(...args: string[]): string {
return IS_PY12 ? args.join(" | ") : `Union[${args.join(", ")}]`;
}

function onlyDefined<T>(items: (T | undefined)[]): T[] {
Expand All @@ -98,6 +98,7 @@ class CEnum implements CType {
kind: "enum" = "enum";
sanitized: { name: string; val: string }[] = [];
values: CEnumVal[] = [];
flagType?: string

constructor(public cName: string, public pyName: string, values: CEnumVal[]) {
this.mergeValues(values);
Expand Down Expand Up @@ -189,6 +190,9 @@ class ${this.pyName}:
else:
self.value = sum(set(flags))
def __or__(self, rhs: ${pyUnion(quoted(this.pyName), etypename)}) -> ${quoted(this.pyName)}:
return ${this.pyName}(int(self) | int(rhs))
def __int__(self) -> int:
return self.value
` + props.join("\n")
Expand Down
27 changes: 27 additions & 0 deletions webgoo/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,9 @@ def __init__(self, flags: Union[list["BufferUsage"], int]):
else:
self.value = sum(set(flags))

def __or__(self, rhs: Union["BufferUsageFlags", "BufferUsage"]) -> "BufferUsageFlags":
return BufferUsageFlags(int(self) | int(rhs))

def __int__(self) -> int:
return self.value

Expand Down Expand Up @@ -826,6 +829,11 @@ def __init__(self, flags: Union[list["ColorWriteMask"], int]):
else:
self.value = sum(set(flags))

def __or__(
self, rhs: Union["ColorWriteMaskFlags", "ColorWriteMask"]
) -> "ColorWriteMaskFlags":
return ColorWriteMaskFlags(int(self) | int(rhs))

def __int__(self) -> int:
return self.value

Expand Down Expand Up @@ -892,6 +900,9 @@ def __init__(self, flags: Union[list["MapMode"], int]):
else:
self.value = sum(set(flags))

def __or__(self, rhs: Union["MapModeFlags", "MapMode"]) -> "MapModeFlags":
return MapModeFlags(int(self) | int(rhs))

def __int__(self) -> int:
return self.value

Expand Down Expand Up @@ -925,6 +936,9 @@ def __init__(self, flags: Union[list["ShaderStage"], int]):
else:
self.value = sum(set(flags))

def __or__(self, rhs: Union["ShaderStageFlags", "ShaderStage"]) -> "ShaderStageFlags":
return ShaderStageFlags(int(self) | int(rhs))

def __int__(self) -> int:
return self.value

Expand Down Expand Up @@ -969,6 +983,11 @@ def __init__(self, flags: Union[list["TextureUsage"], int]):
else:
self.value = sum(set(flags))

def __or__(
self, rhs: Union["TextureUsageFlags", "TextureUsage"]
) -> "TextureUsageFlags":
return TextureUsageFlags(int(self) | int(rhs))

def __int__(self) -> int:
return self.value

Expand Down Expand Up @@ -1035,6 +1054,11 @@ def __init__(self, flags: Union[list["InstanceBackend"], int]):
else:
self.value = sum(set(flags))

def __or__(
self, rhs: Union["InstanceBackendFlags", "InstanceBackend"]
) -> "InstanceBackendFlags":
return InstanceBackendFlags(int(self) | int(rhs))

def __int__(self) -> int:
return self.value

Expand Down Expand Up @@ -1112,6 +1136,9 @@ def __init__(self, flags: Union[list["InstanceFlag"], int]):
else:
self.value = sum(set(flags))

def __or__(self, rhs: Union["InstanceFlags", "InstanceFlag"]) -> "InstanceFlags":
return InstanceFlags(int(self) | int(rhs))

def __int__(self) -> int:
return self.value

Expand Down
4 changes: 3 additions & 1 deletion webgoo/conveniences.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ def read_rgba_texture(device: wg.Device, tex: wg.Texture):

def create_buffer_with_data(device: wg.Device, data: bytes) -> wg.Buffer:
bsize = len(data)
buffer = device.createBuffer(usage = wg.BufferUsage.CopySrc, size=bsize, mappedAtCreation=True)
buffer = device.createBuffer(
usage=wg.BufferUsage.CopySrc, size=bsize, mappedAtCreation=True
)
range = buffer.getMappedRange(0, bsize)
ffi.memmove(range._ptr, data, bsize)
buffer.unmap()
Expand Down

0 comments on commit 6df337a

Please sign in to comment.