Skip to content

Commit

Permalink
Merge pull request tensorflow#13241 from caisq/branch_169712715
Browse files Browse the repository at this point in the history
Branch 169712715
  • Loading branch information
caisq authored Sep 22, 2017
2 parents b72f1e1 + abd1c64 commit ea94bbe
Show file tree
Hide file tree
Showing 22 changed files with 271 additions and 86 deletions.
6 changes: 3 additions & 3 deletions tensorflow/cc/gradients/nn_grad_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ class NNGradTest : public ::testing::Test {
float max_error;
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
scope_, {x}, {x_shape}, {y}, {y_shape}, &max_error)));
EXPECT_LT(max_error, 2e-4);
EXPECT_LT(max_error, 2.2e-4);
}

void RunTest(const Output& x, const Tensor& x_init_value, const Output& y,
const TensorShape& y_shape) {
float max_error;
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
scope_, x, x_init_value, y, y_shape, &max_error)));
EXPECT_LT(max_error, 2e-4);
EXPECT_LT(max_error, 2.2e-4);
}

void RunTest(const OutputList& xs, const std::vector<TensorShape>& x_shapes,
Expand All @@ -53,7 +53,7 @@ class NNGradTest : public ::testing::Test {
float max_error;
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
scope_, xs, x_shapes, ys, y_shapes, &max_error)));
EXPECT_LT(max_error, 2e-4);
EXPECT_LT(max_error, 2.2e-4);
}

Scope scope_;
Expand Down
6 changes: 4 additions & 2 deletions tensorflow/compiler/jit/xla_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,10 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
XlaDevice* xla_device = dynamic_cast<XlaDevice*>(ctx->device());
if (xla_device == nullptr) {
return errors::Internal(
"GetMetadata should be called on an XLA device. This usually means an "
"internal bug or Op is placed on the wrong device.");
"Cannot get XLA metadata from non-XLA device \"", ctx->device()->name(),
"\". GetMetadata must only be called on an XLA device. Either an "
"internal bug has been triggered, or an XLA-specific op has been "
"placed on the wrong device.");
}
*metadata = &(xla_device->xla_metadata_);
return Status::OK();
Expand Down
5 changes: 2 additions & 3 deletions tensorflow/compiler/tf2xla/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ tf_kernel_library(
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/core:all_kernels",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow_opensource",
"//tensorflow/core/kernels:bounds_check",
"//tensorflow/core/kernels:concat_lib",
"//tensorflow/core/kernels:conv_ops",
Expand All @@ -112,7 +112,6 @@ tf_kernel_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow_opensource",
],
)

Expand All @@ -137,9 +136,9 @@ tf_kernel_library(
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/core:all_kernels",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:tensorflow_opensource",
"//tensorflow/core/kernels:bounds_check",
],
)
Expand Down
13 changes: 7 additions & 6 deletions tensorflow/compiler/xla/service/hlo_graph_dumper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1223,7 +1223,8 @@ XLA_REGISTER_GRAPH_RENDERER(FileGraphRenderer, 0);

string DumpGraph(const HloComputation& computation, const string& label,
const DebugOptions& debug_options,
const HloExecutionProfile* hlo_execution_profile) {
const HloExecutionProfile* hlo_execution_profile,
bool show_metadata) {
string graph;
string graph_url;
if (debug_options.xla_hlo_dump_as_graphdef()) {
Expand All @@ -1237,11 +1238,11 @@ string DumpGraph(const HloComputation& computation, const string& label,
graph_url = FileGraphRenderer().RenderGraph(
graph, GraphRendererInterface::TF_GRAPHDEF, debug_options);
} else {
graph = HloDotDumper(
&computation, label,
/*show_addresses=*/debug_options.xla_hlo_graph_addresses(),
/*show_metadata=*/false, hlo_execution_profile, NodeFilter())
.Dump();
graph =
HloDotDumper(&computation, label,
/*show_addresses=*/debug_options.xla_hlo_graph_addresses(),
show_metadata, hlo_execution_profile, NodeFilter())
.Dump();
graph_url = GetGraphRenderer()->RenderGraph(
graph, GraphRendererInterface::DOT_GRAPH, debug_options);
}
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/compiler/xla/service/hlo_graph_dumper.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ string MaybeDumpHloModule(const HloModule& module, const string& label,
// registry is used.
string DumpGraph(const HloComputation& computation, const string& label,
const DebugOptions& debug_options,
const HloExecutionProfile* hlo_execution_profile = nullptr);
const HloExecutionProfile* hlo_execution_profile = nullptr,
bool show_metadata = false);

// Like DumpGraph, but renders only nodes "near" the given node in the graph.
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from __future__ import print_function


import re
import re as _re

from tensorflow.core.framework import graph_pb2 as _graph_pb2
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
Expand Down Expand Up @@ -644,7 +644,7 @@ def _is_removed_mentioned(s, removed_op_names):
# /foo/bar. This regex ensures that we handle these two nodes
# as separate entities. It matches on nodes having names in the form of
# '/foo/bar_x' as well as nodes having names in the form of 'foo.'
s_names = re.findall(r'((?:[\/]?[a-zA-Z0-9\_]*)*)', compat.as_str_any(s))
s_names = _re.findall(r'((?:[\/]?[a-zA-Z0-9\_]*)*)', compat.as_str_any(s))
for removed_op_name in removed_op_names:
for s_name in s_names:
if s_name.endswith(removed_op_name):
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/common_runtime/pending_counts.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ limitations under the License.

namespace tensorflow {

// PendingCounts is internal helper class to keep track of pending and
// PendingCounts is an internal helper class to keep track of pending and
// dead counts for nodes, for use in the ExecutorState module. It
// holds a map from Handles to various counts for that handle. This
// information is needed per frame iteration. The amount of memory
Expand All @@ -39,7 +39,7 @@ namespace tensorflow {
// }
//
// When we actually want to start an iteration we first create a
// nPendingCounts object and then index into it using the precomputed
// PendingCounts object and then index into it using the precomputed
// handles:

// PendingCounts counts(layout);
Expand Down
74 changes: 49 additions & 25 deletions tensorflow/core/grappler/optimizers/constant_folding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,10 +414,16 @@ Status ConstantFolding::EvaluateNode(const NodeDef& node,
Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
std::vector<NodeDef>* outputs) {
TensorVector inputs;
auto inputs_cleanup = gtl::MakeCleanup([&inputs] {
TensorVector output_tensors;
auto inputs_cleanup = gtl::MakeCleanup([&inputs, &output_tensors] {
for (const auto& input : inputs) {
delete input.tensor;
}
for (const auto& output : output_tensors) {
if (output.tensor) {
delete output.tensor;
}
}
});

for (const auto& input : node.input()) {
Expand All @@ -439,7 +445,6 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
inputs.emplace_back(value);
}

TensorVector output_tensors;
TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, &output_tensors));
if (output_tensors.empty()) {
return Status(error::INVALID_ARGUMENT, "Expected at least one output.");
Expand All @@ -452,7 +457,6 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
}
if (output_tensors[i].tensor) {
outputs->push_back(CreateNodeDef(node_name, output_tensors[i]));
delete output_tensors[i].tensor;
} else {
// Create an empty NodeDef to identify dead outputs (e.g. the output of a
// switch that's not selected by the switch predicate).
Expand All @@ -462,7 +466,7 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
return Status::OK();
}

Status ConstantFolding::FoldNode(NodeDef* node) {
Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph) {
if (IsMerge(*node)) {
// Merge nodes are special, in the sense that they execute as soon as one of
// their input is ready. We can therefore fold a merge node iff it has at
Expand Down Expand Up @@ -511,15 +515,15 @@ Status ConstantFolding::FoldNode(NodeDef* node) {
"already present in the graph"));
}

NodeDef* const_out = added_graph_.add_node();
NodeDef* const_out = output_graph->add_node();
*const_out = *input_node;
const_out->set_name(const_out_name);
const_out->set_device(node->device());
*const_out->add_input() = AsControlDependency(*node);
node_map_->AddNode(const_out->name(), const_out);
node_map_->AddOutput(node->name(), const_out->name());

NodeDef* const_index = added_graph_.add_node();
NodeDef* const_index = output_graph->add_node();
const_index->set_op("Const");
Tensor index(DT_INT32, TensorShape({}));
index.flat<int32>()(0) = input_index;
Expand Down Expand Up @@ -608,7 +612,7 @@ Status ConstantFolding::FoldNode(NodeDef* node) {
return errors::AlreadyExists(strings::StrCat(
const_node->name(), "already present in the graph"));
}
NodeDef* added_node = added_graph_.add_node();
NodeDef* added_node = output_graph->add_node();
*added_node = *const_node;
added_node->set_device(node->device());
node_map_->AddNode(added_node->name(), added_node);
Expand Down Expand Up @@ -679,7 +683,7 @@ Status ConstantFolding::FoldGraph(GraphDef* output) {
}
// We need to record a copy of output nodes before FoldNode() modifies it.
std::set<NodeDef*> outputs = node_map_->GetOutputs(node->name());
Status s = FoldNode(node);
Status s = FoldNode(node, output);
processed_nodes.insert(node->name());
if (!s.ok()) {
VLOG(1) << "Failed to fold node " << node->name() << ": " << s;
Expand All @@ -692,14 +696,19 @@ Status ConstantFolding::FoldGraph(GraphDef* output) {
}
}

// Build the graph after constant folding.
for (const auto& node : added_graph_.node()) {
// Delete the newly created nodes that don't feed anything.
int last = output->node_size() - 1;
for (int i = output->node_size() - 1; i >= 0; --i) {
const NodeDef& node = output->node(i);
auto outputs = node_map_->GetOutputs(node.name());
if (!outputs.empty()) {
auto added_node = output->add_node();
*added_node = node;
if (outputs.empty()) {
output->mutable_node()->SwapElements(i, last);
last--;
}
}
output->mutable_node()->DeleteSubrange(last + 1,
output->node_size() - last - 1);

for (const auto& node : graph_.node()) {
// If no fetch nodes is provided, we conservatively
// keep all nodes in the original graph in case users need to fetch
Expand Down Expand Up @@ -843,11 +852,11 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
return Status::OK();
}

Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) {
graph_ = item.graph;
Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
const GrapplerItem& item,
GraphDef* output) {
node_map_.reset(new NodeMap(&graph_));
nodes_to_preserve_ = item.NodesToPreserve();
nodes_whitelist_.clear();
// Fold fetch nodes iff it has a single fanout. Note that if a fetch node
// has a single fanout, it would be rewritten as a constant with the same
// node name, and therefore users are still able to fetch it. This is not
Expand All @@ -861,16 +870,9 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
nodes_whitelist_.insert(fetch_node->name());
}
}
*output = GraphDef();
if (cpu_device_ == nullptr) {
owned_device_.reset(new DeviceSimple());
cpu_device_ = owned_device_.get();
}

bool has_feed = !item.feed.empty();
has_fetch_ = !item.fetch.empty();

GraphProperties properties(item);
bool has_feed = !item.feed.empty();
if (!has_feed) {
// Only use static shape information when there is no feed in the
// graph. That's because it's possible to feed a placeholder with a tensor
Expand All @@ -889,6 +891,28 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
if (!has_feed) {
TF_RETURN_IF_ERROR(SimplifyGraph(output, properties));
}
return Status::OK();
}

Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) {
nodes_to_preserve_ = item.NodesToPreserve();

if (cpu_device_ == nullptr) {
owned_device_.reset(new DeviceSimple());
cpu_device_ = owned_device_.get();
}

has_fetch_ = !item.fetch.empty();

GrapplerItem item_to_optimize = item;
*output = item.graph;
do {
graph_.Swap(output);
item_to_optimize.graph = graph_;
*output = GraphDef();
TF_RETURN_IF_ERROR(RunOptimizationPass(cluster, item_to_optimize, output));
} while (output->node_size() < graph_.node_size());

*output->mutable_library() = item.graph.library();
*output->mutable_versions() = item.graph.versions();
Expand Down
6 changes: 4 additions & 2 deletions tensorflow/core/grappler/optimizers/constant_folding.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class ConstantFolding : public GraphOptimizer {
Status EvaluateOneFoldable(const NodeDef& node,
std::vector<NodeDef>* outputs);

Status FoldNode(NodeDef* node);
Status FoldNode(NodeDef* node, GraphDef* output_graph);

Status FoldGraph(GraphDef* output);

Expand All @@ -69,6 +69,9 @@ class ConstantFolding : public GraphOptimizer {
const GraphProperties& properties) const;
Status SimplifyGraph(GraphDef* output, const GraphProperties& properties);

Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item,
GraphDef* output);

// Points to an externally provided device or to owned_device_;
DeviceBase* cpu_device_;
std::unique_ptr<DeviceBase> owned_device_;
Expand All @@ -78,7 +81,6 @@ class ConstantFolding : public GraphOptimizer {
std::unique_ptr<NodeMap> node_map_;
std::unordered_set<string> nodes_to_preserve_;
std::unordered_set<string> nodes_whitelist_;
GraphDef added_graph_;
bool has_fetch_;
};

Expand Down
13 changes: 5 additions & 8 deletions tensorflow/java/src/gen/gen_ops.bzl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# -*- Python -*-

load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_copts")
load(
"//tensorflow/core:platform/default/build_config_root.bzl",
"if_static")
load("//tensorflow:tensorflow.bzl",
"tf_binary_additional_srcs",
"tf_cc_binary",
"tf_copts")

# Given a list of "ops_libs" (a list of files in the core/ops directory
# without their .cc extensions), generate Java wrapper code for all operations
Expand Down Expand Up @@ -54,10 +54,7 @@ def tf_java_op_gen_srcjar(name,
gen_srcjar = out_dir + name + ".srcjar"
gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) ."]
gen_tools += ["@local_jdk//:jar"] + ["@local_jdk//:jdk"]
gen_tools += if_static(
extra_deps=[],
otherwise=["//tensorflow:libtensorflow_framework.so"]
)
gen_tools += tf_binary_additional_srcs()
native.genrule(
name=name,
outs=[gen_srcjar],
Expand Down
Loading

0 comments on commit ea94bbe

Please sign in to comment.