Skip to content

Commit

Permalink
first draft zig generator; only deserializes
Browse files Browse the repository at this point in the history
  • Loading branch information
sjml committed Oct 28, 2024
1 parent 93a0120 commit 65bea8a
Show file tree
Hide file tree
Showing 11 changed files with 310 additions and 12 deletions.
3 changes: 2 additions & 1 deletion beschi/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import argparse
import traceback

from .protocol import Protocol
from .writers import all_writers, experimental_writers
Expand Down Expand Up @@ -74,7 +75,7 @@ def main():
try:
output = writer.generate()
except NotImplementedError as nie:
sys.stderr.write(f"{nie}\n")
sys.stderr.write(f"{traceback.format_exc()}\n")
sys.exit(1)

if args.output == None:
Expand Down
76 changes: 76 additions & 0 deletions beschi/writers/boilerplate/Zig.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
fn _numberTypeIsValid(comptime T: type) bool {
const validNumericTypes = [_]type{
u8, i8,
u16, i16,
u32, i32,
u64, i64,
f32, f64,
};
for (validNumericTypes) |vt| {
if (T == vt) {
return true;
}
}
return false;
}

pub fn readNumber(comptime T: type, offset: usize, buffer: []u8) struct { value: T, bytes_read: usize } {
comptime {
if (!_numberTypeIsValid(T)) {
@compileError("Invalid number type");
}
}

switch (T) {
f32 => return .{ .value = @bitCast(std.mem.readInt(u32, buffer[offset..][0..@sizeOf(T)], .little)), .bytes_read = @sizeOf(T) },
f64 => return .{ .value = @bitCast(std.mem.readInt(u64, buffer[offset..][0..@sizeOf(T)], .little)), .bytes_read = @sizeOf(T) },
else => return .{ .value = std.mem.readInt(T, buffer[offset..][0..@sizeOf(T)], .little), .bytes_read = @sizeOf(T) },
}
}

pub fn readString(allocator: std.mem.Allocator, offset: usize, buffer: []u8) !struct { value: []u8, bytes_read: usize } {
const len_read = readNumber({# STRING_SIZE_TYPE #}, offset, buffer);
const len = len_read.value;
var str = try allocator.alloc(u8, len);
for (0..len) |i| {
str[i] = buffer[offset + len_read.bytes_read + i];
}
return .{ .value = str, .bytes_read = @sizeOf({# STRING_SIZE_TYPE #}) + len };
}

pub fn readList(comptime T: type, allocator: std.mem.Allocator, offset: usize, buffer: []u8) !struct { value: []T, bytes_read: usize } {
var local_offset = offset;
const len_read = readNumber({# LIST_SIZE_TYPE #}, local_offset, buffer);
const len = len_read.value;
local_offset += len_read.bytes_read;
var list = try allocator.alloc(T, len);

for (0..len) |i| {
if (comptime _numberTypeIsValid(T)) {
const list_read = readNumber(T, local_offset, buffer);
list[i] = list_read.value;
local_offset += list_read.bytes_read;
} else {
switch (T) {
[]u8 => {
const list_read = try readString(allocator, local_offset, buffer);
list[i] = list_read.value;
local_offset += list_read.bytes_read;
},
else => {
if (comptime _typeIsSimple(T)) {
const list_read = try T.fromBytes(local_offset, buffer);
list[i] = list_read.value;
local_offset += list_read.bytes_read;
}
else {
const list_read = try T.fromBytes(allocator, local_offset, buffer);
list[i] = list_read.value;
local_offset += list_read.bytes_read;
}
},
}
}
}
return .{ .value = list, .bytes_read = local_offset - offset };
}
8 changes: 5 additions & 3 deletions beschi/writers/c.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,11 @@ def gen_implementation(self, sname: str, sdata: Struct):

def generate(self) -> str:
self.output = []
self.write_line(f"// This file was automatically generated by {LIB_NAME} v{LIB_VERSION}")
self.write_line( "// <https://github.com/sjml/beschi>")
self.write_line(f"// Do not edit directly.")
self.write_line( "/*")
self.write_line(f" This file was automatically generated by {LIB_NAME} v{LIB_VERSION}")
self.write_line( " <https://github.com/sjml/beschi>")
self.write_line( " Do not edit directly.")
self.write_line( "*/")
self.write_line()

if self.embed_protocol:
Expand Down
2 changes: 1 addition & 1 deletion beschi/writers/csharp.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def generate(self) -> str:

self.write_line(f"// This file was automatically generated by {LIB_NAME} v{LIB_VERSION}")
self.write_line( "// <https://github.com/sjml/beschi>")
self.write_line(f"// Do not edit directly.")
self.write_line( "// Do not edit directly.")
self.write_line()

if self.embed_protocol:
Expand Down
2 changes: 1 addition & 1 deletion beschi/writers/go.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def generate(self) -> str:

self.write_line(f"// This file was automatically generated by {LIB_NAME} v{LIB_VERSION}")
self.write_line( "// <https://github.com/sjml/beschi>")
self.write_line(f"// Do not edit directly.")
self.write_line( "// Do not edit directly.")
self.write_line()

if self.embed_protocol:
Expand Down
2 changes: 1 addition & 1 deletion beschi/writers/rust.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def generate(self) -> str:
self.output = []
self.write_line(f"// This file was automatically generated by {LIB_NAME} v{LIB_VERSION}")
self.write_line( "// <https://github.com/sjml/beschi>")
self.write_line(f"// Do not edit directly.")
self.write_line( "// Do not edit directly.")
self.write_line()

if self.embed_protocol:
Expand Down
2 changes: 1 addition & 1 deletion beschi/writers/swift.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def generate(self) -> str:

self.write_line(f"// This file was automatically generated by {LIB_NAME} v{LIB_VERSION}")
self.write_line( "// <https://github.com/sjml/beschi>")
self.write_line(f"// Do not edit directly.")
self.write_line( "// Do not edit directly.")
self.write_line()
self.write_line("import Foundation")
self.write_line()
Expand Down
2 changes: 1 addition & 1 deletion beschi/writers/typescript.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def generate(self) -> str:

self.write_line(f"// This file was automatically generated by {LIB_NAME} v{LIB_VERSION}")
self.write_line( "// <https://github.com/sjml/beschi>")
self.write_line(f"// Do not edit directly.")
self.write_line( "// Do not edit directly.")
self.write_line()

if self.embed_protocol:
Expand Down
212 changes: 212 additions & 0 deletions beschi/writers/zig.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import argparse

from ..protocol import Protocol, Struct, Variable, NUMERIC_TYPE_SIZES
from ..writer import Writer, TextUtil
from .. import LIB_NAME, LIB_VERSION

LANGUAGE_NAME = "Zig"


class ZigWriter(Writer):
language_name = LANGUAGE_NAME
default_extension = ".zig"
in_progress = True

def __init__(self, p: Protocol, extra_args: dict[str,any] = {}):
super().__init__(protocol=p)

self.embed_protocol = extra_args["embed_protocol"]

self.type_mapping["byte"] = "u8"
self.type_mapping["bool"] = "bool"
self.type_mapping["uint16"] = "u16"
self.type_mapping["int16"] = "i16"
self.type_mapping["uint32"] = "u32"
self.type_mapping["int32"] = "i32"
self.type_mapping["uint64"] = "u64"
self.type_mapping["int64"] = "i64"
self.type_mapping["float"] = "f32"
self.type_mapping["double"] = "f64"
self.type_mapping["string"] = "[]u8"

self.base_defaults: dict[str,str] = {
"byte": "0",
"bool": "false",
"uint16": "0",
"int16": "0",
"uint32": "0",
"int32": "0",
"uint64": "0",
"int64": "0",
"float": "0.0",
"double": "0.0",
"string": '""',
}

def deserializer(self, var: Variable, accessor: str, parent_is_simple: bool, simple_offset: int):
if parent_is_simple: # also means that *var* is simple because recursion!
if var.vartype == "bool":
self.write_line(f"const {accessor}_{var.name} = readNumber(u8, offset + {simple_offset}, buffer).value != 0;")
elif var.vartype in NUMERIC_TYPE_SIZES.keys():
self.write_line(f"const {accessor}_{var.name} = readNumber({self.type_mapping[var.vartype]}, offset + {simple_offset}, buffer).value;")
else:
self.write_line(f"const {accessor}_{var.name}_read = {var.vartype}.fromBytes({simple_offset}, buffer);")
self.write_line(f"const {accessor}_{var.name} = {accessor}_{var.name}_read.value;")
else:
if var.is_list:
self.write_line(f"const {accessor}_{var.name}_read = try readList({self.type_mapping[var.vartype]}, allocator, local_offset, buffer);")
self.write_line(f"const {accessor}_{var.name} = {accessor}_{var.name}_read.value;")
self.write_line(f"local_offset += {accessor}_{var.name}_read.bytes_read;")
elif var.vartype == "bool":
self.write_line(f"const {accessor}_{var.name}_read = readNumber(u8, local_offset, buffer);")
self.write_line(f"const {accessor}_{var.name} = {accessor}_{var.name}_read.value != 0;")
self.write_line(f"local_offset += {accessor}_{var.name}_read.bytes_read;")
elif var.vartype in NUMERIC_TYPE_SIZES.keys():
self.write_line(f"const {accessor}_{var.name}_read = readNumber({self.type_mapping[var.vartype]}, local_offset, buffer);")
self.write_line(f"const {accessor}_{var.name} = {accessor}_{var.name}_read.value;")
self.write_line(f"local_offset += {accessor}_{var.name}_read.bytes_read;")
elif var.vartype == "string":
self.write_line(f"const {accessor}_{var.name}_read = try readString(allocator, local_offset, buffer);")
self.write_line(f"const {accessor}_{var.name} = {accessor}_{var.name}_read.value;")
self.write_line(f"local_offset += {accessor}_{var.name}_read.bytes_read;")
else:
self.write_line(f"const {accessor}_{var.name}_read = try {var.vartype}.fromBytes({'' if var.is_simple() else 'allocator, '}local_offset, buffer);")
self.write_line(f"const {accessor}_{var.name} = {accessor}_{var.name}_read.value;")
self.write_line(f"local_offset += {accessor}_{var.name}_read.bytes_read;")

self.write_line()

def destructor(self, var: Variable, accessor: str):
if var.is_simple():
return
elif var.is_list:
if not var.is_simple(True):
idx = self.indent_level
self.write_line(f"for ({accessor}{var.name}) |{'*' if not var.vartype == 'string' else ''}item{idx}| {{")
self.indent_level += 1
inner = Variable(self.protocol, f"item{idx}", var.vartype)
self.destructor(inner, "")
self.indent_level -= 1
self.write_line("}")
self.write_line(f"allocator.free({accessor}{var.name});")
elif var.vartype == "string":
self.write_line(f"allocator.free({accessor}{var.name});")
else:
self.write_line(f"{accessor}{var.name}.deinit(allocator);")


def gen_struct(self, sname: str, sdata: Struct):
self.write_line(f"pub const {sname} = struct {{")
self.indent_level += 1
for var in sdata.members:
if var.is_list:
self.write_line(f"{var.name}: []{self.type_mapping[var.vartype]},")
else:
default_value = self.base_defaults.get(var.vartype)
if default_value == None:
if var.is_simple():
default_value = f"{var.vartype}{{}}"
else:
default_value = None
if default_value != None:
self.write_line(f"{var.name}: {self.type_mapping[var.vartype]} = {default_value},")
else:
self.write_line(f"{var.name}: {self.type_mapping[var.vartype]},")
self.write_line()

self.write_line(f"pub fn fromBytes({'' if sdata.is_simple() else 'allocator: std.mem.Allocator, '}offset: usize, buffer: []u8) !struct {{ value: {sname}, bytes_read: usize }} {{")
self.indent_level += 1
simple_offset = -1
if sdata.is_simple():
simple_offset = 0
else:
self.write_line("var local_offset = offset;")
self.write_line()
for mem in sdata.members:
self.deserializer(mem, sname, sdata.is_simple(), simple_offset)
if sdata.is_simple():
simple_offset += self.protocol.get_size_of(mem.vartype)
self.write_line(f"return .{{ .value = {sname}{{")
self.indent_level += 1
for var in sdata.members:
self.write_line(f".{var.name} = {sname}_{var.name},")
self.indent_level -= 1
if sdata.is_simple():
self.write_line(f"}}, .bytes_read = {self.protocol.get_size_of(sdata.name)} }};")
else:
self.write_line(f"}}, .bytes_read = local_offset - offset }};")
self.indent_level -= 1
self.write_line("}")

if not sdata.is_simple():
self.write_line()
self.write_line(f"pub fn deinit(self: *{sname}, allocator: std.mem.Allocator) void {{")
self.indent_level += 1
[self.destructor(mem, "self.") for mem in sdata.members]
self.indent_level -= 1
self.write_line("}")


self.indent_level -= 1
self.write_line("};")
self.write_line()

def generate(self) -> str:
self.output = []

self.write_line(f"// This file was automatically generated by {LIB_NAME} v{LIB_VERSION}")
self.write_line( "// <https://github.com/sjml/beschi>")
self.write_line(f"// Do not edit directly.")
self.write_line()

if self.embed_protocol:
self.write_line("// DATA PROTOCOL")
self.write_line("// -----------------")
[self.write_line(f"// {l}") for l in self.protocol.protocol_string.splitlines()]
self.write_line("// -----------------")
self.write_line("// END DATA PROTOCOL")
self.write_line()
self.write_line()

self.write_line("const std = @import(\"std\");")
self.write_line()

self.write_line( "fn _typeIsSimple(comptime T: type) bool {")
self.write_line( " if (comptime _numberTypeIsValid(T)) {")
self.write_line( " return true;")
self.write_line( " }")
self.write_line( " const simpleTypes = [_]type{")
simple_structs = [sname for sname, sdata in self.protocol.structs.items() if sdata.is_simple()]
simple_messages = [mname for mname, mdata in self.protocol.messages.items() if mdata.is_simple()]
if len(simple_structs):
self.write_line(f" {', '.join(simple_structs )},")
if len(simple_messages):
self.write_line(f" {', '.join(simple_messages)},")
self.write_line( " };")
self.write_line( " for (simpleTypes) |vt| {")
self.write_line( " if (T == vt) {")
self.write_line( " return true;")
self.write_line( " }")
self.write_line( " }")
self.write_line( " return false;")
self.write_line( "}")
self.write_line()

subs = [
("{# STRING_SIZE_TYPE #}", self.get_native_string_size()),
("{# LIST_SIZE_TYPE #}" , self.get_native_list_size()),
]
self.add_boilerplate(subs)
self.write_line()

for sname, sdata in self.protocol.structs.items():
self.gen_struct(sname, sdata)

for mname, mdata in self.protocol.messages.items():
self.gen_struct(mname, mdata)


self.write_line()
assert self.indent_level == 0

return "\n".join(self.output)
Loading

0 comments on commit 65bea8a

Please sign in to comment.