From a626d4fa3140018b8c1cc071cd50a36b49950da1 Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Mon, 10 Aug 2020 14:27:31 +0800 Subject: [PATCH] optimize TFPark create tf model performance (#2694) * optimize create tf model * fix style * fix style --- .../orca/src/bigdl/orca/tfpark/tf_optimizer.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/python/orca/src/bigdl/orca/tfpark/tf_optimizer.py b/python/orca/src/bigdl/orca/tfpark/tf_optimizer.py index 888e2b76b9a..ee606668afe 100644 --- a/python/orca/src/bigdl/orca/tfpark/tf_optimizer.py +++ b/python/orca/src/bigdl/orca/tfpark/tf_optimizer.py @@ -113,16 +113,16 @@ def __init__(self, training_helper_layer, criterion, val_methods): def _expand_inputs(inputs, tensors_with_value, loss): additional_inputs = [] additional_values = [] - all_required_inputs = find_placeholders([loss]) - all_required_inputs_names = [v.name for v in all_required_inputs] + inputs = nest.flatten(inputs) + names = set([i.name for i in inputs]) + if tensors_with_value: for t, v in tensors_with_value.items(): - if t.name in all_required_inputs_names: - additional_inputs.append(t) - additional_values.append(v) - - if not isinstance(inputs, list): - inputs = nest.flatten(inputs) + if t.name in names: + msg = f"tensor {t} already in inputs, cannot put it in tensor_with_value" + raise ValueError(msg) + additional_inputs.append(t) + additional_values.append(v) return inputs, additional_inputs, additional_values @@ -442,7 +442,6 @@ def _get_vars_grads(loss): def _get_vars_grads_from_train_op(train_op): def predicate(t): return t.name.split("/")[-1].startswith("zoo_identity_op_for_grad") - grads = find_tensors([train_op], predicate) grad_ops = [grad.op for grad in grads] variables = []