Skip to content

Commit

Permalink
optimize TFPark create tf model performance (intel-analytics#2694)
Browse files Browse the repository at this point in the history
* optimize create tf model

* fix style

* fix style
  • Loading branch information
yangw1234 authored Aug 10, 2020
1 parent cd2978c commit a626d4f
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions python/orca/src/bigdl/orca/tfpark/tf_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,16 @@ def __init__(self, training_helper_layer, criterion, val_methods):
def _expand_inputs(inputs, tensors_with_value, loss):
additional_inputs = []
additional_values = []
all_required_inputs = find_placeholders([loss])
all_required_inputs_names = [v.name for v in all_required_inputs]
inputs = nest.flatten(inputs)
names = set([i.name for i in inputs])

if tensors_with_value:
for t, v in tensors_with_value.items():
if t.name in all_required_inputs_names:
additional_inputs.append(t)
additional_values.append(v)

if not isinstance(inputs, list):
inputs = nest.flatten(inputs)
if t.name in names:
msg = f"tensor {t} already in inputs, cannot put it in tensor_with_value"
raise ValueError(msg)
additional_inputs.append(t)
additional_values.append(v)

return inputs, additional_inputs, additional_values

Expand Down Expand Up @@ -442,7 +442,6 @@ def _get_vars_grads(loss):
def _get_vars_grads_from_train_op(train_op):
def predicate(t):
return t.name.split("/")[-1].startswith("zoo_identity_op_for_grad")

grads = find_tensors([train_op], predicate)
grad_ops = [grad.op for grad in grads]
variables = []
Expand Down

0 comments on commit a626d4f

Please sign in to comment.