Skip to content

Commit

Permalink
fixed enable_binary_blob option for CWriter
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Buethe committed May 6, 2024
1 parent 2056881 commit 1711e97
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 41 deletions.
2 changes: 1 addition & 1 deletion dnn/torch/lossgen/export_lossgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def c_export(args, model):

message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"

writer = CWriter(os.path.join(args.output_dir, "lossgen_data"), message=message, model_struct_name='LossGen')
writer = CWriter(os.path.join(args.output_dir, "lossgen_data"), message=message, model_struct_name='LossGen', enable_binary_blob=False)
writer.header.write(
f"""
#include "opus_types.h"
Expand Down
77 changes: 38 additions & 39 deletions dnn/torch/weight-exchange/wexchange/c_export/c_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,50 +120,49 @@ def __init__(self,
def _finalize_header(self):

# create model type
if self.enable_binary_blob:
if self.add_typedef:
self.header.write(f"\ntypedef struct {{")
else:
self.header.write(f"\nstruct {self.model_struct_name} {{")
for name, data in self.layer_dict.items():
layer_type = data[0]
self.header.write(f"\n {layer_type} {name};")
if self.add_typedef:
self.header.write(f"\n}} {self.model_struct_name};\n")
else:
self.header.write(f"\n}};\n")

init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"
self.header.write(f"\n{init_prototype};\n")
if self.add_typedef:
self.header.write(f"\ntypedef struct {{")
else:
self.header.write(f"\nstruct {self.model_struct_name} {{")
for name, data in self.layer_dict.items():
layer_type = data[0]
self.header.write(f"\n {layer_type} {name};")
if self.add_typedef:
self.header.write(f"\n}} {self.model_struct_name};\n")
else:
self.header.write(f"\n}};\n")

init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"
self.header.write(f"\n{init_prototype};\n")

self.header.write(f"\n#endif /* {self.header_guard} */\n")

def _finalize_source(self):

if self.enable_binary_blob:
# create weight array
if len(set(self.weight_arrays)) != len(self.weight_arrays):
raise ValueError("error: detected duplicates in weight arrays")
self.source.write("\n#ifndef USE_WEIGHTS_FILE\n")
self.source.write(f"const WeightArray {self.model_struct_name.lower()}_arrays[] = {{\n")
for name in self.weight_arrays:
self.source.write(f"#ifdef WEIGHTS_{name}_DEFINED\n")
self.source.write(f' {{"{name}", WEIGHTS_{name}_TYPE, sizeof({name}), {name}}},\n')
self.source.write(f"#endif\n")
self.source.write(" {NULL, 0, 0, NULL}\n")
self.source.write("};\n")

self.source.write("#endif /* USE_WEIGHTS_FILE */\n")

# create init function definition
init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"
self.source.write("\n#ifndef DUMP_BINARY_WEIGHTS\n")
self.source.write(f"{init_prototype} {{\n")
for name, data in self.layer_dict.items():
self.source.write(f" if ({data[1]}) return 1;\n")
self.source.write(" return 0;\n")
self.source.write("}\n")
self.source.write("#endif /* DUMP_BINARY_WEIGHTS */\n")

# create weight array
if len(set(self.weight_arrays)) != len(self.weight_arrays):
raise ValueError("error: detected duplicates in weight arrays")
if self.enable_binary_blob: self.source.write("\n#ifndef USE_WEIGHTS_FILE\n")
self.source.write(f"const WeightArray {self.model_struct_name.lower()}_arrays[] = {{\n")
for name in self.weight_arrays:
self.source.write(f"#ifdef WEIGHTS_{name}_DEFINED\n")
self.source.write(f' {{"{name}", WEIGHTS_{name}_TYPE, sizeof({name}), {name}}},\n')
self.source.write(f"#endif\n")
self.source.write(" {NULL, 0, 0, NULL}\n")
self.source.write("};\n")

if self.enable_binary_blob: self.source.write("#endif /* USE_WEIGHTS_FILE */\n")

# create init function definition
init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"
if self.enable_binary_blob: self.source.write("\n#ifndef DUMP_BINARY_WEIGHTS\n")
self.source.write(f"{init_prototype} {{\n")
for name, data in self.layer_dict.items():
self.source.write(f" if ({data[1]}) return 1;\n")
self.source.write(" return 0;\n")
self.source.write("}\n")
if self.enable_binary_blob:self.source.write("#endif /* DUMP_BINARY_WEIGHTS */\n")


def close(self):
Expand Down
2 changes: 1 addition & 1 deletion dnn/torch/weight-exchange/wexchange/c_export/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def print_vector(writer, vector, name, dtype='float', reshape_8x4=False, static=
#ifndef USE_WEIGHTS_FILE
'''
)
writer.weight_arrays.append(name)
writer.weight_arrays.append(name)

if reshape_8x4:
vector = vector.reshape((vector.shape[0]//4, 4, vector.shape[1]//8, 8))
Expand Down

0 comments on commit 1711e97

Please sign in to comment.