Skip to content

Commit

Permalink
Merge pull request #128 from mdanilow/feature/strings_attr
Browse files Browse the repository at this point in the history
Feature/strings attr
  • Loading branch information
maltanar authored Aug 12, 2024
2 parents 326a525 + 654bf15 commit c4c16f7
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
6 changes: 5 additions & 1 deletion src/qonnx/custom_op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def get_nodeattr(self, name):
if dtype == "s":
# decode string attributes
ret = ret.decode("utf-8")
elif dtype == "strings":
ret = [x.decode("utf-8") for x in ret]
elif dtype == "t":
# use numpy helper to convert TensorProto -> np array
ret = np_helper.to_array(ret)
Expand Down Expand Up @@ -123,13 +125,15 @@ def set_nodeattr(self, name, value):
# encode string attributes
value = value.encode("utf-8")
attr.__setattr__(dtype, value)
elif dtype == "strings":
attr.strings[:] = [x.encode("utf-8") for x in value]
elif dtype == "floats": # list of floats
attr.floats[:] = value
elif dtype == "ints": # list of integers
attr.ints[:] = value
elif dtype == "t": # single tensor
attr.t.CopyFrom(value)
elif dtype in ["strings", "tensors", "graphs", "sparse_tensors"]:
elif dtype in ["tensors", "graphs", "sparse_tensors"]:
# untested / unsupported attribute types
# add testcases & appropriate getters before enabling
raise Exception("Attribute type %s not yet supported" % dtype)
Expand Down
12 changes: 11 additions & 1 deletion tests/custom_op/test_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@

class AttrTestOp(CustomOp):
def get_nodeattr_types(self):
return {"tensor_attr": ("t", True, np.asarray([]))}
my_attrs = {"tensor_attr": ("t", True, np.asarray([])), "strings_attr": ("strings", True, [""])}
return my_attrs

def make_shape_compatible_op(self, model):
param_tensor = self.get_nodeattr("tensor_attr")
Expand Down Expand Up @@ -70,6 +71,7 @@ def test_attr():
strarr = np.array2string(w, separator=", ")
w_str = strarr.replace("[", "{").replace("]", "}").replace(" ", "")
tensor_attr_str = f"int8{wshp_str} {w_str}"
strings_attr = ["a", "bc", "def"]

input = f"""
<
Expand All @@ -86,9 +88,17 @@ def test_attr():
model = oprs.parse_model(input)
model = ModelWrapper(model)
inst = getCustomOp(model.graph.node[0])

w_prod = inst.get_nodeattr("tensor_attr")
assert (w_prod == w).all()
w = w - 1
inst.set_nodeattr("tensor_attr", w)
w_prod = inst.get_nodeattr("tensor_attr")
assert (w_prod == w).all()

inst.set_nodeattr("strings_attr", strings_attr)
strings_attr_prod = inst.get_nodeattr("strings_attr")
assert strings_attr_prod == strings_attr
strings_attr_prod[0] = "test"
inst.set_nodeattr("strings_attr", strings_attr_prod)
assert inst.get_nodeattr("strings_attr") == ["test"] + strings_attr[1:]

0 comments on commit c4c16f7

Please sign in to comment.