Skip to content

Commit

Permalink
Get matmul_ijk working
Browse files Browse the repository at this point in the history
  • Loading branch information
weiya711 committed Oct 11, 2023
1 parent b5d9547 commit 15833ac
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions scripts/gen_sam_apps/test_generating_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,19 @@ def get_common_test_name(test_name):
return test_name


def get_out_crd_str(d, u_, index_value):
# By default, the input primitive connected to a crddrop will be a level scanner
out_crd_str = "out_crd"
# However, if the input primitive is another crddrop, we need to make sure it's reading from
# the correct input crddrop output.
if d[u_]["type"] == "crddrop":
if index_value == d[u_]["inner"]:
out_crd_str += "_inner"
elif index_value == d[u_]["outer"]:
out_crd_str += "_outer"
return out_crd_str


def generate_datasets_code(f, tensor_formats, scope_lvl, tensor_info, tensor_format_parse, test_name):
# Assuming the format is csr and csc:
for ten in tensor_format_parse.return_all_tensors():
Expand Down Expand Up @@ -539,7 +552,7 @@ def get_all_files(directory_path):
continue
out_name.append(filename[0:-3])
# checking if it is a file
print(out_name[-1])
print("Test Name:", out_name[-1])
if os.path.isfile(f):
file_paths.append(f)
return file_paths, out_name
Expand Down Expand Up @@ -810,9 +823,13 @@ def get_all_files(directory_path):
for u_ in data.get_parents()[v]:
index_value = data.get_edge_data()[v][data.get_parents()[v].index(u_)][-1]
if index_value == d[v]["inner"]:
f.write(tab(2) + d[v]["object"] + ".set_inner_crd" + "(" + d[u_]["object"] + ".out_crd())\n")
out_crd_str = get_out_crd_str(d, u_, index_value)
f.write(tab(2) + d[v]["object"] + ".set_inner_crd" + "(" + d[u_]["object"] + "." +
out_crd_str + "())\n")
if index_value == d[v]["outer"]:
f.write(tab(2) + d[v]["object"] + ".set_outer_crd" + "(" + d[u_]["object"] + ".out_crd())\n")
out_crd_str = get_out_crd_str(d, u_, index_value)
f.write(tab(2) + d[v]["object"] + ".set_outer_crd" + "(" + d[u_]["object"] + "." +
out_crd_str + "())\n")
nodes_updating_list.append(tab(2) + d[v]["object"] + ".update()\n")
# f.write(tab(2) + d[v]["object"] + ".update()\n\n")
data.add_done(v)
Expand Down Expand Up @@ -933,7 +950,6 @@ def get_all_files(directory_path):
if "val" not in data.get_edge_data()[v][i] and "spaccumulator" \
in d[u_]["object"]:
local_index = data.get_edge_data()[v][i][-1]
print(d[u_], " ", local_index, " ", apath)
if d[u_]["in0"] == local_index:
local_cord = "_inner"
else:
Expand Down

0 comments on commit 15833ac

Please sign in to comment.