Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branch 169712715 #13241

Merged
merged 15 commits into from
Sep 22, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tensorflow/cc/gradients/nn_grad_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ class NNGradTest : public ::testing::Test {
float max_error;
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
scope_, {x}, {x_shape}, {y}, {y_shape}, &max_error)));
EXPECT_LT(max_error, 2e-4);
EXPECT_LT(max_error, 2.2e-4);
}

void RunTest(const Output& x, const Tensor& x_init_value, const Output& y,
const TensorShape& y_shape) {
float max_error;
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
scope_, x, x_init_value, y, y_shape, &max_error)));
EXPECT_LT(max_error, 2e-4);
EXPECT_LT(max_error, 2.2e-4);
}

void RunTest(const OutputList& xs, const std::vector<TensorShape>& x_shapes,
Expand All @@ -53,7 +53,7 @@ class NNGradTest : public ::testing::Test {
float max_error;
TF_ASSERT_OK((ComputeGradientError<float, float, float>(
scope_, xs, x_shapes, ys, y_shapes, &max_error)));
EXPECT_LT(max_error, 2e-4);
EXPECT_LT(max_error, 2.2e-4);
}

Scope scope_;
Expand Down
6 changes: 4 additions & 2 deletions tensorflow/compiler/jit/xla_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,10 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
XlaDevice* xla_device = dynamic_cast<XlaDevice*>(ctx->device());
if (xla_device == nullptr) {
return errors::Internal(
"GetMetadata should be called on an XLA device. This usually means an "
"internal bug or Op is placed on the wrong device.");
"Cannot get XLA metadata from non-XLA device \"", ctx->device()->name(),
"\". GetMetadata must only be called on an XLA device. Either an "
"internal bug has been triggered, or an XLA-specific op has been "
"placed on the wrong device.");
}
*metadata = &(xla_device->xla_metadata_);
return Status::OK();
Expand Down
5 changes: 2 additions & 3 deletions tensorflow/compiler/tf2xla/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ tf_kernel_library(
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/core:all_kernels",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow_opensource",
"//tensorflow/core/kernels:bounds_check",
"//tensorflow/core/kernels:concat_lib",
"//tensorflow/core/kernels:conv_ops",
Expand All @@ -112,7 +112,6 @@ tf_kernel_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow_opensource",
],
)

Expand All @@ -137,9 +136,9 @@ tf_kernel_library(
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/core:all_kernels",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:tensorflow_opensource",
"//tensorflow/core/kernels:bounds_check",
],
)
Expand Down
13 changes: 7 additions & 6 deletions tensorflow/compiler/xla/service/hlo_graph_dumper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1223,7 +1223,8 @@ XLA_REGISTER_GRAPH_RENDERER(FileGraphRenderer, 0);

string DumpGraph(const HloComputation& computation, const string& label,
const DebugOptions& debug_options,
const HloExecutionProfile* hlo_execution_profile) {
const HloExecutionProfile* hlo_execution_profile,
bool show_metadata) {
string graph;
string graph_url;
if (debug_options.xla_hlo_dump_as_graphdef()) {
Expand All @@ -1237,11 +1238,11 @@ string DumpGraph(const HloComputation& computation, const string& label,
graph_url = FileGraphRenderer().RenderGraph(
graph, GraphRendererInterface::TF_GRAPHDEF, debug_options);
} else {
graph = HloDotDumper(
&computation, label,
/*show_addresses=*/debug_options.xla_hlo_graph_addresses(),
/*show_metadata=*/false, hlo_execution_profile, NodeFilter())
.Dump();
graph =
HloDotDumper(&computation, label,
/*show_addresses=*/debug_options.xla_hlo_graph_addresses(),
show_metadata, hlo_execution_profile, NodeFilter())
.Dump();
graph_url = GetGraphRenderer()->RenderGraph(
graph, GraphRendererInterface::DOT_GRAPH, debug_options);
}
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/compiler/xla/service/hlo_graph_dumper.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ string MaybeDumpHloModule(const HloModule& module, const string& label,
// registry is used.
string DumpGraph(const HloComputation& computation, const string& label,
const DebugOptions& debug_options,
const HloExecutionProfile* hlo_execution_profile = nullptr);
const HloExecutionProfile* hlo_execution_profile = nullptr,
bool show_metadata = false);

// Like DumpGraph, but renders only nodes "near" the given node in the graph.
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from __future__ import print_function


import re
import re as _re

from tensorflow.core.framework import graph_pb2 as _graph_pb2
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
Expand Down Expand Up @@ -644,7 +644,7 @@ def _is_removed_mentioned(s, removed_op_names):
# /foo/bar. This regex ensures that we handle these two nodes
# as separate entities. It matches on nodes having names in the form of
# '/foo/bar_x' as well as nodes having names in the form of 'foo.'
s_names = re.findall(r'((?:[\/]?[a-zA-Z0-9\_]*)*)', compat.as_str_any(s))
s_names = _re.findall(r'((?:[\/]?[a-zA-Z0-9\_]*)*)', compat.as_str_any(s))
for removed_op_name in removed_op_names:
for s_name in s_names:
if s_name.endswith(removed_op_name):
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/common_runtime/pending_counts.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ limitations under the License.

namespace tensorflow {

// PendingCounts is internal helper class to keep track of pending and
// PendingCounts is an internal helper class to keep track of pending and
// dead counts for nodes, for use in the ExecutorState module. It
// holds a map from Handles to various counts for that handle. This
// information is needed per frame iteration. The amount of memory
Expand All @@ -39,7 +39,7 @@ namespace tensorflow {
// }
//
// When we actually want to start an iteration we first create a
// nPendingCounts object and then index into it using the precomputed
// PendingCounts object and then index into it using the precomputed
// handles:

// PendingCounts counts(layout);
Expand Down
74 changes: 49 additions & 25 deletions tensorflow/core/grappler/optimizers/constant_folding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,10 +414,16 @@ Status ConstantFolding::EvaluateNode(const NodeDef& node,
Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
std::vector<NodeDef>* outputs) {
TensorVector inputs;
auto inputs_cleanup = gtl::MakeCleanup([&inputs] {
TensorVector output_tensors;
auto inputs_cleanup = gtl::MakeCleanup([&inputs, &output_tensors] {
for (const auto& input : inputs) {
delete input.tensor;
}
for (const auto& output : output_tensors) {
if (output.tensor) {
delete output.tensor;
}
}
});

for (const auto& input : node.input()) {
Expand All @@ -439,7 +445,6 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
inputs.emplace_back(value);
}

TensorVector output_tensors;
TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, &output_tensors));
if (output_tensors.empty()) {
return Status(error::INVALID_ARGUMENT, "Expected at least one output.");
Expand All @@ -452,7 +457,6 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
}
if (output_tensors[i].tensor) {
outputs->push_back(CreateNodeDef(node_name, output_tensors[i]));
delete output_tensors[i].tensor;
} else {
// Create an empty NodeDef to identify dead outputs (e.g. the output of a
// switch that's not selected by the switch predicate).
Expand All @@ -462,7 +466,7 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
return Status::OK();
}

Status ConstantFolding::FoldNode(NodeDef* node) {
Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph) {
if (IsMerge(*node)) {
// Merge nodes are special, in the sense that they execute as soon as one of
// their input is ready. We can therefore fold a merge node iff it has at
Expand Down Expand Up @@ -511,15 +515,15 @@ Status ConstantFolding::FoldNode(NodeDef* node) {
"already present in the graph"));
}

NodeDef* const_out = added_graph_.add_node();
NodeDef* const_out = output_graph->add_node();
*const_out = *input_node;
const_out->set_name(const_out_name);
const_out->set_device(node->device());
*const_out->add_input() = AsControlDependency(*node);
node_map_->AddNode(const_out->name(), const_out);
node_map_->AddOutput(node->name(), const_out->name());

NodeDef* const_index = added_graph_.add_node();
NodeDef* const_index = output_graph->add_node();
const_index->set_op("Const");
Tensor index(DT_INT32, TensorShape({}));
index.flat<int32>()(0) = input_index;
Expand Down Expand Up @@ -608,7 +612,7 @@ Status ConstantFolding::FoldNode(NodeDef* node) {
return errors::AlreadyExists(strings::StrCat(
const_node->name(), "already present in the graph"));
}
NodeDef* added_node = added_graph_.add_node();
NodeDef* added_node = output_graph->add_node();
*added_node = *const_node;
added_node->set_device(node->device());
node_map_->AddNode(added_node->name(), added_node);
Expand Down Expand Up @@ -679,7 +683,7 @@ Status ConstantFolding::FoldGraph(GraphDef* output) {
}
// We need to record a copy of output nodes before FoldNode() modifies it.
std::set<NodeDef*> outputs = node_map_->GetOutputs(node->name());
Status s = FoldNode(node);
Status s = FoldNode(node, output);
processed_nodes.insert(node->name());
if (!s.ok()) {
VLOG(1) << "Failed to fold node " << node->name() << ": " << s;
Expand All @@ -692,14 +696,19 @@ Status ConstantFolding::FoldGraph(GraphDef* output) {
}
}

// Build the graph after constant folding.
for (const auto& node : added_graph_.node()) {
// Delete the newly created nodes that don't feed anything.
int last = output->node_size() - 1;
for (int i = output->node_size() - 1; i >= 0; --i) {
const NodeDef& node = output->node(i);
auto outputs = node_map_->GetOutputs(node.name());
if (!outputs.empty()) {
auto added_node = output->add_node();
*added_node = node;
if (outputs.empty()) {
output->mutable_node()->SwapElements(i, last);
last--;
}
}
output->mutable_node()->DeleteSubrange(last + 1,
output->node_size() - last - 1);

for (const auto& node : graph_.node()) {
// If no fetch nodes is provided, we conservatively
// keep all nodes in the original graph in case users need to fetch
Expand Down Expand Up @@ -843,11 +852,11 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
return Status::OK();
}

Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) {
graph_ = item.graph;
Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
const GrapplerItem& item,
GraphDef* output) {
node_map_.reset(new NodeMap(&graph_));
nodes_to_preserve_ = item.NodesToPreserve();
nodes_whitelist_.clear();
// Fold fetch nodes iff it has a single fanout. Note that if a fetch node
// has a single fanout, it would be rewritten as a constant with the same
// node name, and therefore users are still able to fetch it. This is not
Expand All @@ -861,16 +870,9 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
nodes_whitelist_.insert(fetch_node->name());
}
}
*output = GraphDef();
if (cpu_device_ == nullptr) {
owned_device_.reset(new DeviceSimple());
cpu_device_ = owned_device_.get();
}

bool has_feed = !item.feed.empty();
has_fetch_ = !item.fetch.empty();

GraphProperties properties(item);
bool has_feed = !item.feed.empty();
if (!has_feed) {
// Only use static shape information when there is no feed in the
// graph. That's because it's possible to feed a placeholder with a tensor
Expand All @@ -889,6 +891,28 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
if (!has_feed) {
TF_RETURN_IF_ERROR(SimplifyGraph(output, properties));
}
return Status::OK();
}

Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) {
nodes_to_preserve_ = item.NodesToPreserve();

if (cpu_device_ == nullptr) {
owned_device_.reset(new DeviceSimple());
cpu_device_ = owned_device_.get();
}

has_fetch_ = !item.fetch.empty();

GrapplerItem item_to_optimize = item;
*output = item.graph;
do {
graph_.Swap(output);
item_to_optimize.graph = graph_;
*output = GraphDef();
TF_RETURN_IF_ERROR(RunOptimizationPass(cluster, item_to_optimize, output));
} while (output->node_size() < graph_.node_size());

*output->mutable_library() = item.graph.library();
*output->mutable_versions() = item.graph.versions();
Expand Down
6 changes: 4 additions & 2 deletions tensorflow/core/grappler/optimizers/constant_folding.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class ConstantFolding : public GraphOptimizer {
Status EvaluateOneFoldable(const NodeDef& node,
std::vector<NodeDef>* outputs);

Status FoldNode(NodeDef* node);
Status FoldNode(NodeDef* node, GraphDef* output_graph);

Status FoldGraph(GraphDef* output);

Expand All @@ -69,6 +69,9 @@ class ConstantFolding : public GraphOptimizer {
const GraphProperties& properties) const;
Status SimplifyGraph(GraphDef* output, const GraphProperties& properties);

Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item,
GraphDef* output);

// Points to an externally provided device or to owned_device_;
DeviceBase* cpu_device_;
std::unique_ptr<DeviceBase> owned_device_;
Expand All @@ -78,7 +81,6 @@ class ConstantFolding : public GraphOptimizer {
std::unique_ptr<NodeMap> node_map_;
std::unordered_set<string> nodes_to_preserve_;
std::unordered_set<string> nodes_whitelist_;
GraphDef added_graph_;
bool has_fetch_;
};

Expand Down
13 changes: 5 additions & 8 deletions tensorflow/java/src/gen/gen_ops.bzl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# -*- Python -*-

load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_copts")
load(
"//tensorflow/core:platform/default/build_config_root.bzl",
"if_static")
load("//tensorflow:tensorflow.bzl",
"tf_binary_additional_srcs",
"tf_cc_binary",
"tf_copts")

# Given a list of "ops_libs" (a list of files in the core/ops directory
# without their .cc extensions), generate Java wrapper code for all operations
Expand Down Expand Up @@ -54,10 +54,7 @@ def tf_java_op_gen_srcjar(name,
gen_srcjar = out_dir + name + ".srcjar"
gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) ."]
gen_tools += ["@local_jdk//:jar"] + ["@local_jdk//:jdk"]
gen_tools += if_static(
extra_deps=[],
otherwise=["//tensorflow:libtensorflow_framework.so"]
)
gen_tools += tf_binary_additional_srcs()
native.genrule(
name=name,
outs=[gen_srcjar],
Expand Down
Loading