Skip to content

Commit

Permalink
Tensorflow: properly register callbacks as SDFG symbols, remove argum…
Browse files Browse the repository at this point in the history
…ent dumping
  • Loading branch information
tbennun committed Oct 20, 2019
1 parent 4907c2c commit a172d11
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions dace/frontend/tensorflow/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,11 +523,6 @@ def compile(self,
strict)
self.graph.draw_to_file()
compiled_sdfg = self.graph.compile(optimizer=False)

sdfg_args_filename = os.path.join(".dacecache", name,
"sdfg_args.pickle")
with open(sdfg_args_filename, "wb") as handle:
pickle.dump(sdfg_args, handle, pickle.HIGHEST_PROTOCOL)
sdfg_args.update(self.callbackFunctionDict)

############################
Expand Down Expand Up @@ -738,6 +733,10 @@ def tensorflow_callback(tf_op, *inputList, num_outputs=0):
*callback_input_types))
self.callbackFunctionDict[node_name] = tensorflow_callback

# Register callback in SDFG
self.graph.add_symbol(node_name,
self.callbackTypeDict[node_name].dtype)

callback_tasklet = self.state.add_tasklet(
node_name,
{*taskletInputs},
Expand Down

0 comments on commit a172d11

Please sign in to comment.