diff --git a/.gitignore b/.gitignore index 8d7808ba01e7b6..90324058600bee 100644 --- a/.gitignore +++ b/.gitignore @@ -24,10 +24,10 @@ Pods Podfile.lock *.pbxproj *.xcworkspacedata -/tensorflow/contrib/lite/tools/make/downloads/** -/tensorflow/contrib/lite/gen/** -/tensorflow/contrib/lite/examples/ios/simple/data/*.txt -/tensorflow/contrib/lite/examples/ios/simple/data/*.tflite +/tensorflow/lite/tools/make/downloads/** +/tensorflow/lite/gen/** +/tensorflow/lite/examples/ios/simple/data/*.txt +/tensorflow/lite/examples/ios/simple/data/*.tflite xcuserdata/** /api_init_files_list.txt /estimator_api_init_files_list.txt diff --git a/BUILD b/BUILD index 4bf647e47aa56c..1200cf5f7103ca 100644 --- a/BUILD +++ b/BUILD @@ -2,5 +2,7 @@ exports_files( [ "LICENSE", "ACKNOWLEDGEMENTS", + "configure", + "configure.py", ], ) diff --git a/RELEASE.md b/RELEASE.md index 2b00d06580d925..b13b071bd6cf4d 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -258,8 +258,8 @@ Ag Ramesh, Alex Wiltschko, Alexander Pantyukhin, Amogh Mannekote, An Jiaoyang, A * Update `tf.keras` to the Keras 2.1.6 API. * Added [`tf.keras.layers.CuDNNGRU`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNGRU) and [`tf.keras.layers.CuDNNLSTM`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNLSTM) layers. [Try it](https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb?linkId=53292082). * Adding support of core [feature columns](https://www.tensorflow.org/get_started/feature_columns) and [losses](https://www.tensorflow.org/api_docs/python/tf/losses) to [gradient boosted trees estimators](https://github.com/tensorflow/models/tree/master/official/boosted_trees). -* The [python interface](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/lite) - for the [TFLite Optimizing Converter](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/README.md) +* The [python interface](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/lite) + for the [TFLite Optimizing Converter](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/toco/README.md) has been expanded, and the command line interface (AKA: `toco`, `tflite_convert`) is once again included in the standard `pip` installation. * Improved data-loading and text processing with: @@ -562,7 +562,7 @@ Yoni Tsafir, yordun, Yuan (Terry) Tang, Yuxin Wu, zhengdi, Zhengsheng Wei, 田 ## Major Features And Improvements * [Eager execution](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/eager) preview version is now available. -* [TensorFlow Lite](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/lite) +* [TensorFlow Lite](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/lite) dev preview is now available. * CUDA 9.0 and cuDNN 7 support. * Accelerated Linear Algebra (XLA): @@ -909,7 +909,7 @@ See also [TensorBoard 0.1.4](https://github.com/tensorflow/tensorboard/releases/ * Adds tf.contrib.nn.rank_sampled_softmax_loss, a sampled-softmax variant that can improve rank loss. * `tf.contrib.metrics`.{streaming_covariance,streaming_pearson_correlation} modified to return nan when they have seen less or equal to 1 unit of weight. * Adds time series models to contrib. See contrib/timeseries/README.md for details. -* Adds FULLY_CONNECTED Op to tensorflow/contrib/lite/schema.fbs +* Adds FULLY_CONNECTED Op to tensorflow/lite/schema.fbs ## Known Issues * Tensorflow_gpu compilation fails with Bazel 0.5.3. diff --git a/configure.py b/configure.py index 42aab032d29dc8..2eeeceb3399c79 100644 --- a/configure.py +++ b/configure.py @@ -1418,11 +1418,16 @@ def set_mpi_home(environ_cp): def valid_mpi_path(mpi_home): exists = ( os.path.exists(os.path.join(mpi_home, 'include')) and - os.path.exists(os.path.join(mpi_home, 'lib'))) + (os.path.exists(os.path.join(mpi_home, 'lib')) or + os.path.exists(os.path.join(mpi_home, 'lib64')) or + os.path.exists(os.path.join(mpi_home, 'lib32')))) if not exists: - print('Invalid path to the MPI Toolkit. %s or %s cannot be found' % - (os.path.join(mpi_home, 'include'), - os.path.exists(os.path.join(mpi_home, 'lib')))) + print( + 'Invalid path to the MPI Toolkit. %s or %s or %s or %s cannot be found' + % (os.path.join(mpi_home, 'include'), + os.path.exists(os.path.join(mpi_home, 'lib')), + os.path.exists(os.path.join(mpi_home, 'lib64')), + os.path.exists(os.path.join(mpi_home, 'lib32')))) return exists _ = prompt_loop_or_load_from_env( @@ -1463,8 +1468,17 @@ def set_other_mpi_vars(environ_cp): if os.path.exists(os.path.join(mpi_home, 'lib/libmpi.so')): symlink_force( os.path.join(mpi_home, 'lib/libmpi.so'), 'third_party/mpi/libmpi.so') + elif os.path.exists(os.path.join(mpi_home, 'lib64/libmpi.so')): + symlink_force( + os.path.join(mpi_home, 'lib64/libmpi.so'), 'third_party/mpi/libmpi.so') + elif os.path.exists(os.path.join(mpi_home, 'lib32/libmpi.so')): + symlink_force( + os.path.join(mpi_home, 'lib32/libmpi.so'), 'third_party/mpi/libmpi.so') + else: - raise ValueError('Cannot find the MPI library file in %s/lib' % mpi_home) + raise ValueError( + 'Cannot find the MPI library file in %s/lib or %s/lib64 or %s/lib32' % + mpi_home, mpi_home, mpi_home) def set_system_libs_flag(environ_cp): @@ -1681,4 +1695,3 @@ def main(): if __name__ == '__main__': main() - diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 182aa211766b70..0d497568385052 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -55,4 +55,10 @@ # does not have 'python', 'core' directories. Then, it will be copied # to tensorflow/ which does have these two directories. pass +# Similarly for compiler. Do it separately to make sure we do this even if the +# others don't exist. +try: + del compiler +except NameError: + pass # pylint: enable=undefined-variable diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index 797cda991fde7c..65bdb6cb1b5e6f 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -63,4 +63,10 @@ # does not have 'python', 'core' directories. Then, it will be copied # to tensorflow/ which does have these two directories. pass +# Similarly for compiler. Do it separately to make sure we do this even if the +# others don't exist. +try: + del compiler +except NameError: + pass # pylint: enable=undefined-variable diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 33f634a8ec5846..dbe8dba2924a7a 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -199,6 +199,7 @@ tf_cuda_cc_test( size = "small", srcs = ["c_api_test.cc"], data = [ + ":test_op1.so", "//tensorflow/cc/saved_model:saved_model_half_plus_two", ], kernels = [":test_op_kernel"], @@ -283,8 +284,8 @@ tf_cc_test( ) tf_custom_op_library( - name = "test_op.so", - srcs = ["test_op.cc"], + name = "test_op1.so", + srcs = ["test_op1.cc"], ) tf_kernel_library( diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index bff0313da4597a..fabe2fa0f60bc8 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -8775,3 +8775,28 @@ void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder, tensorflow::DeviceType(device_type), builder->BuildNodeDef(), /* def = */ nullptr, /* kernel_class_name = */ nullptr); } + +const char* TF_GetNumberAttrForOpListInput(const char* op_name, int input_index, + TF_Status* status) { + const tensorflow::OpDef* op_def = nullptr; + status->status = + tensorflow::OpRegistry::Global()->LookUpOpDef(op_name, &op_def); + if (!status->status.ok()) return nullptr; + + if (input_index >= op_def->input_arg_size() || input_index < 0) { + status->status = tensorflow::errors::InvalidArgument( + input_index, " out of range for ", op_name); + return nullptr; + } + + const tensorflow::OpDef_ArgDef& input_arg = op_def->input_arg()[input_index]; + + if (input_arg.number_attr().empty()) { + status->status = tensorflow::errors::NotFound( + op_name, " does not have number_attr() defined."); + return nullptr; + } + + // The returned string is owned by OpRegistry, so liveness is not a concern. + return input_arg.number_attr().c_str(); +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index b4f8635e6703a5..6639b0be72bdf8 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -202,6 +202,13 @@ TF_CAPI_EXPORT extern void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, TF_CAPI_EXPORT extern void TF_AttrBuilderCheckCanRunOnDevice( TF_AttrBuilder* builder, const char* device_type, TF_Status* status); +// For argument number input_index, fetch the corresponding number_attr that +// needs to be updated with the argument length of the input list. +// Returns nullptr if there is any problem like op_name is not found, or the +// argument does not support this attribute type. +TF_CAPI_EXPORT extern const char* TF_GetNumberAttrForOpListInput( + const char* op_name, int input_index, TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index c4746b4990bc3b..d5934a10395ae0 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -187,15 +187,26 @@ TEST(CAPI, LibraryLoadFunctions) { // tf_cuda_cc_test() bazel rule and remove the next line. if (!GPUDeviceName().empty()) return; - // Load the library. - TF_Status* status = TF_NewStatus(); - TF_Library* lib = - TF_LoadLibrary("tensorflow/c/test_op.so", status); - TF_Code code = TF_GetCode(status); - string status_msg(TF_Message(status)); - TF_DeleteStatus(status); - ASSERT_EQ(TF_OK, code) << status_msg; +#if !defined(TENSORFLOW_NO_SHARED_OBJECTS) + { + // Load the library. + TF_Status* status = TF_NewStatus(); + TF_Library* lib = + TF_LoadLibrary("tensorflow/c/test_op1.so", status); + TF_Code code = TF_GetCode(status); + string status_msg(TF_Message(status)); + TF_DeleteStatus(status); + ASSERT_EQ(TF_OK, code) << status_msg; + // Test op list. + TF_Buffer op_list_buf = TF_GetOpList(lib); + tensorflow::OpList op_list; + EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length)); + ASSERT_EQ(op_list.op_size(), 1); + EXPECT_EQ("TestCApi1", op_list.op(0).name()); + TF_DeleteLibraryHandle(lib); + } +#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS) { TF_Buffer* op_list_buffer = TF_GetAllOpList(); tensorflow::OpList op_list; @@ -210,19 +221,6 @@ TEST(CAPI, LibraryLoadFunctions) { EXPECT_TRUE(found); TF_DeleteBuffer(op_list_buffer); } - -#if !defined(TENSORFLOW_NO_SHARED_OBJECTS) - { - // Test op list. - TF_Buffer op_list_buf = TF_GetOpList(lib); - tensorflow::OpList op_list; - EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length)); - ASSERT_EQ(op_list.op_size(), 1); - EXPECT_EQ("TestCApi", op_list.op(0).name()); - } -#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS) - - TF_DeleteLibraryHandle(lib); } void TestEncodeDecode(int line, const std::vector& data) { @@ -2349,14 +2347,8 @@ TEST(TestApiDef, TestCreateApiDef) { // tf_cuda_cc_test() bazel rule and remove the next line. if (!GPUDeviceName().empty()) return; - TF_Status* status = TF_NewStatus(); - TF_Library* lib = - TF_LoadLibrary("tensorflow/c/test_op.so", status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteStatus(status); - TF_Buffer* op_list_buf = TF_GetAllOpList(); - status = TF_NewStatus(); + TF_Status* status = TF_NewStatus(); auto* api_def_map = TF_NewApiDefMap(op_list_buf, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); @@ -2376,7 +2368,6 @@ TEST(TestApiDef, TestCreateApiDef) { TF_DeleteBuffer(api_def_buf); TF_DeleteApiDefMap(api_def_map); TF_DeleteBuffer(op_list_buf); - TF_DeleteLibraryHandle(lib); } TEST(TestApiDef, TestCreateApiDefWithOverwrites) { @@ -2384,14 +2375,8 @@ TEST(TestApiDef, TestCreateApiDefWithOverwrites) { // tf_cuda_cc_test() bazel rule and remove the next line. if (!GPUDeviceName().empty()) return; - TF_Status* status = TF_NewStatus(); - TF_Library* lib = - TF_LoadLibrary("tensorflow/c/test_op.so", status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteStatus(status); - TF_Buffer* op_list_buf = TF_GetAllOpList(); - status = TF_NewStatus(); + TF_Status* status = TF_NewStatus(); auto* api_def_map = TF_NewApiDefMap(op_list_buf, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); @@ -2422,7 +2407,6 @@ TEST(TestApiDef, TestCreateApiDefWithOverwrites) { TF_DeleteBuffer(api_def_buf); TF_DeleteApiDefMap(api_def_map); TF_DeleteBuffer(op_list_buf); - TF_DeleteLibraryHandle(lib); } class DummyKernel : public tensorflow::OpKernel { diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 3ee31a6a7ac641..ba3d8533db7623 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -69,7 +69,7 @@ tf_cuda_library( name = "c_api_internal", hdrs = ["c_api_internal.h"], visibility = [ - "//learning/deepmind/courier:__pkg__", + "//learning/deepmind/courier:__subpackages__", "//tensorflow:internal", ], deps = [ diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 104d52430cf7aa..fa1b22e3af487b 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -79,10 +79,6 @@ struct TFE_TensorHandle { tensorflow::Device* op_device) : handle(new tensorflow::TensorHandle(t, d, op_device, nullptr)) {} - TFE_TensorHandle(tensorflow::uint64 node_id, tensorflow::DataType dtype, - tensorflow::EagerContext* ctx) - : handle(new tensorflow::TensorHandle(node_id, dtype, ctx)) {} - TFE_TensorHandle(tensorflow::TensorHandle* handle) : handle(handle) {} tensorflow::TensorHandle* handle; diff --git a/tensorflow/stream_executor/lib/stringpiece.h b/tensorflow/c/test_op1.cc similarity index 60% rename from tensorflow/stream_executor/lib/stringpiece.h rename to tensorflow/c/test_op1.cc index 76249101298588..b22cc9aef2b344 100644 --- a/tensorflow/stream_executor/lib/stringpiece.h +++ b/tensorflow/c/test_op1.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,17 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_ -#define TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" -#include "absl/strings/string_view.h" +namespace tensorflow { -namespace stream_executor { -namespace port { +REGISTER_OP("TestCApi1").Doc(R"doc(Used to test C API)doc"); -using StringPiece = absl::string_view; - -} // namespace port -} // namespace stream_executor - -#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_ +} // namespace tensorflow diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index c18b07603ae384..83353b79f722f0 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -170,6 +170,7 @@ cc_library_with_android_deps( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -516,6 +517,8 @@ tf_gen_op_wrappers_cc( ":array_ops", ":const_op", ":math_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", ], ) diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index b49f360875c522..f98ba48735451c 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -190,11 +190,13 @@ cc_library( "//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/kernels:sendrecv_ops", "//tensorflow/core/kernels:shape_ops", + "//tensorflow/core/kernels:stack", "//tensorflow/core/kernels:variable_ops", "//tensorflow/core/kernels/data:generator_dataset_op", "//tensorflow/core/kernels/data:iterator_ops", "//tensorflow/core/kernels/data:prefetch_dataset_op", "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", ], ) @@ -240,6 +242,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:variable_ops", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", ], ) @@ -499,6 +502,7 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index 054f31ba3352b2..93637a69d5d7b6 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -214,7 +214,8 @@ Status NodeRequiresCompilation(Node* n, bool* result) { return errors::Internal("Could not find compilation device ", device_type.type()); } - *result = registration->requires_compilation; + *result = registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways; return Status::OK(); } diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index e29da8500f9ce0..0562838f628c66 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -525,7 +525,6 @@ Predicate* PredicateFactory::MakeAndOrImpl( op->GetOperands().begin(), op->GetOperands().end()); } else { - std::vector sub_ops_intersection; common_inner_operands.clear(); absl::c_copy_if(op->GetOperands(), std::back_inserter(common_inner_operands), diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 617e31488c7dae..8a73101c184e61 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -127,7 +127,8 @@ InductionVarInfo CreateInductionVariable(const Scope& root, Output loop_cond = ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr); ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond); - ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output); + ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), + latch.output_false); Output iv_next = ops::Add(root.WithOpName(prefix + "/ivnext"), latch.output_true, increment_by); Output next_iteration = @@ -191,7 +192,8 @@ DependentInductionVar CreateDependentLoopInvariantValue( value, frame_name); ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_value, enter_value}); ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond); - ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output); + ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), + latch.output_false); Output next_iteration = ops::NextIteration( root.WithOpName(prefix + "/next_iteration"), latch.output_true); CHECK(root.graph() diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index da030b3bcc7aac..f478832781cb1d 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -1122,8 +1122,11 @@ Status Encapsulator::Subgraph::BuildFunctionDef( fdef); } - if (!reuse_existing_functions || library->Find(name) == nullptr) { + const FunctionDef* original_fdef = library->Find(name); + if (!reuse_existing_functions || original_fdef == nullptr) { TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); + } else if (!FunctionDefsEqual(*original_fdef, fdef)) { + TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef)); } return Status::OK(); } diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index 363003de5073ee..70b019d35fc80c 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -51,6 +51,12 @@ xla::StatusOr AddHostComputeKeyPlaceholder( return n; } +// Returns if the node is a XLA computation key placeholder. +bool IsKeyPlaceholderNode(const Node& n) { + return n.type_string() == "Placeholder" && + absl::EndsWith(n.name(), "_key_placeholder"); +} + // Returns nodes with given type. std::vector GatherNodesWithType(const Graph& g, const string& type) { std::vector result; @@ -107,6 +113,8 @@ xla::StatusOr BuildRecvAtHostNode( xla::StatusOr ReplaceArgNodesWithRecvAtHostNode( Graph* g, const string& oc_cluster_name, std::vector* recv_at_host_dtypes, Node* key_placeholder) { + // TODO(b/77601805): use out nodes for source node, instead of traversing all + // nodes. std::vector arg_nodes = GatherNodesWithType(*g, "_Arg"); TF_RETURN_IF_ERROR(GetArgDataTypes(arg_nodes, recv_at_host_dtypes)); TF_ASSIGN_OR_RETURN( @@ -218,6 +226,8 @@ xla::StatusOr BuildSendFromHostNode( xla::StatusOr ReplaceRetNodesWithSendFromHostNode( Graph* g, const string& oc_cluster_name, std::vector* send_from_host_dtypes, Node* key_placeholder) { + // TODO(b/77601805): use in nodes for sink node, instead of traversing all + // nodes. std::vector ret_nodes = GatherNodesWithType(*g, "_Retval"); TF_RETURN_IF_ERROR(GetRetDataTypes(ret_nodes, send_from_host_dtypes)); TF_ASSIGN_OR_RETURN( @@ -258,7 +268,7 @@ absl::optional> GetInferredInputShapes( return absl::nullopt; } - const PartialTensorShape shape = shapes[e->dst_input()]; + const PartialTensorShape shape = shapes[e->src_output()]; if (!shape.IsFullyDefined()) { return absl::nullopt; } @@ -411,8 +421,7 @@ Status ConstructHostGraph( if (node_map.find(n) != node_map.end()) { // Already copied this node. copy = node_map.at(n); - } else if (n->type_string() == "Placeholder" && - absl::EndsWith(n->name(), "_key_placeholder")) { + } else if (IsKeyPlaceholderNode(*n)) { // Change a). copy = key_placeholder; node_map[n] = copy; @@ -691,8 +700,7 @@ Status RewriteOutsideCompilationSubgraphFn::operator()( // Step 4: add XLA cluster and outside compilation attr. for (Node* n : (*graph)->nodes()) { - if (n->type_string() == "Placeholder" && - absl::EndsWith(n->name(), "_key_placeholder")) { + if (IsKeyPlaceholderNode(*n)) { continue; } diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc index f7ffa2589ae676..bd8719b7f1acb7 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc @@ -221,8 +221,8 @@ Status ConvertTensorFlowSliceToStaticShapedSlice( .WithOpName("static_shaped_slice"), slice_inputs_int64.input, slice_inputs_int64.begin, slice_size) .node(); - std::vector compile_time_const_inputs; - compile_time_const_inputs.push_back(2); + std::vector compile_time_const_inputs; + compile_time_const_inputs.push_back("size"); (*result)->AddAttr(kXlaCompileTimeConstantInputsAttr, compile_time_const_inputs); return status; @@ -314,15 +314,18 @@ Status FindAndRewriteSlices(Graph* g, bool* changed) { Status IncreaseDynamismForAutoJitPass::Run( const GraphOptimizationPassOptions& options) { + legacy_flags::MarkForCompilationPassFlags* flags = + legacy_flags::GetMarkForCompilationPassFlags(); + if (flags->tf_xla_clustering_debug) { + dump_graph::DumpGraphToFile("before_increase_dynamism_for_auto_jit_pass", + **options.graph, options.flib_def); + } + bool changed; TF_RETURN_IF_ERROR(FindAndRewriteSlices(options.graph->get(), &changed)); - if (changed) { - legacy_flags::MarkForCompilationPassFlags* flags = - legacy_flags::GetMarkForCompilationPassFlags(); - if (flags->tf_xla_clustering_debug) { - dump_graph::DumpGraphToFile("increase_dynamism_for_auto_jit_pass", - **options.graph, options.flib_def); - } + if (changed && flags->tf_xla_clustering_debug) { + dump_graph::DumpGraphToFile("increase_dynamism_for_auto_jit_pass", + **options.graph, options.flib_def); } return Status::OK(); diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc index 06cd7cf2dd7a30..0f6f612e967035 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc @@ -129,8 +129,8 @@ TEST(SliceToDynamicSliceRewriteTest, Basic) { Op("ConcatV2"), AssignedDevice(kHostName), Inputs(m_slice_size_0, Const(static_cast(500)), Const(zero_32)))); - std::vector compile_time_constant_inputs; - compile_time_constant_inputs.push_back(2); + std::vector compile_time_constant_inputs; + compile_time_constant_inputs.push_back("size"); auto m_dynamic_slice = NodeWith( Op("Slice"), AssignedDevice(kDeviceName), Attr(kXlaCompileTimeConstantInputsAttr, compile_time_constant_inputs), diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 6bcae1dcc3dcf8..56b7909ffd3348 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -39,12 +39,22 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/util/stream_executor_util.h" +// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that +// in error case, it returns RET instead of void. +#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ + do { \ + ::tensorflow::Status _s(__VA_ARGS__); \ + if (!TF_PREDICT_TRUE(_s.ok())) { \ + (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ + return RET; \ + } \ + } while (0) + namespace tensorflow { namespace { -Status PlatformInfoFromContext(OpKernelConstruction* ctx, - XlaPlatformInfo* result) { +XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) { DeviceType device_type = ctx->device_type(); se::Platform::Id platform_id = nullptr; const XlaDevice::Metadata* xla_device_metadata = nullptr; @@ -76,16 +86,16 @@ Status PlatformInfoFromContext(OpKernelConstruction* ctx, } if (!device_allocator) { - TF_ASSIGN_OR_RETURN(se::Platform* const platform, - se::MultiPlatformManager::PlatformWithId(platform_id)); + xla::StatusOr maybe_platform = + se::MultiPlatformManager::PlatformWithId(platform_id); + OP_REQUIRES_OK_RETURN(ctx, XlaPlatformInfo(), maybe_platform.status()); + xla_allocator = absl::make_unique( - platform, ctx->device()->GetAllocator({})); + maybe_platform.ValueOrDie(), ctx->device()->GetAllocator({})); } - *result = XlaPlatformInfo(device_type, platform_id, xla_device_metadata, - std::move(xla_allocator), device_allocator); - - return Status::OK(); + return XlaPlatformInfo(device_type, platform_id, xla_device_metadata, + std::move(xla_allocator), device_allocator); } // A closure describing how to run a compiled version of a TensorFlow function. @@ -179,9 +189,8 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, : OpKernel(ctx), constants_(constants), resources_(resources), - function_(function) { - OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); -} + function_(function), + platform_info_(PlatformInfoFromContext(ctx)) {} static Status BuildCompilationCache(OpKernelContext* ctx, const XlaPlatformInfo& platform_info, @@ -333,18 +342,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { } namespace { - -// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that -// in error case, it returns RET instead of void. -#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ - do { \ - ::tensorflow::Status _s(__VA_ARGS__); \ - if (!TF_PREDICT_TRUE(_s.ok())) { \ - (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ - return RET; \ - } \ - } while (0) - // Helper static functions to construct parameters for // XlaLocalLaunchBase constructor from OpKernelConstruction. std::vector ConstantsVector(OpKernelConstruction* ctx) { @@ -381,7 +378,12 @@ NameAttrList FunctionAttr(OpKernelConstruction* ctx) { return *func; } -#undef OP_REQUIRES_OK_RETURN +bool MustCompileAttr(OpKernelConstruction* ctx) { + bool must_compile; + OP_REQUIRES_OK_RETURN(ctx, false, + ctx->GetAttr("must_compile", &must_compile)); + return must_compile; +} } // namespace XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) @@ -396,10 +398,9 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx) : OpKernel(ctx), constants_(ConstantsVector(ctx)), resources_(ResourcesVector(ctx)), - function_(FunctionAttr(ctx)) { - OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("must_compile", &must_compile_)); -} + function_(FunctionAttr(ctx)), + platform_info_(PlatformInfoFromContext(ctx)), + must_compile_(MustCompileAttr(ctx)) {} void XlaCompileOp::Compute(OpKernelContext* ctx) { VLOG(3) << "XlaCompileOp " << def().name() @@ -409,13 +410,30 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { xla::LocalExecutable* executable; std::map variables; - if (legacy_flags::GetXlaOpsCommonFlags().tf_xla_always_defer_compilation) { + bool cannot_compile_cluster; + { + mutex_lock guard(cannot_compile_cluster_mu_); + cannot_compile_cluster = cannot_compile_cluster_; + } + + if (legacy_flags::GetXlaOpsCommonFlags().tf_xla_always_defer_compilation || + cannot_compile_cluster) { executable = nullptr; } else { - OP_REQUIRES_OK(ctx, CompileToLocalExecutable( - ctx, function_, platform_info_, resources_, - constants_, /*lazy=*/!must_compile_, &client, - &variables, &kernel, &executable)); + Status status = CompileToLocalExecutable( + ctx, function_, platform_info_, resources_, constants_, + /*lazy=*/!must_compile_, &client, &variables, &kernel, &executable); + if (must_compile_ || status.code() != error::UNIMPLEMENTED) { + OP_REQUIRES_OK(ctx, status); + } + + if (status.code() == error::UNIMPLEMENTED) { + LOG(WARNING) << "Compilation failed:" << status.ToString() + << ". Falling back to TF function call."; + executable = nullptr; + mutex_lock guard(cannot_compile_cluster_mu_); + cannot_compile_cluster_ = true; + } } AllocatorAttributes host_alloc_attrs; @@ -452,9 +470,8 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { ctx->set_output(1, compilation_successful); } -XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); -} +XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) + : OpKernel(ctx), platform_info_(PlatformInfoFromContext(ctx)) {} void XlaRunOp::Compute(OpKernelContext* ctx) { VLOG(3) << "XlaRunOp " << def().name(); diff --git a/tensorflow/compiler/jit/kernels/xla_ops.h b/tensorflow/compiler/jit/kernels/xla_ops.h index ac90837e0d9094..7b4d4b5b473778 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.h +++ b/tensorflow/compiler/jit/kernels/xla_ops.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ #define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ +#include + #include "tensorflow/compiler/jit/xla_compilation_cache.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" @@ -33,6 +35,7 @@ namespace tensorflow { class XlaPlatformInfo { public: XlaPlatformInfo() : device_type_("") {} + XlaPlatformInfo(XlaPlatformInfo&&) = default; explicit XlaPlatformInfo(const DeviceType device_type, se::Platform::Id platform_id, const XlaDevice::Metadata* xla_device_metadata, @@ -110,12 +113,12 @@ class XlaLocalLaunchBase : public OpKernel { protected: // Indexes of compile-time constant inputs - std::vector constants_; + const std::vector constants_; // Indexes of resource inputs - std::vector resources_; + const std::vector resources_; - NameAttrList function_; - XlaPlatformInfo platform_info_; + const NameAttrList function_; + const XlaPlatformInfo platform_info_; }; // XlaLocalLaunchOp is used to replace a region of the TensorFlow graph @@ -144,15 +147,23 @@ class XlaCompileOp : public OpKernel { private: // Indexes of compile-time constant inputs - std::vector constants_; + const std::vector constants_; // Indexes of resource inputs - std::vector resources_; + const std::vector resources_; - NameAttrList function_; + const NameAttrList function_; XlaPlatformInfo platform_info_; - bool must_compile_; + const bool must_compile_; + + // cannot_compile_cluster_ is set to true if XLA returns an Unimplemented + // error when compiling the cluster this _XlaCompile is supposed to compile. + // If `cannot_compile_cluster_` is true then we avoid compiling this cluster + // on any future calls to _XlaCompile. + bool cannot_compile_cluster_ GUARDED_BY(cannot_compile_cluster_mu_) = false; + + mutex cannot_compile_cluster_mu_; }; class XlaRunOp : public OpKernel { @@ -162,7 +173,7 @@ class XlaRunOp : public OpKernel { void Compute(OpKernelContext* ctx) override; private: - XlaPlatformInfo platform_info_; + const XlaPlatformInfo platform_info_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 0da8be046f5dfe..dae6ca4ad2403d 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -49,6 +49,25 @@ limitations under the License. namespace tensorflow { namespace { +// Aggregates information about what kinds of ops are allowed. +struct OperationFilter { + // Whether resource variable ops are allowed. We do not allow resource + // variable ops in called functions (either as direct TF calls or as higher + // order control flow ops) because we do not yet model their memory effects in + // jit/resource_variable_safety_analysis. + bool allow_resource_ops; + + // Whether stateful RNG ops are allowed. XLA's RNG does not have the same + // seeding behavior as TensorFlow's RNG (b/34749654). So we avoid + // auto-clustering stateful RNG ops. + bool allow_stateful_rng_ops; +}; + +bool IsStatefulRandomOp(absl::string_view op_name) { + return op_name == "RandomUniform" || op_name == "RandomShuffle" || + op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" || + op_name == "TruncatedNormal"; +} bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient @@ -101,7 +120,7 @@ const int kMaxRecursionDepth = 10; bool IsCompilableCall(const NodeDef& call_def, const DeviceType& jit_device_type, - bool allow_resource_ops, int depth, + const OperationFilter& op_filter, int depth, FunctionLibraryRuntime* lib_runtime); // Tests whether 'while_node' is a completely compilable loop. @@ -109,7 +128,7 @@ bool IsCompilableCall(const NodeDef& call_def, // while loop to be compilable. bool IsCompilableWhile(const Node& while_node, const DeviceType& jit_device_type, - bool allow_resource_ops, int depth, + const OperationFilter& op_filter, int depth, FunctionLibraryRuntime* lib_runtime) { const NameAttrList* name_attr; NodeDef call; @@ -124,7 +143,7 @@ bool IsCompilableWhile(const Node& while_node, call.set_name("while_cond"); call.set_op(cond_func); *call.mutable_attr() = name_attr->attr(); - if (!IsCompilableCall(call, jit_device_type, allow_resource_ops, depth + 1, + if (!IsCompilableCall(call, jit_device_type, op_filter, depth + 1, lib_runtime)) { VLOG(2) << "Rejecting While " << while_node.name() << ": can't compile loop condition: " << cond_func; @@ -140,7 +159,7 @@ bool IsCompilableWhile(const Node& while_node, call.set_name("while_body"); call.set_op(body_func); *call.mutable_attr() = name_attr->attr(); - if (!IsCompilableCall(call, jit_device_type, allow_resource_ops, depth + 1, + if (!IsCompilableCall(call, jit_device_type, op_filter, depth + 1, lib_runtime)) { VLOG(2) << "Rejecting While " << while_node.name() << ": can't compile loop body: " << body_func; @@ -154,7 +173,7 @@ bool IsCompilableWhile(const Node& while_node, // compilable. bool IsCompilableCall(const NodeDef& call_def, const DeviceType& jit_device_type, - bool allow_resource_ops, int depth, + const OperationFilter& op_filter, int depth, FunctionLibraryRuntime* lib_runtime) { if (depth > kMaxRecursionDepth) { VLOG(2) << "Rejecting " << call_def.op() @@ -195,16 +214,20 @@ bool IsCompilableCall(const NodeDef& call_def, continue; if (node->type_string() == "While") { // Handle functional While loop. - return IsCompilableWhile(*node, jit_device_type, allow_resource_ops, - depth + 1, lib_runtime); + return IsCompilableWhile(*node, jit_device_type, op_filter, depth + 1, + lib_runtime); } - if (!allow_resource_ops && + if (!op_filter.allow_resource_ops && (HasResourceInput(*node) || HasResourceOutput(*node))) { return false; } + if (!op_filter.allow_stateful_rng_ops && + IsStatefulRandomOp(node->type_string())) { + return false; + } if (!HasXLAKernel(*node, jit_device_type) && - !IsCompilableCall(node->def(), jit_device_type, allow_resource_ops, - depth + 1, lib_runtime)) { + !IsCompilableCall(node->def(), jit_device_type, op_filter, depth + 1, + lib_runtime)) { VLOG(2) << "Rejecting " << call_def.op() << ": unsupported op " << node->name() << ": " << node->def().ShortDebugString(); return false; @@ -426,14 +449,28 @@ Status FindCompilationCandidates( CHECK( XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)); DeviceType jit_device_type(registration->compilation_device_name); + + OperationFilter op_filter; + op_filter.allow_resource_ops = registration->compile_resource_ops; + op_filter.allow_stateful_rng_ops = + (registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways); + if (!HasXLAKernel(*node, jit_device_type) && - !IsCompilableCall(node->def(), jit_device_type, - registration->compile_resource_ops, 0, lib_runtime)) { + !IsCompilableCall(node->def(), jit_device_type, op_filter, 0, + lib_runtime)) { VLOG(2) << "Rejecting " << node->name() << ": unsupported op " << node->type_string(); continue; } - if (!registration->compile_resource_ops && + + if (!op_filter.allow_stateful_rng_ops && + IsStatefulRandomOp(node->type_string())) { + VLOG(2) << "Rejecting " << node->name() << ": stateful random operation"; + continue; + } + + if (!op_filter.allow_resource_ops && (HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) { // We don't have a way of returning values of type DT_RESOURCE from XLA // computations so we avoid auto-clustering nodes producing DT_RESOURCE. @@ -444,6 +481,7 @@ Status FindCompilationCandidates( << node->type_string(); continue; } + if (compile_time_const_nodes[node->id()]) { const OpDef* op_def; TF_RETURN_IF_ERROR( @@ -501,9 +539,7 @@ Status FindCompilationCandidates( // registration->compile_resource_ops is true for XLA_CPU/XLA_GPU but not // for CPU/GPU. if (node->type_string() == "While" && - !IsCompilableWhile(*node, jit_device_type, - registration->compile_resource_ops, 0, - lib_runtime)) { + !IsCompilableWhile(*node, jit_device_type, op_filter, 0, lib_runtime)) { continue; } // _Arg nodes in a top-level function represent feeds. @@ -563,10 +599,12 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { ®istration)); DeviceType jit_device_type(registration->compilation_device_name); - // We can always *compile* resource operations, even if we are sometimes - // unable to auto-cluster them. - const bool compile_resource_ops = true; - return IsCompilableCall(ndef, jit_device_type, compile_resource_ops, 0, flr); + // We can always *compile* resource operations and stateful RNGs, even if we + // are sometimes unable to auto-cluster them. + OperationFilter op_filter; + op_filter.allow_resource_ops = true; + op_filter.allow_stateful_rng_ops = true; + return IsCompilableCall(ndef, jit_device_type, op_filter, 0, flr); } Status MarkForCompilationPass::Run( @@ -577,10 +615,8 @@ Status MarkForCompilationPass::Run( GetGlobalJitLevel(options); legacy_flags::MarkForCompilationPassFlags* flags = legacy_flags::GetMarkForCompilationPassFlags(); - bool cpu_global_jit = flags->tf_xla_cpu_global_jit; bool fusion_only = flags->tf_xla_fusion_only; - VLOG(1) << "flags->tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit; VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only; VLOG(1) << "flags->tf_xla_auto_jit = " << flags->tf_xla_auto_jit; const FunctionLibraryDefinition* fld = options.flib_def; @@ -599,9 +635,6 @@ Status MarkForCompilationPass::Run( return false; } - // If this device requires a JIT, we must say yes. - if (registration->requires_compilation) return true; - // If there is a _XlaCompile annotation, use its value. bool compile = false; Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); @@ -638,18 +671,21 @@ Status MarkForCompilationPass::Run( return false; } - // Otherwise use the value of global_jit_level. - // Ignore enable_jit_by_default if global jit compilation for CPU - // is explicitly requested via tf_xla_cpu_global_jit flag - bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU; + // Otherwise use the value of global_jit_level and the device's + // autoclustering policy. bool should_compile = - (ignore_registration || registration->enable_jit_by_default) && - global_jit_level != OptimizerOptions::OFF; + registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways || + (registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally && + global_jit_level != OptimizerOptions::OFF); if (!should_compile) { if (global_jit_level == OptimizerOptions::OFF) { VLOG(2) << "Rejecting " << node->name() << ": global jit disabled."; } else { - VLOG(2) << "Rejecting " << node->name() << ": JIT for device disabled."; + VLOG(2) + << "Rejecting " << node->name() + << ": autoclustering for device only when requested explicitly."; } } return should_compile; @@ -1037,12 +1073,10 @@ Status MarkForCompilationPass::RunImpl( XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration); // Compile if this is a cluster of >= min_cluster_size compilable operators. - // Also, always compile if the operator is placed on a device that requires - // compilation, or if it contains at least one op that is marked for + // Also, always compile if it contains at least one op that is marked for // compilation that is not an Identity op. if (effective_cluster_sizes[cluster] >= min_cluster_size || - (effective_cluster_sizes[cluster] > 0 && marked_for_compilation) || - registration->requires_compilation) { + (effective_cluster_sizes[cluster] > 0 && marked_for_compilation)) { string& name = cluster_names[cluster]; if (name.empty()) { diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 452091b28227ff..ef4f1ea2b06461 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -923,9 +923,8 @@ TEST(XlaCompilationTest, RandomShapeOnXlaDevice) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); std::unordered_map clusters = GetClusters(*graph); - EXPECT_NE(clusters["test/shape_rng"], ""); - EXPECT_NE(clusters["test/reshape"], ""); - EXPECT_NE(clusters["test/shape_rng"], clusters["test/reshape"]); + EXPECT_EQ(clusters["test/shape_rng"], ""); + EXPECT_EQ(clusters["test/reshape"], ""); } TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) { @@ -1061,5 +1060,48 @@ TEST(XlaCompilationTest, NOT_DontClusterSpreadingNodes) { // Improved Heuristics should prevent this probably. EXPECT_EQ(clusters["MatMulSource_dev0"], clusters["MatMul0_dev0"]); } + +TEST(XlaCompilationTest, ClusterStatefulRandomOpOnXlaDevice) { + absl::string_view xla_cpu_device = + "/job:worker/replica:0/task:0/device:XLA_CPU:0"; + + Scope root = Scope::NewRootScope().ExitOnError(); + Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200}); + Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT); + Output b = ops::RandomUniform(root.WithOpName("test/b"), shape, DT_FLOAT); + Output c = ops::Add(root.WithOpName("test/c"), a, b); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + for (Node* n : graph->nodes()) { + if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { + n->set_assigned_device_name(string(xla_cpu_device)); + } + } + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_NE(clusters["test/a"], ""); + EXPECT_NE(clusters["test/b"], ""); + EXPECT_NE(clusters["test/c"], ""); +} + +TEST(XlaCompilationTest, DontAutoclusterStatefulRandomOp) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200}); + Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT); + Output b = ops::RandomUniform(root.WithOpName("test/b"), shape, DT_FLOAT); + Output c = ops::Add(root.WithOpName("test/c"), a, b); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_EQ(clusters["test/a"], ""); + EXPECT_EQ(clusters["test/b"], ""); +} } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/node_matchers.cc b/tensorflow/compiler/jit/node_matchers.cc index 95b8ace306aca8..c788091724e443 100644 --- a/tensorflow/compiler/jit/node_matchers.cc +++ b/tensorflow/compiler/jit/node_matchers.cc @@ -485,6 +485,16 @@ std::pair impl::AttrLiteralHelper( return {int_list_attr.first, attr_value}; } +std::pair impl::AttrLiteralHelper( + const std::pair>& string_list_attr) { + AttrValue attr_value; + AttrValue::ListValue* list = attr_value.mutable_list(); + for (string s : string_list_attr.second) { + list->add_s(s); + } + return {string_list_attr.first, attr_value}; +} + impl::NodeMatcherProperties impl::Attr(std::pair attr) { impl::NodeMatcherProperties props; props.set_attr(std::move(attr)); diff --git a/tensorflow/compiler/jit/node_matchers.h b/tensorflow/compiler/jit/node_matchers.h index cd2ab53e4047e6..0d4f02c236bba3 100644 --- a/tensorflow/compiler/jit/node_matchers.h +++ b/tensorflow/compiler/jit/node_matchers.h @@ -170,6 +170,9 @@ std::pair AttrLiteralHelper( std::pair AttrLiteralHelper( const std::pair>& int_list_attr); + +std::pair AttrLiteralHelper( + const std::pair>& string_list_attr); } // namespace impl // ----------------------------------------------------------------------------- diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 5b9610322336ac..550ffa2465a2e1 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -210,7 +210,8 @@ bool IsIntraClusterEdge(const Edge& edge) { bool IsMustCompileDevice(const DeviceType& device_type) { const XlaOpRegistry::DeviceRegistration* registration; if (XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { - return registration->requires_compilation; + return registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways; } return false; diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 6d1e5279bfa7ea..31cb32e3059bc1 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -173,9 +173,7 @@ Status XlaCompileOnDemandOp::Compile( XlaCompiler::Options options; options.device_type = metadata.jit_device_type(); options.client = metadata.client(); - auto flib_def = absl::make_unique( - OpRegistry::Global(), FunctionDefLibrary{}); - options.flib_def = flib_def.get(); + options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); options.shape_representation_fn = metadata.shape_representation_fn(); XlaCompiler::CompileOptions compile_options; diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index f2cea3d000e239..116e0756036e72 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -42,8 +42,10 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options, XlaOpRegistry::DeviceRegistration registration; registration.compilation_device_name = DEVICE_CPU_XLA_JIT; - registration.requires_compilation = !compile_on_demand; - registration.enable_jit_by_default = false; + registration.autoclustering_policy = + compile_on_demand + ? XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested + : XlaOpRegistry::AutoclusteringPolicy::kAlways; registration.compile_resource_ops = true; XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration); @@ -60,7 +62,6 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options, options.device_name = DEVICE_XLA_CPU; options.device_ordinal = 0; options.compilation_device_name = DEVICE_CPU_XLA_JIT; - options.transfer_as_literal = false; options.use_multiple_streams = false; auto device = absl::make_unique(session_options, options); devices->push_back(device.release()); diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 17353456eb5d18..2289abd2df3726 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -201,12 +201,18 @@ XlaDevice::XlaDevice(const SessionOptions& session_options, jit_device_name_(options.compilation_device_name), platform_(options.platform), use_multiple_streams_(options.use_multiple_streams), - transfer_as_literal_(options.transfer_as_literal), shape_representation_fn_(options.shape_representation_fn) { VLOG(1) << "Created XLA device " << options.compilation_device_name << " " << this; thread_pool_.reset(new thread::ThreadPool(session_options.env, "xla_device", /*num_threads=*/1)); + + // We have multiple device to device streams to allow for some concurrency + // between transfers. The particular value of '4' is chosen fairly + // arbitrarily. It may be necessary to make this tunable via + // XlaDevice::Options. + static constexpr int kNumDeviceToDeviceStreams = 4; + device_to_device_streams_.resize(kNumDeviceToDeviceStreams); } XlaDevice::~XlaDevice() { @@ -274,8 +280,9 @@ xla::StatusOr XlaDevice::GetDeviceContextLocked() { TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_, &need_new_device_context)); - std::shared_ptr host_to_device_stream = stream_; - std::shared_ptr device_to_host_stream = stream_; + std::shared_ptr host_to_device_stream; + std::shared_ptr device_to_host_stream; + std::vector> device_to_device_streams; if (use_multiple_streams_) { TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream", &host_to_device_stream_, @@ -283,8 +290,18 @@ xla::StatusOr XlaDevice::GetDeviceContextLocked() { TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "device_to_host_stream", &device_to_host_stream_, &need_new_device_context)); + for (std::shared_ptr& stream : device_to_device_streams_) { + TF_RETURN_IF_ERROR( + EnsureStreamOkLocked(backend, "device_to_device_stream", &stream, + &need_new_device_context)); + } host_to_device_stream = host_to_device_stream_; device_to_host_stream = device_to_host_stream_; + device_to_device_streams = device_to_device_streams_; + } else { + host_to_device_stream = stream_; + device_to_host_stream = stream_; + device_to_device_streams = {stream_}; } if (!need_new_device_context) { @@ -302,8 +319,9 @@ xla::StatusOr XlaDevice::GetDeviceContextLocked() { // ensures that the streams remain live for the duration of a run, even if // an error is encountered and the streams are replaced with new ones. device_context_ = new XlaDeviceContext( - stream_, host_to_device_stream, device_to_host_stream, client(), - transfer_as_literal_, shape_representation_fn_, thread_pool_.get()); + stream_, std::move(host_to_device_stream), + std::move(device_to_host_stream), std::move(device_to_device_streams), + client(), shape_representation_fn_, thread_pool_.get()); VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext " << device_context_; diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 223f0f6649f054..8881b697bc863e 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -108,11 +108,6 @@ class XlaDevice : public LocalDevice { // The name of the compilation device (e.g., "XLA_CPU_JIT"); string compilation_device_name; - // 'transfer_as_literal' is true if device<->host transfers must be done - // using XLA's TransferLiteral{To,From}Device interface. If false, we can - // use ThenMemcpy instead. - bool transfer_as_literal = false; - // If 'use_multiple_streams' is true, we create separate streams for // compute, host-to-device, and device-to-host communication. bool use_multiple_streams = false; @@ -188,6 +183,7 @@ class XlaDevice : public LocalDevice { se::Platform* const platform_; // Not owned. // Memory allocator associated with this device. Allocator* xla_allocator_ GUARDED_BY(mu_) = nullptr; // Not owned. + // Stream associated with this device. Operations enqueued on this // stream are executed on the device. Operations include data // copying back and forth between CPU and the device, and @@ -203,9 +199,11 @@ class XlaDevice : public LocalDevice { // If use_multiple_streams_, device to host transfers are performed using this // stream. std::shared_ptr device_to_host_stream_ GUARDED_BY(mu_); - // Must we use XLA's transfer manager for correct host<->device transfers? if - // false, we can use ThenMemcpy() instead. - const bool transfer_as_literal_; + // If use_multiple_streams_, transfers between different devices are performed + // using these streams. + std::vector> device_to_device_streams_ + GUARDED_BY(mu_); + const XlaCompiler::ShapeRepresentationFn shape_representation_fn_; // The device context accessed by all users of the XlaDevice, set by calls to diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index cbab81fe2f51db..eb3cf27624bb76 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -53,16 +53,17 @@ void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } XlaDeviceContext::XlaDeviceContext( std::shared_ptr compute_stream, std::shared_ptr host_to_device_stream, - std::shared_ptr device_to_host_stream, xla::LocalClient* client, - bool transfer_as_literal, + std::shared_ptr device_to_host_stream, + std::vector> device_to_device_streams, + xla::LocalClient* client, XlaCompiler::ShapeRepresentationFn shape_representation_fn, thread::ThreadPool* thread_pool) : stream_(std::move(compute_stream)), host_to_device_stream_(std::move(host_to_device_stream)), device_to_host_stream_(std::move(device_to_host_stream)), + device_to_device_streams_(std::move(device_to_device_streams)), client_(client), transfer_manager_(client->backend().transfer_manager()), - transfer_as_literal_(transfer_as_literal), shape_representation_fn_(std::move(shape_representation_fn)), thread_pool_(thread_pool) { CHECK(host_to_device_stream_ != nullptr); @@ -75,71 +76,6 @@ XlaDeviceContext::XlaDeviceContext( } } -Status XlaDeviceContext::TransferLiteralToDevice(const Tensor& host_tensor, - Tensor* device_tensor) const { - xla::Shape xla_shape; - TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(), - host_tensor.shape(), &xla_shape)); - // Create a reference to hold onto host_tensor until after the literal has - // been transferred. Also make sure the literal exists until the function - // asynchronously completes, as it will be wrapped in an xla::LiteralSlice. - TensorReference ref(host_tensor); - auto literal = std::make_shared( - static_cast(DMAHelper::base(&host_tensor)), xla_shape); - - XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); - const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer(); - VLOG(1) << "Transfer to device as literal: " << literal->ToString() << " " - << shaped_buffer.ToString(); - if (UseMultipleStreams() && !transfer_manager_->CanShapedBufferBeAccessedNow( - stream_->parent(), shaped_buffer)) { - // Initially wait for the compute stream so that memory allocations are - // synchronized. - host_to_device_stream_->ThenWaitFor(stream_.get()); - } - TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync( - host_to_device_stream_.get(), *literal, shaped_buffer)); - if (UseMultipleStreams()) { - auto event = std::make_shared(stream_->parent()); - TF_RET_CHECK(event->Init()) << "Event failed to initialize!"; - host_to_device_stream_->ThenRecordEvent(event.get()); - xla_tensor->ResetDefinitionEvent(std::move(event), - host_to_device_stream_.get()); - } - // Unref the host tensor, and capture the literal shared_ptr too so it goes - // out of scope when the lambda completes. - // We don't defer the call to done() onto the stream here, and the reasons why - // this is correct are subtle. We assume that: - // a) all consumers of the device tensor will wait for its definition event. - // b) if the tensor is destroyed, then the memory allocator will not hand out - // the same buffers until the transfer has completed. - host_to_device_stream_->ThenDoHostCallback([ref, literal]() { ref.Unref(); }); - - return Status::OK(); -} - -void XlaDeviceContext::TransferLiteralFromDevice( - Tensor* host_tensor, const Tensor& device_tensor, - const StatusCallback& done) const { - xla::MutableBorrowingLiteral literal; - TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(host_tensor, &literal)); - - const xla::ShapedBuffer& shaped_buffer = - XlaTensor::FromTensor(&device_tensor)->shaped_buffer(); - - TensorReference ref(device_tensor); - transfer_manager_->TransferLiteralFromDevice( - device_to_host_stream_.get(), shaped_buffer, literal, - [=, &shaped_buffer](xla::Status status) { - ref.Unref(); - done([&]() -> Status { - VLOG(1) << "Transfer from device as literal: " - << shaped_buffer.ToString(); - return status; - }()); - }); -} - void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, @@ -158,54 +94,73 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, << cpu_tensor->shape().DebugString() << " " << device_tensor->shape().DebugString(); - void* src_ptr = const_cast(DMAHelper::base(cpu_tensor)); - const int64 total_bytes = cpu_tensor->TotalBytes(); XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); CHECK(xla_tensor); - xla::StatusOr shape_or_status = - shape_representation_fn_(device_tensor->shape(), device_tensor->dtype()); - if (!shape_or_status.ok()) { - done(shape_or_status.status()); - return; - } - TensorShape shape = shape_or_status.ValueOrDie(); - if (!xla_tensor->has_shaped_buffer()) { - Status s = + Status status = [&]() -> Status { + TF_ASSIGN_OR_RETURN(TensorShape shape, + shape_representation_fn_(device_tensor->shape(), + device_tensor->dtype())); + + // The device tensor should always be fresh. + TF_RET_CHECK(!xla_tensor->has_shaped_buffer()); + + xla_tensor->set_host_tensor(*cpu_tensor); + TF_RETURN_IF_ERROR( xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_, - stream_->parent()->device_ordinal()); - if (!s.ok()) { - done(s); - return; + stream_->parent()->device_ordinal())); + + xla::BorrowingLiteral literal( + static_cast(DMAHelper::base(cpu_tensor)), + xla_tensor->shaped_buffer().on_host_shape()); + + VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " " + << xla_tensor->shaped_buffer().ToString(); + if (UseMultipleStreams() && + !transfer_manager_->CanShapedBufferBeAccessedNow( + stream_->parent(), xla_tensor->shaped_buffer())) { + // Initially wait for the compute stream so that memory allocations are + // synchronized. + host_to_device_stream_->ThenWaitFor(stream_.get()); } - } - Status status; - if (transfer_as_literal_) { - Tensor reshaped_cpu_tensor; - if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) { - done(errors::Internal( - "Tensor::CopyFrom failed when copying from CPU to XLA device")); - return; - } - status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); - } else { - se::DeviceMemoryBase dev_dst_ptr = - XlaTensor::DeviceMemoryFromTensor(*device_tensor); - host_to_device_stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); - // TODO(hpucha): Make this asynchronous. - Status block_status = host_to_device_stream_->BlockHostUntilDone(); - if (!block_status.ok()) { - status = xla::InternalError( - "Failed to complete data transfer on stream %p: %s", - host_to_device_stream_.get(), block_status.error_message().c_str()); + TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync( + host_to_device_stream_.get(), literal, xla_tensor->shaped_buffer())); + + if (UseMultipleStreams()) { + auto event = std::make_shared(stream_->parent()); + TF_RET_CHECK(event->Init()) << "Event failed to initialize!"; + host_to_device_stream_->ThenRecordEvent(event.get()); + xla_tensor->ResetDefinitionEvent(std::move(event), + host_to_device_stream_.get()); } + + return Status::OK(); + }(); + if (!status.ok()) { + done(status); + return; } - if (status.ok()) { - xla_tensor->set_host_tensor(*cpu_tensor); + + // Create a reference to hold onto cpu_tensor until after the literal has + // been transferred + TensorReference ref(*cpu_tensor); + if (UseMultipleStreams()) { + // Unref the host tensor when the transfer completes. + // We don't defer the call to done() onto the stream here, and the reasons + // why this is correct are subtle. We assume that: + // a) all consumers of the device tensor will wait for its definition event. + // b) if the tensor is destroyed, then the memory allocator will not hand + // out the same buffers until the transfer has completed. + host_to_device_stream_->ThenDoHostCallback([ref]() { ref.Unref(); }); + done(status); + } else { + host_to_device_stream_->ThenDoHostCallback([ref, done]() { + ref.Unref(); + done(Status::OK()); + }); } - done(status); } void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, @@ -225,30 +180,31 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, << cpu_tensor->shape().DebugString() << " " << device_tensor->shape().DebugString(); - const int64 total_bytes = cpu_tensor->TotalBytes(); - se::DeviceMemoryBase dev_src_ptr = - XlaTensor::DeviceMemoryFromTensor(*device_tensor); - void* dst_ptr = DMAHelper::base(cpu_tensor); XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); - xla_tensor->WaitForDefinitionEventOnStream(device_to_host_stream_.get()); - Status status; - if (transfer_as_literal_) { - TransferLiteralFromDevice(cpu_tensor, *device_tensor, done); - return; - } else { - device_to_host_stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes); - // TODO(hpucha): Make this asynchronous. - Status block_status = device_to_host_stream_->BlockHostUntilDone(); - if (!block_status.ok()) { - status = xla::InternalError( - "Failed to complete data transfer on stream %p: %s", stream_.get(), - block_status.error_message().c_str()); - } - } + xla::MutableBorrowingLiteral literal; + TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(cpu_tensor, &literal)); + + TensorReference ref(*device_tensor); + transfer_manager_->TransferLiteralFromDevice( + device_to_host_stream_.get(), xla_tensor->shaped_buffer(), literal, + [ref, xla_tensor, done](xla::Status status) { + done([&]() -> Status { + VLOG(1) << "Transfer from device as literal: " + << xla_tensor->shaped_buffer().ToString(); + return status; + }()); + ref.Unref(); + }); +} - done(status); +se::Stream* XlaDeviceContext::GetDeviceToDeviceStream() { + DCHECK_GT(device_to_device_streams_.size(), 0); + absl::MutexLock lock(&mu_); + int stream = next_stream_; + next_stream_ = (next_stream_ + 1) % device_to_device_streams_.size(); + return device_to_device_stream(stream); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 39521ec7ad6779..1e18df197a2dd6 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/synchronization/mutex.h" #include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/global_data.h" @@ -50,7 +51,8 @@ class XlaDeviceContext : public DeviceContext { std::shared_ptr compute_stream, std::shared_ptr host_to_device_stream, std::shared_ptr device_to_host_stream, - xla::LocalClient* client, bool transfer_as_literal, + std::vector> device_to_device_streams, + xla::LocalClient* client, XlaCompiler::ShapeRepresentationFn shape_representation_fn, thread::ThreadPool* thread_pool); @@ -61,14 +63,26 @@ class XlaDeviceContext : public DeviceContext { absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) override; + xla::LocalClient* client() const { return client_; } se::Stream* stream() const { return stream_.get(); } + se::Stream* host_to_device_stream() const { + return host_to_device_stream_.get(); + } + se::Stream* device_to_host_stream() const { + return device_to_host_stream_.get(); + } + se::Stream* device_to_device_stream(int index) const { + return device_to_device_streams_.at(index).get(); + } + xla::TransferManager* transfer_manager() const { return transfer_manager_; } + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const { + return shape_representation_fn_; + } + + // Returns a device-to-device stream, in round-robin fashion. + se::Stream* GetDeviceToDeviceStream(); private: - Status TransferLiteralToDevice(const Tensor& host_tensor, - Tensor* device_tensor) const; - void TransferLiteralFromDevice(Tensor* host_tensor, - const Tensor& device_tensor, - const StatusCallback& done) const; bool UseMultipleStreams() const { return stream_ != host_to_device_stream_; } // The main compute stream of the device, used to synchronize the transfer @@ -80,16 +94,22 @@ class XlaDeviceContext : public DeviceContext { // The stream to use for transferring data from device to host. Can be // idential to stream_, but must not be nullptr. std::shared_ptr device_to_host_stream_; + // Streams to use for transferring data directly between different devices, + // e.g., over NVLINK. + std::vector> device_to_device_streams_; + // For the underlying memory allocator and XLA's TransferManager. xla::LocalClient* client_; // Transfer manager, for marshalling data to and from the device. xla::TransferManager* transfer_manager_; - // True if we must use XLA's TransferManager for correct device transfers. - const bool transfer_as_literal_; + XlaCompiler::ShapeRepresentationFn shape_representation_fn_; // Thread pool used for running closures thread::ThreadPool* thread_pool_; + + absl::Mutex mu_; + int next_stream_ GUARDED_BY(mu_) = 0; }; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 241ea8f60df8b6..adf0f994b84d9f 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/kernels/resource_variable_ops.h" #include "tensorflow/core/kernels/sendrecv_ops.h" #include "tensorflow/core/kernels/shape_ops.h" +#include "tensorflow/core/kernels/stack.h" #include "tensorflow/core/kernels/variable_ops.h" namespace tensorflow { @@ -257,9 +258,27 @@ class XlaAssignVariableOp : public OpKernel { .Device(DEVICE) \ .TypeConstraint("T") \ .HostMemory("input"), \ - RetvalOp); + RetvalOp); \ + \ + REGISTER_KERNEL_BUILDER(Name("StackV2") \ + .Device(DEVICE) \ + .HostMemory("max_size") \ + .HostMemory("handle"), \ + StackOp); \ + REGISTER_KERNEL_BUILDER(Name("StackPushV2") \ + .Device(DEVICE) \ + .HostMemory("handle") \ + .TypeConstraint("T", TYPES), \ + TemplatedStackPushOp); \ + REGISTER_KERNEL_BUILDER(Name("StackPopV2") \ + .Device(DEVICE) \ + .HostMemory("handle") \ + .TypeConstraint("elem_type", TYPES), \ + StackPopOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("StackCloseV2").Device(DEVICE).HostMemory("handle"), StackCloseOp); -// TODO(phawkins): currently we do not register the QueueEnqueueMany, +// TODO(b/118881356): currently we do not register the QueueEnqueueMany, // QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read // and write the tensors they access in order to concatenate them into a batch. // We would need either to call out to an XLA computation to perform the diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index d9021fb001abda..717daadc4ac8e0 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -37,8 +37,8 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options, std::vector* devices) { XlaOpRegistry::DeviceRegistration registration; registration.compilation_device_name = DEVICE_GPU_XLA_JIT; - registration.requires_compilation = true; - registration.enable_jit_by_default = false; + registration.autoclustering_policy = + XlaOpRegistry::AutoclusteringPolicy::kAlways; registration.compile_resource_ops = true; XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration); @@ -59,7 +59,6 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options, options.device_name = DEVICE_XLA_GPU; options.device_ordinal = 0; options.compilation_device_name = DEVICE_GPU_XLA_JIT; - options.transfer_as_literal = false; options.use_multiple_streams = false; auto device = absl::make_unique(session_options, options); diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index aee3b58c997ff0..e828bae865d630 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -45,8 +45,8 @@ Status XlaInterpreterDeviceFactory::CreateDevices( XlaOpRegistry::DeviceRegistration registration; registration.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT; - registration.requires_compilation = true; - registration.enable_jit_by_default = false; + registration.autoclustering_policy = + XlaOpRegistry::AutoclusteringPolicy::kAlways; registration.compile_resource_ops = true; XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER, registration); @@ -60,7 +60,6 @@ Status XlaInterpreterDeviceFactory::CreateDevices( options.device_name = DEVICE_XLA_INTERPRETER; options.device_ordinal = 0; options.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT; - options.transfer_as_literal = false; options.use_multiple_streams = false; auto device = absl::make_unique(session_options, options); devices->push_back(device.release()); diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index f76aef2ccaa3ea..194e710f1f1134 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -375,6 +375,27 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "resampler_ops_test", + size = "small", + srcs = ["resampler_ops_test.py"], + disabled_backends = [ + # TODO(b/74459949) Support BatchDot in CPU backend. + "cpu", + "cpu_ondemand", + ], + # TODO(b/112295522): figure out how to make OSS build pass. + tags = ["no_oss"], + deps = [ + ":xla_test", + "//tensorflow/contrib/resampler:resampler_ops", + "//tensorflow/contrib/resampler:resampler_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "dynamic_stitch_test", size = "small", @@ -489,8 +510,6 @@ tf_xla_py_test( name = "function_test", size = "small", srcs = ["function_test.py"], - # Functions are not implemented in the on-demand compilation model yet. - disabled_backends = "cpu_ondemand", deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -680,9 +699,6 @@ tf_xla_py_test( name = "random_ops_test", size = "small", srcs = ["random_ops_test.py"], - disabled_backends = [ - "cpu_ondemand", - ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -698,6 +714,10 @@ tf_xla_py_test( size = "medium", srcs = ["reduce_ops_test.py"], shard_count = 5, + tags = [ + # TODO(b/119059212): Re-enable this test in OSS. + "no_oss", + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -713,7 +733,6 @@ tf_xla_py_test( name = "reduce_window_test", size = "small", srcs = ["reduce_window_test.py"], - disabled_backends = ["cpu_ondemand"], deps = [ ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", @@ -822,8 +841,6 @@ tf_xla_py_test( name = "stack_ops_test", size = "small", srcs = ["stack_ops_test.py"], - # Stack ops are not implemented in the on-demand compilation model yet. - disabled_backends = "cpu_ondemand", deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -851,7 +868,7 @@ tf_xla_py_test( size = "small", srcs = ["tensor_array_ops_test.py"], # TensorArray ops are not implemented in the on-demand compilation model yet. - disabled_backends = "cpu_ondemand", + disabled_backends = ["cpu_ondemand"], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -872,7 +889,7 @@ tf_xla_py_test( size = "small", srcs = ["tensor_list_ops_test.py"], # TensorList ops are not implemented in the on-demand compilation model yet. - disabled_backends = "cpu_ondemand", + disabled_backends = ["cpu_ondemand"], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -952,7 +969,6 @@ tf_xla_py_test( name = "while_test", size = "small", srcs = ["while_test.py"], - disabled_backends = ["cpu_ondemand"], deps = [ ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", @@ -1109,6 +1125,7 @@ cc_library( "//tensorflow/core:test", "//tensorflow/core:testlib", "//tensorflow/core/kernels:ops_util", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], @@ -1219,7 +1236,6 @@ tf_xla_py_test( name = "xla_ops_test", size = "medium", srcs = ["xla_ops_test.py"], - disabled_backends = ["cpu_ondemand"], deps = [ ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 1b39d53dc0908e..4e6dd6abfc9cdb 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -178,6 +178,13 @@ def testFloatOps(self): [0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9, 6.1, 10.0], dtype=dtype), expected=np.array([0, 0, 0, 0, 0, 6, 7, 8, 9, 10, 0, 0], dtype=dtype)) + self._testBinary( + gen_nn_ops.leaky_relu_grad, + np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype), + np.array([0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9], dtype=dtype), + expected=np.array([0.2, 0.4, 0.6, 0.8, 1, 6, 7, 8, 9, 10], + dtype=dtype)) + self._testBinary( gen_nn_ops.softmax_cross_entropy_with_logits, np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype), diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index 1d3979b21bfd91..447a7de2cb6526 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -50,6 +50,8 @@ def tf_xla_py_test( """ if disabled_backends == None: disabled_backends = [] + if type(disabled_backends) != "list": + fail("disabled_backends must be a list of strings", "disabled_backends") enabled_backends = [b for b in all_backends() if b not in disabled_backends] test_names = [] diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py index f1b87a5ffb73be..5b197afd655404 100644 --- a/tensorflow/compiler/tests/ftrl_test.py +++ b/tensorflow/compiler/tests/ftrl_test.py @@ -135,7 +135,7 @@ def testFtrlwithoutRegularization(self): self.assertAllCloseAccordingToType( np.array([-2.60260963, -4.29698515]), var0.eval(), - float_rtol=1e-5, + float_rtol=1e-4, half_rtol=1e-2) self.assertAllCloseAccordingToType( np.array([-0.28432083, -0.56694895]), @@ -167,7 +167,8 @@ def testFtrlwithoutRegularization2(self): # Validate updated params self.assertAllCloseAccordingToType( - np.array([-2.55607247, -3.98729396]), var0.eval(), 1e-5, 1e-5) + np.array([-2.55607247, -3.98729396]), var0.eval(), 1e-5, 1e-5, + float_rtol=1e-4) self.assertAllCloseAccordingToType( np.array([-0.28232238, -0.56096673]), var1.eval(), 1e-5, 1e-5) diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index d67b16f8e9e732..0e2d840418156d 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -448,8 +448,8 @@ def testAlignCorners1x2To3x2(self): for dtype in self.float_types: self._assertForwardOpMatchesExpected( np.array([[1, 2]], dtype=dtype), [3, 3], - expected=np.array( - [[1, 1.5, 2], [1, 1.5, 2], [1, 1.5, 2]], dtype=np.float32)) + expected=np.array([[1, 1.5, 2], [1, 1.5, 2], [1, 1.5, 2]], + dtype=np.float32)) def testAlignCorners1x2To3x2Grad(self): for dtype in self.float_types: @@ -477,8 +477,8 @@ def testAlignCorners2x2To3x3(self): for dtype in self.float_types: self._assertForwardOpMatchesExpected( np.array([[1, 2], [3, 4]], dtype=dtype), [3, 3], - expected=np.array( - [[1, 1.5, 2], [2, 2.5, 3], [3, 3.5, 4]], dtype=np.float32)) + expected=np.array([[1, 1.5, 2], [2, 2.5, 3], [3, 3.5, 4]], + dtype=np.float32)) def testAlignCorners2x2To3x3Grad(self): self._assertBackwardOpMatchesExpected( @@ -498,8 +498,8 @@ def testAlignCorners3x3To2x2Grad(self): np.array([[7, 13], [22, 4]], dtype=np.float32), input_shape=[3, 3], dtype=dtype, - expected=np.array( - [[7, 0, 13], [0, 0, 0], [22, 0, 4]], dtype=np.float32)) + expected=np.array([[7, 0, 13], [0, 0, 0], [22, 0, 4]], + dtype=np.float32)) def testAlignCorners4x4To3x3(self): for dtype in self.float_types: @@ -507,8 +507,8 @@ def testAlignCorners4x4To3x3(self): np.array( [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=dtype), [3, 3], - expected=np.array( - [[1, 2.5, 4], [7, 8.5, 10], [13, 14.5, 16]], dtype=np.float32)) + expected=np.array([[1, 2.5, 4], [7, 8.5, 10], [13, 14.5, 16]], + dtype=np.float32)) def testAlignCorners4x4To3x3Grad(self): for dtype in self.float_types: @@ -516,41 +516,39 @@ def testAlignCorners4x4To3x3Grad(self): np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), input_shape=[4, 4], dtype=dtype, - expected=np.array( - [[1, 1, 1, 3], [2, 1.25, 1.25, 3], [2, 1.25, 1.25, 3], - [7, 4, 4, 9]], - dtype=np.float32)) + expected=np.array([[1, 1, 1, 3], [2, 1.25, 1.25, 3], + [2, 1.25, 1.25, 3], [7, 4, 4, 9]], + dtype=np.float32)) def testAlignCorners3x3To9x9(self): for dtype in self.float_types: self._assertForwardOpMatchesExpected( np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype), [9, 9], expected=np.array( - [[1.0, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50, 2.75, 3.00], [ - 1.75, 2.00, 2.25, 2.50, 2.75, 3.00, 3.25, 3.50, 3.75 - ], [2.50, 2.75, 3.00, 3.25, 3.50, 3.75, 4.00, 4.25, 4.50], [ - 3.25, 3.50, 3.75, 4.00, 4.25, 4.50, 4.75, 5.00, 5.25 - ], [4.00, 4.25, 4.50, 4.75, 5.00, 5.25, 5.50, 5.75, 6.00], [ - 4.75, 5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75 - ], [5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.00, 7.25, 7.50], [ - 6.25, 6.50, 6.75, 7.00, 7.25, 7.50, 7.75, 8.00, 8.25 - ], [7.00, 7.25, 7.50, 7.75, 8.00, 8.25, 8.50, 8.75, 9.00]], + [[1.0, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50, 2.75, 3.00], + [1.75, 2.00, 2.25, 2.50, 2.75, 3.00, 3.25, 3.50, 3.75], + [2.50, 2.75, 3.00, 3.25, 3.50, 3.75, 4.00, 4.25, 4.50], + [3.25, 3.50, 3.75, 4.00, 4.25, 4.50, 4.75, 5.00, 5.25], + [4.00, 4.25, 4.50, 4.75, 5.00, 5.25, 5.50, 5.75, 6.00], + [4.75, 5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75], + [5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.00, 7.25, 7.50], + [6.25, 6.50, 6.75, 7.00, 7.25, 7.50, 7.75, 8.00, 8.25], + [7.00, 7.25, 7.50, 7.75, 8.00, 8.25, 8.50, 8.75, 9.00]], dtype=np.float32)) def testAlignCorners3x3To9x9Grad(self): for dtype in self.float_types: self._assertBackwardOpMatchesExpected( - np.array( - [[1.00, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50, 2.75, 3.00], [ - 1.75, 2.00, 2.25, 2.50, 2.75, 3.00, 3.25, 3.50, 3.75 - ], [2.50, 2.75, 3.00, 3.25, 3.50, 3.75, 4.00, 4.25, 4.50], [ - 3.25, 3.50, 3.75, 4.00, 4.25, 4.50, 4.75, 5.00, 5.25 - ], [4.00, 4.25, 4.50, 4.75, 5.00, 5.25, 5.50, 5.75, 6.00], [ - 4.75, 5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75 - ], [5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.00, 7.25, 7.50], [ - 6.25, 6.50, 6.75, 7.00, 7.25, 7.50, 7.75, 8.00, 8.25 - ], [7.00, 7.25, 7.50, 7.75, 8.00, 8.25, 8.50, 8.75, 9.00]], - dtype=np.float32), + np.array([[1.00, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50, 2.75, 3.00], + [1.75, 2.00, 2.25, 2.50, 2.75, 3.00, 3.25, 3.50, 3.75], + [2.50, 2.75, 3.00, 3.25, 3.50, 3.75, 4.00, 4.25, 4.50], + [3.25, 3.50, 3.75, 4.00, 4.25, 4.50, 4.75, 5.00, 5.25], + [4.00, 4.25, 4.50, 4.75, 5.00, 5.25, 5.50, 5.75, 6.00], + [4.75, 5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75], + [5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.00, 7.25, 7.50], + [6.25, 6.50, 6.75, 7.00, 7.25, 7.50, 7.75, 8.00, 8.25], + [7.00, 7.25, 7.50, 7.75, 8.00, 8.25, 8.50, 8.75, 9.00]], + dtype=np.float32), input_shape=[3, 3], dtype=dtype, expected=np.array( @@ -571,12 +569,12 @@ def testAlignCorners8x8To16x16(self): (np.array([[0, 1, 2, 3, 4, 5, 6, 7]], dtype=np.float32) + np.array( [[0], [1], [2], [3], [4], [5], [6], [7]], dtype=np.float32)) * 15.0, [16, 16], - expected=7 * (np.array( - [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]], - dtype=np.float32) + np.array( - [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], - [12], [13], [14], [15]], - dtype=np.float32)), + expected=7 * + (np.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]], + dtype=np.float32) + + np.array([[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], + [12], [13], [14], [15]], + dtype=np.float32)), large_tolerance=True) def testNonAlignCorners3x2To6x4(self): @@ -600,6 +598,26 @@ def testNonAlignCorners6x4To3x2(self): expected=np.array(expected_data, dtype=dtype), align_corners=False) + def testNonAlignCorners3x2To6x4Batch2(self): + input_data = [[[64, 32], [32, 64], [50, 100]], [[32, 16], [16, 32], + [25, 50]]] + expected_data = [[[64.0, 48.0, 32.0, 32.0], [48.0, 48.0, 48.0, 48.0], + [32.0, 48.0, 64.0, 64.0], [41.0, 61.5, 82.0, 82.0], + [50.0, 75.0, 100.0, 100.0], [50.0, 75.0, 100.0, 100.0]], + [[32.0, 24.0, 16.0, 16.0], [24.0, 24.0, 24.0, 24.0], + [16.0, 24.0, 32.0, 32.0], [20.5, 30.75, 41.0, 41.0], + [25.0, 37.5, 50.0, 50.0], [25.0, 37.5, 50.0, 50.0]]] + + for dtype in self.float_types: + input_image = np.array(input_data, dtype=dtype) + expected = np.array(expected_data, dtype=dtype) + with self.cached_session() as sess, self.test_scope(): + image = array_ops.placeholder(input_image.dtype) + resized = gen_image_ops.resize_bilinear( + image, [6, 4], align_corners=False) + out = sess.run(resized, {image: input_image[:, :, :, np.newaxis]}) + self.assertAllClose(expected[:, :, :, np.newaxis], out) + class NonMaxSuppressionTest(xla_test.XLATestCase): @@ -804,5 +822,6 @@ def testSelectFromContinuousOverLap(self): self.assertEqual(num_valid, 3) self.assertAllClose(indices_tf[:num_valid], [0, 2, 4]) + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index dc119fb0f8a41a..cfccf5f3d2a0a3 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -45,6 +45,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -2687,6 +2688,37 @@ TEST_F(OpTest, Reverse) { }); } +TEST_F(OpTest, ReverseSequence) { + Repeatedly([this]() { + std::vector dims = RandomDims(/*min_rank=*/2); + auto type = Choose(kAllXlaTypes); + int64 rank = dims.size(); + + // Choose random batch and sequence dimensions. + std::vector shuffled_dim_ids(rank); + absl::c_iota(shuffled_dim_ids, 0); + absl::c_shuffle(shuffled_dim_ids, generator()); + shuffled_dim_ids.resize(2); + int batch_dim = shuffled_dim_ids[0]; + int seq_dim = shuffled_dim_ids[1]; + + int batch_size = dims[batch_dim]; + int max_seq_len = dims[seq_dim]; + std::vector seq_lens(batch_size); + std::uniform_int_distribution d(0, max_seq_len); + absl::c_generate(seq_lens, [&]() { return d(generator()); }); + + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("ReverseSequence") + .RandomInput(type, dims) + .Input(test::AsTensor(seq_lens)) + .Attr("seq_dim", seq_dim) + .Attr("batch_dim", batch_dim) + .Attr("T", type) + .Attr("Tlen", DT_INT32)); + }); +} + TEST_F(OpTest, ReverseV2) { Repeatedly([this]() { auto type = Choose(kAllXlaTypes); diff --git a/tensorflow/compiler/tests/resampler_ops_test.py b/tensorflow/compiler/tests/resampler_ops_test.py new file mode 100644 index 00000000000000..d05554fdb681a7 --- /dev/null +++ b/tensorflow/compiler/tests/resampler_ops_test.py @@ -0,0 +1,156 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for resampler ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.contrib import resampler +from tensorflow.contrib.resampler.ops import gen_resampler_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class ResamplerOpsTest(xla_test.XLATestCase): + + def _assertForwardOpMatchesExpected(self, image_np, warp_np, expected): + with self.test_session() as sess, self.test_scope(): + input_image = array_ops.placeholder(image_np.dtype) + warp = array_ops.placeholder(warp_np.dtype) + resampled = resampler.resampler(input_image, warp, name='resampler') + out = sess.run(resampled, {input_image: image_np, warp: warp_np}) + + self.assertAllCloseAccordingToType( + expected, out, half_rtol=1e-2, bfloat16_rtol=3e-2) + + def _assertBackwardOpMatchesExpected(self, input_np, warp_np, grad_output_np, + expected_grad_data, expected_grad_warp): + with self.cached_session() as sess, self.test_scope(): + input_image = array_ops.placeholder(input_np.dtype) + warp = array_ops.placeholder(warp_np.dtype) + grad_output = array_ops.placeholder(grad_output_np.dtype) + + grad_data, grad_warp = gen_resampler_ops.resampler_grad( + input_image, warp, grad_output) + + grad_data_tf, grad_warp_tf = sess.run([grad_data, grad_warp], { + input_image: input_np, + warp: warp_np, + grad_output: grad_output_np + }) + + self.assertAllCloseAccordingToType( + expected_grad_warp, grad_warp_tf, half_rtol=1e-2, bfloat16_rtol=3e-2) + self.assertAllCloseAccordingToType( + expected_grad_data, grad_data_tf, half_rtol=1e-2, bfloat16_rtol=3e-2) + + def testSimple(self): + for dtype in self.float_types: + input_shape = [1, 2, 2, 1] + input_rgb_data = [0, 5, 13, 54] + input_np = np.array(input_rgb_data, dtype=dtype).reshape(input_shape) + + warp_shape = [1, 2] + warp_data = [0.7, 0.6] + warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) + expected = [[26.42]] + self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + + grad_output = np.ones([1, 1], dtype=dtype) + + expected_grad_data = [[[[0.12], [0.27999997]], [[0.18000001], + [0.42000002]]]] + + expected_grad_warp = [[26.60000038, 38.20000076]] + + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + + def testMultiChannel(self): + for dtype in self.float_types: + input_shape = [1, 2, 2, 3] + input_rgb_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + input_np = np.array(input_rgb_data, dtype=dtype).reshape(input_shape) + + warp_shape = [1, 2] + warp_data = [0.7, 0.6] + warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) + expected = [[59.58000183, 146.94000244, 107.37999725]] + self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + + grad_output = np.ones([1, 3], dtype=dtype) + + expected_grad_data = [[[[0.12, 0.12, 0.12], + [0.27999997, 0.27999997, 0.27999997]], + [[0.18000001, 0.18000001, 0.18000001], + [0.42000002, 0.42000002, 0.42000002]]]] + + expected_grad_warp = [[199, 30]] + + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + + def testBatch2Height3byWidth3RGB(self): + for dtype in self.float_types: + input_shape = [2, 3, 3, 3] + input_rgb_data = [ + 0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1, 30, 105, 2, 40, 115, + 3, 50, 125, 4, 60, 135, 5, 70, 145, 6, 0, 5, 13, 54, 135, 226, 37, 8, + 234, 90, 255, 1, 30, 105, 2, 40, 115, 3, 50, 125, 4, 60, 135, 5, 70, + 145, 6 + ] + input_np = np.array(input_rgb_data, dtype=dtype).reshape(input_shape) + + # 2 batches and 2 samples for each batch. + warp_shape = [2, 2, 2] + warp_data = [0.7, 0.6, 1, 0.7, 0.9, 1.2, 1.3, 1.6] + warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) + + expected_forward = [[[43.92, 128.4, 65.86], [37.2, 114., 69.2]], + [[40.6, 122.8, 2.5], [51., 126, 4.1]]] + + self._assertForwardOpMatchesExpected(input_np, warp_np, expected_forward) + + expected_grad_data = [[[[0.12, 0.12, 0.12], + [0.57999998, 0.57999998, 0.57999998], + [0., 0., 0.]], + [[0.18000001, 0.18000001, 0.18000001], + [1.12, 1.12, 1.12], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0.08000001, 0.08000001, 0.08000001], + [0.99999988, 0.99999988, 0.99999988], + [0.11999997, 0.11999997, 0.11999997]], + [[0.02000001, 0.02000001, 0.02000001], + [0.60000008, 0.60000008, 0.60000008], + [0.17999998, 0.17999998, 0.17999998]]]] + expected_grad_warp = [[[33.39999008, -96.20000458], [-26.10000229, + -278.]], + [[-162.99998474, 39.99999619], [21., 63.]]] + + grad_output = np.ones([2, 2, 3], dtype=dtype) + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 7d0eb7ef822464..d612d3b32dd6b0 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -358,6 +358,11 @@ def testFloatOps(self): np.array([[-0.05, 6.05, 5]], dtype=dtype), expected=np.array([[0, 6, 5]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + nn_ops.leaky_relu, + np.array([[-2, -1, 0, 1, 2]], dtype=dtype), + expected=np.array([[-0.4, -0.2, 0.0, 1.0, 2.0]], dtype=dtype)) + self._assertOpOutputMatchesExpected( nn_ops.softmax, np.array([1, 2, 3, 4], dtype=dtype), diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 5fc9a352ff930c..f18d8c20089625 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -194,6 +194,7 @@ cc_library( ":side_effect_util", ":tf2xla_util", "//tensorflow/compiler/jit:xla_cluster_util", + "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 9ee4178f5c213e..d85b4f5ae0cb9c 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -178,6 +178,32 @@ tf_kernel_library( ], ) +# A separate cc_library for resampler_ops is needed because resampler is in +# contrib/, and thus the declaration of resampler cannot be pulled into the deps +# of xla_ops. Therefore, resampler_ops is its own cc_library target, and its +# corresponding tf_kernel_library is defined in contrib/resampler/BUILD. +cc_library( + name = "resampler_ops", + srcs = ["resampler_ops.cc"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:numeric", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], + alwayslink = 1, +) + cc_library( name = "conv_op_helpers", srcs = ["conv_op_helpers.cc"], diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index fa04b0f7d00299..0c7ca602bfacd5 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -231,20 +231,22 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, num_extended[0] = upper_padding[0] / (dims.kernel_size[0]); num_extended[1] = upper_padding[1] / (dims.kernel_size[1]); + const int64 batch_dim_size = + builder->GetShape(input).ValueOrDie().dimensions(0); if (num_extended[0] > 0) { - auto slice = - xla::Slice(input_data, {0, in_size[0] - 1, 0, 0}, - {1, in_size[0], in_size[1], channels}, {1, 1, 1, 1}); + auto slice = xla::Slice( + input_data, {0, in_size[0] - 1, 0, 0}, + {batch_dim_size, in_size[0], in_size[1], channels}, {1, 1, 1, 1}); for (int i = 0; i < num_extended[0]; i++) { input_data = xla::ConcatInDim(builder, {input_data, slice}, 1); } } if (num_extended[1] > 0) { - auto slice = - xla::Slice(input_data, {0, 0, in_size[1] - 1, 0}, - {1, in_size[0] + num_extended[0], in_size[1], channels}, - {1, 1, 1, 1}); + auto slice = xla::Slice( + input_data, {0, 0, in_size[1] - 1, 0}, + {batch_dim_size, in_size[0] + num_extended[0], in_size[1], channels}, + {1, 1, 1, 1}); for (int i = 0; i < num_extended[1]; i++) { input_data = xla::ConcatInDim(builder, {input_data, slice}, 2); } diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index d35777ccb1271e..a8e230ba107ce8 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -15,14 +15,12 @@ limitations under the License. // Native XLA implementations of XLA Relu Ops -#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/no_op.h" namespace tensorflow { namespace { @@ -37,6 +35,7 @@ class ReluOp : public XlaOpKernel { ctx->SetOutput(0, xla::Max(zero, ctx->Input(0))); } }; +REGISTER_XLA_OP(Name("Relu"), ReluOp); class Relu6Op : public XlaOpKernel { public: @@ -49,6 +48,22 @@ class Relu6Op : public XlaOpKernel { ctx->SetOutput(0, xla::Clamp(zero, ctx->Input(0), six)); } }; +REGISTER_XLA_OP(Name("Relu6"), Relu6Op); + +class LeakyReluOp : public XlaOpKernel { + public: + explicit LeakyReluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_)); + } + void Compile(XlaOpKernelContext* ctx) override { + auto features = ctx->Input("features"); + auto output = + xla::Max(features, features * xla::ScalarLike(features, alpha_)); + ctx->SetOutput(0, output); + } + float alpha_; +}; +REGISTER_XLA_OP(Name("LeakyRelu"), LeakyReluOp); class ReluGradOp : public XlaOpKernel { public: @@ -64,6 +79,7 @@ class ReluGradOp : public XlaOpKernel { ctx->SetOutput(0, xla::Select(pred, ctx->Input(0), zero)); } }; +REGISTER_XLA_OP(Name("ReluGrad"), ReluGradOp); class Relu6GradOp : public XlaOpKernel { public: @@ -83,11 +99,24 @@ class Relu6GradOp : public XlaOpKernel { ctx->SetOutput(0, out); } }; - -REGISTER_XLA_OP(Name("Relu"), ReluOp); -REGISTER_XLA_OP(Name("Relu6"), Relu6Op); -REGISTER_XLA_OP(Name("ReluGrad"), ReluGradOp); REGISTER_XLA_OP(Name("Relu6Grad"), Relu6GradOp); +class LeakyReluGradOp : public XlaOpKernel { + public: + explicit LeakyReluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_)); + } + void Compile(XlaOpKernelContext* ctx) override { + auto gradients = ctx->Input("gradients"); + auto features = ctx->Input("features"); + auto output = + xla::Select(xla::Gt(features, xla::ScalarLike(features, 0)), gradients, + gradients * xla::ScalarLike(gradients, alpha_)); + ctx->SetOutput(0, output); + } + float alpha_; +}; +REGISTER_XLA_OP(Name("LeakyReluGrad"), LeakyReluGradOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc new file mode 100644 index 00000000000000..847704608fb32b --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc @@ -0,0 +1,541 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +using xla::XlaOp; + +// TODO(b/112295522): note that sampling from image boundary is not currently +// being handled properly. + +// Calculates the bilinear weight tensor, given basis ratio (px, py) of the +// sampling position: +// W = [(1-px)*(1-py), px*(1-py), (1-px)*py, px*py] +// 'ratio' tensor has dimensions [batch, dim_0, ...dim_n, 2]. +// +// The returned tensor has dimensions [batch, dim_0, ... dim_n, 4]. +XlaOp BilinearWeights(XlaOpKernelContext* ctx, XlaOp ratio, + const TensorShape warp_shape, + xla::PrimitiveType xla_type) { + auto first_term = xla::ConstantR2( + ctx->builder(), {{1.0, 1.0}, {0.0, 1.0}, {1.0, 0.0}, {0.0, 0.0}}); + first_term = xla::ConvertElementType(first_term, xla_type); + + auto warp_dims = warp_shape.dim_sizes(); + std::vector broadcast_dims(warp_dims.begin(), warp_dims.end() - 1); + broadcast_dims.push_back(4); + broadcast_dims.push_back(2); + + const int64 broadcast_dims_size = broadcast_dims.size(); + + std::vector last_two_dims_indices = {(broadcast_dims_size - 2), + (broadcast_dims_size - 1)}; + + xla::Shape broadcast_shape = + xla::ShapeUtil::MakeShape(xla_type, broadcast_dims); + + auto broadcast_first_term = + xla::BroadcastInDim(first_term, broadcast_shape, last_two_dims_indices); + + // Ratio is of the same dimension as warp, which is [batch, dim_0,... dim_n, + // 2], we broadcast ratio tensor to 'broadcast_dim' by keeping the + // [batch, dim_0,...dim_n] dimensions and the [2] dimension as the last + // dimension. + std::vector ratio_broadcast_indices(broadcast_dims.size()); + std::iota(ratio_broadcast_indices.begin(), ratio_broadcast_indices.end(), 0); + ratio_broadcast_indices.erase(ratio_broadcast_indices.end() - 2); + + auto broadcast_ratio = + xla::BroadcastInDim(ratio, broadcast_shape, ratio_broadcast_indices); + + auto first_term_subtract_weights = broadcast_first_term - broadcast_ratio; + + // Now we have [(1-px, 1-py), (-px, 1-py), (1-px, -py), (px, py)], need to + // flip the signs of the second and the third term. + auto sign_change = xla::ConstantR2( + ctx->builder(), {{1.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}, {1.0, 1.0}}); + sign_change = xla::ConvertElementType(sign_change, xla_type); + + auto broadcast_sign_change = + xla::BroadcastInDim(sign_change, broadcast_shape, last_two_dims_indices); + + auto flipped = first_term_subtract_weights * broadcast_sign_change; + + // Build up the final bilinear weight tensor by multiply reduction, which + // gives: + // [(1-px)*(1-py), px*(1-py), (1-px)*py, px*py] + // for each 4 neighboring pixels where px and py are the weight of the target + // pixel we are sampling from. + return xla::Reduce( + flipped, xla::One(ctx->builder(), xla_type), + xla::CreateScalarMultiplyComputation(xla_type, ctx->builder()), + {broadcast_dims_size - 1}); +} + +// Concatenates the batch indices to the (x, y) coordinate indices. +// This is done by first creating an Iota tensor that represents the current +// batch it is in, then concatenate with the givin (coordinate) indices. +// +// The resulting tensor has dimension (batch, dim_0, ... dim_n, 3) where +// the last dimension of size 3 in turn is [batch_number, x, y]. +// The [batch_number, x, y] dimension is needed because the indices +// [x,y] alone cannot allow the xla::Gather operation to gather from the input +// data, which is of dimension [batch, height(y), width(x), channel] with +// 'batch' being the first dimension. +XlaOp ConcatenateIota(xla::XlaBuilder* b, XlaOp indices, + const TensorShape& warp_shape) { + // We need to create an iota tensor with the same batch dimension. + std::vector dimensions; + for (auto dim : warp_shape) { + dimensions.push_back(dim.size); + } + // Except the last dimension, which is of size 1. + dimensions.back() = 1; + + auto batch_indices = + xla::Iota(b, xla::ShapeUtil::MakeShape(xla::U32, dimensions), + /*iota_dimension=*/0); + + return xla::ConcatInDim(b, {batch_indices, indices}, dimensions.size() - 1); +} + +// Gathers the 2x2 neighbors of the input starting_indices, and return a +// tensor of dimension [batch, dim_0, ... dim_n, 4, data_channels]. +// 'gather_indices' is of dimension [batch, dim_0, ..., dim_n, 3] where the last +// dimension of size 3 is (batch_no, x, y). +XlaOp Gather2by2Neighbors(xla::XlaBuilder* b, XlaOp data, XlaOp gather_indices, + int64 data_channels, int warp_dims) { + xla::GatherDimensionNumbers gather_dim_numbers; + const int64 neighbor_data_dimensions = warp_dims + 2; + // Since the Gather output dimensions are [batch, dim_0, ... dim_n, 2, 2, + // data_channels], the offset dimensions for Gather is the last 3 dimensions. + gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 3); + gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 2); + gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 1); + // The last dimension of 'gather_indices' is the starting indices for gather. + gather_dim_numbers.set_index_vector_dim(warp_dims - 1); + gather_dim_numbers.add_collapsed_slice_dims(0); + gather_dim_numbers.add_start_index_map(0); + // Since input is of dimension [batch, height(y), width(x), channel], and warp + // is of dimension [batch, x, y], the ordering of x, y here needs to be + // swapped when gathering. + gather_dim_numbers.add_start_index_map(2); + gather_dim_numbers.add_start_index_map(1); + // Data dimensions are [batch, x, y, channel]. + // Output dimensions are [batch, dim_0, ... dim_n, 2, 2, data_channels]. + auto neighbors_data = xla::Gather(data, gather_indices, gather_dim_numbers, + /*slice_sizes=*/{1, 2, 2, data_channels}); + // Collapse the ...,2,2,... dimensions into ...,4,... + return xla::Collapse(neighbors_data, {warp_dims - 1, warp_dims}); +} + +// Scatter 'updates' tensor to 'grad_data' based on 'indices'. Returns the +// resulting tensor of dimension: [batch, dim_0, ...dim_n, 2, 2, data_channels]. +// This function can also be seen as the inverse of 'Gather2by2Neighbors'. +XlaOp ScatterToGradData(XlaOpKernelContext* ctx, XlaOp grad_data, XlaOp indices, + XlaOp updates, int64 warp_dims, + xla::PrimitiveType xla_type) { + xla::ScatterDimensionNumbers scatter_dim_numbers; + const int64 neighbor_data_dimensions = warp_dims + 2; + // Since the Scatter output dimensions are [batch, dim_0, ... dim_n, 2, 2, + // data_channels], the update window dimensions is the last 3 dimensions. + scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 3); + scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 2); + scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 1); + scatter_dim_numbers.set_index_vector_dim(warp_dims - 1); + + scatter_dim_numbers.add_inserted_window_dims(0); + scatter_dim_numbers.add_scatter_dims_to_operand_dims(0); + // Since input is of dimension [batch, height(y), width(x), channel], and warp + // is of dimension [batch, x, y], the ordering of x, y here needs to be + // swapped when scattering. + scatter_dim_numbers.add_scatter_dims_to_operand_dims(2); + scatter_dim_numbers.add_scatter_dims_to_operand_dims(1); + + return xla::Scatter(grad_data, indices, updates, + xla::CreateScalarAddComputation(xla_type, ctx->builder()), + scatter_dim_numbers); +} + +// Build computation the backprop into input 'data'. +// Where input: +// grad_output is of dimension [batch, dim_0, ...dim_n, channel] +// ratio is of dimension [batch, dim_0, ...dim_n, 2] +// gather_indices is of dimension [batch, dim_0, ...dim_n, 3] +// +// Output: +// scatter-add to each 2x2 grad_data neighbor: +// grad_data[fx, fy, chan] += output_grad * dx * dy +// grad_data[cx, fy, chan] += output_grad * (1 - dx) * dy +// grad_data[fx, cy, chan] += output_grad * dx * (1 - dy) +// grad_data[cx, cy, chan] += output_grad * (1 - dx) * (1 - dy) +// where (dx, dy) is (1 - ratio). +XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, + XlaOp gather_indices, xla::PrimitiveType warp_type, + TensorShape warp_shape, int64 data_channels, + xla::Shape data_shape) { + // Weights tensor has dimension [batch, dim_0, ... dim_n, 4]. + auto weights = BilinearWeights(ctx, ratio, warp_shape, warp_type); + + auto warp_dims = warp_shape.dim_sizes(); + std::vector warp_dims_without_last_dims(warp_dims.begin(), + warp_dims.end() - 1); + + std::vector reshaped_weights_dims = warp_dims_without_last_dims; + // Reshape the last dimension of size 4 to two dimensions [2, 2]. + reshaped_weights_dims.push_back(2); + reshaped_weights_dims.push_back(2); + std::vector reshape_dims(warp_shape.dims()); + std::iota(reshape_dims.begin(), reshape_dims.end(), 0); + // The dimension is [batch, dim_0,..., dim_n, 2, 2]. + auto reshaped_weights = xla::Reshape(weights, /*dimensions=*/reshape_dims, + /*new_sizes=*/reshaped_weights_dims); + + std::vector weights_with_channels_dims = reshaped_weights_dims; + weights_with_channels_dims.push_back(data_channels); + auto weights_with_channels_shape = + xla::ShapeUtil::MakeShape(warp_type, weights_with_channels_dims); + std::vector reshaped_weights_indices(reshaped_weights_dims.size()); + std::iota(reshaped_weights_indices.begin(), reshaped_weights_indices.end(), + 0); + + // The dimension is [batch, dim_0, ..., dim_n, 2, 2, data_channel]. + auto broadcast_reshaped_weights = xla::BroadcastInDim( + reshaped_weights, weights_with_channels_shape, reshaped_weights_indices); + + std::vector grad_output_indices(warp_dims_without_last_dims.size()); + std::iota(grad_output_indices.begin(), grad_output_indices.end(), 0); + grad_output_indices.push_back(weights_with_channels_dims.size() - 1); + XlaOp broadcast_grad_output = xla::BroadcastInDim( + grad_output, weights_with_channels_shape, grad_output_indices); + + auto grad_output_multiply_weights = + broadcast_grad_output * broadcast_reshaped_weights; + + auto grad_data = xla::ConstantLiteral( + ctx->builder(), xla::Literal::CreateFromShape(data_shape)); + + return ScatterToGradData(ctx, grad_data, gather_indices, + grad_output_multiply_weights, warp_shape.dims(), + warp_type); +} + +// Build computation for the backprop into input 'warp'. +// Where input: +// warp is of dimension [batch, dim_0, ...dim_n, 2] +// grad_output is of dimension [batch, dim_0, ...dim_n, channel] +// ratio is of dimension [batch, dim_0, ...dim_n, 2] +// gather_indices is of dimension [batch, dim_0, ...dim_n, 3] +// data is of dimension [batch, x, y, channel] +// +// Output (simplified by ignoring the batch dimensions): +// Since the forward path has: +// output = dot(weights * neighbors) +// The backprop into warp will therefore be: +// grad_warp = output_grad * d_output / d_warp +// = output_grad * (d_weights / d_warp * neighbors + d_neighbors / +// d_warp * weight) +// Where: +// d_weights / d_warp_x = [-(1 - py), (1 - py), -py, py] +// d_weights / d_warp_y = [-(1 - px), -px, (1-px), px] +// and +// d_neighbors / d_warp_x = 0 +// +// Therefore: +// grad_warp_x = py * (img_cxcy - img_fxcy) + (1-py) * (img_cxfy-img_fxfy) +// grad_warp_y = px * (img_cxcy - img_cxfy) + (1-px) * (img_fxcy-img_fxfy) +// +// where (px, py) is warp, (fx, fy) is the left top corner and (cx, cy) is the +// bottom right corner in a 2x2 neighborhood. +XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, + XlaOp gather_indices, XlaOp data, + TensorShape warp_shape, int64 data_channels, + xla::PrimitiveType data_type) { + auto warp_dims = warp_shape.dim_sizes(); + std::vector warp_dims_without_last_dims(warp_dims.begin(), + warp_dims.end() - 1); + + std::vector neighbor_broadcast_dims = warp_dims_without_last_dims; + neighbor_broadcast_dims.push_back(4); + + // With dimension [batch, dim_0, ...dim_n, 4] + auto neighbor_broadcast_shape = + xla::ShapeUtil::MakeShape(data_type, neighbor_broadcast_dims); + + // The dimension is [batch, dim_0, ... dim_n, 4, data_channels] + auto neighbors_data = Gather2by2Neighbors( + ctx->builder(), data, gather_indices, data_channels, warp_shape.dims()); + + const int64 last_warp_dim = warp_shape.dims() - 1; + + // Since we will be creating the dot product of: + // lhs: [batch, dim_0, ...dim_n, 4] + // and + // rhs: [batch, dim_0, ...dim_n, 4, data_channels] + // we choose the last dimension of lhs and the second last dimension of rhs, + // with size 4, as the contracting dimension. + xla::DotDimensionNumbers dot_dims; + for (int i = 0; i < warp_shape.dims() - 1; ++i) { + dot_dims.add_lhs_batch_dimensions(i); + dot_dims.add_rhs_batch_dimensions(i); + } + dot_dims.add_lhs_contracting_dimensions(warp_shape.dims() - 1); + dot_dims.add_rhs_contracting_dimensions(warp_shape.dims() - 1); + + // img_cxcy - img_fxcy + auto bottom_right_minus_bottom_left = xla::DotGeneral( + xla::BroadcastInDim( + xla::ConvertElementType( + xla::ConstantR1(ctx->builder(), {0, 0, -1, 1}), data_type), + neighbor_broadcast_shape, {last_warp_dim}), + neighbors_data, dot_dims, /*precision_config=*/nullptr); + + // img_cxfy - img_fxfy + auto top_right_minus_top_left = xla::DotGeneral( + xla::BroadcastInDim( + xla::ConvertElementType( + xla::ConstantR1(ctx->builder(), {-1, 1, 0, 0}), data_type), + neighbor_broadcast_shape, {last_warp_dim}), + neighbors_data, dot_dims, /*precision_config=*/nullptr); + + // img_cxcy - img_cxfy + auto bottom_right_minus_top_right = xla::DotGeneral( + xla::BroadcastInDim( + xla::ConvertElementType( + xla::ConstantR1(ctx->builder(), {0, -1, 0, 1}), data_type), + neighbor_broadcast_shape, {last_warp_dim}), + neighbors_data, dot_dims, /*precision_config=*/nullptr); + + // img_fxcy - img_fxfy + auto bottom_left_minus_top_left = xla::DotGeneral( + xla::BroadcastInDim( + xla::ConvertElementType( + xla::ConstantR1(ctx->builder(), {-1, 0, 1, 0}), data_type), + neighbor_broadcast_shape, {last_warp_dim}), + neighbors_data, dot_dims, /*precision_config=*/nullptr); + + // Slice out x and y. + auto weight_x = xla::SliceInDim(ratio, /*start_index=*/0, /*limit_index=*/1, + /*stride=*/1, /*dimno=*/last_warp_dim); + auto weight_y = xla::SliceInDim(ratio, /*start_index=*/1, /*limit_index=*/2, + /*stride=*/1, /*dimno=*/last_warp_dim); + + // Build 1 - y and 1 - x. + auto one_minus_y = xla::One(ctx->builder(), data_type) - weight_y; + auto one_minus_x = xla::One(ctx->builder(), data_type) - weight_x; + + auto x_before_reduce = + grad_output * weight_y * bottom_right_minus_bottom_left + + one_minus_y * top_right_minus_top_left; + + std::vector reshaped_sizes = warp_dims_without_last_dims; + reshaped_sizes.push_back(1); + + std::vector reshaped_dims(warp_dims_without_last_dims.size()); + std::iota(reshaped_dims.begin(), reshaped_dims.end(), 0); + + // Reduce-add along the channel dimension. + auto x_result = + xla::Reduce(x_before_reduce, xla::Zero(ctx->builder(), data_type), + xla::CreateScalarAddComputation(data_type, ctx->builder()), + {last_warp_dim}); + // Reshape before concatenating with y values. + XlaOp reshaped_x = xla::Reshape(x_result, reshaped_dims, reshaped_sizes); + + auto y_before_reduce = grad_output * weight_x * bottom_right_minus_top_right + + one_minus_x * bottom_left_minus_top_left; + // Reduce-add along the channel dimension. + auto y_result = + xla::Reduce(y_before_reduce, xla::Zero(ctx->builder(), data_type), + + xla::CreateScalarAddComputation(data_type, ctx->builder()), + {last_warp_dim}); + XlaOp reshaped_y = xla::Reshape(y_result, reshaped_dims, reshaped_sizes); + + return xla::ConcatInDim(ctx->builder(), {reshaped_x, reshaped_y}, + last_warp_dim); +} + +class ResamplerOp : public XlaOpKernel { + public: + explicit ResamplerOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape data_shape = ctx->InputShape("data"); + OP_REQUIRES(ctx, data_shape.dims() == 4, + errors::InvalidArgument("data must be 4-dimensional", + data_shape.DebugString())); + const int64 data_channels = data_shape.dim_size(3); + xla::PrimitiveType data_type = ctx->input_xla_type(0); + + TensorShape warp_shape = ctx->InputShape("warp"); + OP_REQUIRES(ctx, warp_shape.dims() >= 2, + errors::InvalidArgument("warp must be at least 2-dimensional", + warp_shape.DebugString())); + for (int size : warp_shape.dim_sizes()) { + OP_REQUIRES(ctx, size > 0, + errors::InvalidArgument("warp sizes must be positive, got [", + size, "]")); + } + const int64 last_warp_dim = warp_shape.dims() - 1; + // Last dimension of warp shape must be of size 2. + OP_REQUIRES(ctx, warp_shape.dim_size(last_warp_dim) == 2, + errors::InvalidArgument( + "the last dimension of warp must be exactly size 2.")); + + XlaOp data = ctx->Input("data"); + XlaOp warp = ctx->Input("warp"); + + // Find the coordinates of the top left corner for the 2x2 region to be + // sampled from. The dimensions are (batch, dim_0, ... dim_n, 2) where the + // last dimension of size 2 in turn is [x, y]. + XlaOp top_left = xla::ConvertElementType(warp, xla::U32); + + auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape); + + // The dimension is [batch, dim_0, ... dim_n, 4, data_channels] + auto neighbors_data = Gather2by2Neighbors( + ctx->builder(), data, gather_indices, data_channels, warp_shape.dims()); + + // Dimensions are [batch, dim_0, ... dim_n, 2]. + XlaOp ratio = warp - xla::ConvertElementType(top_left, data_type); + + // Obtain the bilinear blending weights, the dimension is [batch, dim_0, + // ...dim_n, 4]. + auto weights = BilinearWeights(ctx, ratio, warp_shape, data_type); + + // Since we will be creating the dot product of: + // lhs: [batch, dim_0, ...dim_n, 4] + // and + // rhs: [batch, dim_0, ...dim_n, 4, data_channels] + // we choose the last dimension of lhs and the second last dimension of rhs, + // with size 4, as the contracting dimension. + xla::DotDimensionNumbers dot_dims; + for (int i = 0; i < warp_shape.dims() - 1; ++i) { + dot_dims.add_lhs_batch_dimensions(i); + dot_dims.add_rhs_batch_dimensions(i); + } + dot_dims.add_lhs_contracting_dimensions(warp_shape.dims() - 1); + dot_dims.add_rhs_contracting_dimensions(warp_shape.dims() - 1); + + auto blended_pixels = xla::DotGeneral(weights, neighbors_data, dot_dims, + /*precision_config=*/nullptr); + + ctx->SetOutput(0, blended_pixels); + } +}; + +REGISTER_XLA_OP(Name("Resampler"), ResamplerOp); + +class ResamplerGradOp : public XlaOpKernel { + public: + explicit ResamplerGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + DataType output_dtype; + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape data_shape_tf = ctx->InputShape("data"); + OP_REQUIRES(ctx, data_shape_tf.dims() == 4, + errors::InvalidArgument("data must be 4-dimensional", + data_shape_tf.DebugString())); + const int64 data_channels = data_shape_tf.dim_size(3); + xla::PrimitiveType data_type = ctx->input_xla_type(0); + + TensorShape warp_shape = ctx->InputShape("warp"); + OP_REQUIRES(ctx, warp_shape.dims() >= 2, + errors::InvalidArgument("warp must be at least 2-dimensional", + warp_shape.DebugString())); + for (int size : warp_shape.dim_sizes()) { + OP_REQUIRES(ctx, size > 0, + errors::InvalidArgument("warp sizes must be positive, got [", + size, "]")); + } + // Last dimension of warp shape must be of size 2. + OP_REQUIRES(ctx, warp_shape.dim_size(warp_shape.dims() - 1) == 2, + errors::InvalidArgument( + "the last dimension of warp must be exactly size 2.")); + xla::PrimitiveType warp_type = ctx->input_xla_type(1); + + TensorShape output_grad_shape = ctx->InputShape("grad_output"); + OP_REQUIRES( + ctx, output_grad_shape.dims() >= 2, + errors::InvalidArgument("output_grad must be at least 2-dimensional", + output_grad_shape.DebugString())); + + // Dimensions are [batch, x, y, channel]. + XlaOp data = ctx->Input("data"); + xla::Shape data_shape = TensorShapeToXLAShape(data_type, data_shape_tf); + + // Dimensions are [batch, dim_0, ...dim_n, 2]. + XlaOp warp = ctx->Input("warp"); + // Dimensions are [batch, dim_0, ...dim_n, channel]. + XlaOp grad_output = ctx->Input("grad_output"); + + // Find the top left corner coordinate for the region to be sampled from. + // The dimensions are [batch, dim_0, ... dim_n, 2] where the last dimension + // of size 2 in turn is [x, y]. + XlaOp top_left = xla::ConvertElementType(warp, xla::U32); + + // Dimensions are [batch, dim_0, ... dim_n, 2] + XlaOp ratio = warp - xla::ConvertElementType(top_left, warp_type); + + // Indices for gathering neighboring pixels. + auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape); + + auto grad_data = + CalculateGradData(ctx, grad_output, ratio, gather_indices, warp_type, + warp_shape, data_channels, data_shape); + + auto grad_warp = + CalculateGradWarp(ctx, grad_output, ratio, gather_indices, data, + warp_shape, data_channels, data_type); + + ctx->SetOutput(0, grad_data); + ctx->SetOutput(1, grad_warp); + } +}; + +REGISTER_XLA_OP(Name("ResamplerGrad"), ResamplerGradOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 03a50ef8a059e5..7ff3e916381143 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -17,8 +17,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { @@ -61,113 +63,79 @@ class ReverseSequenceOp : public XlaOpKernel { const auto seq_lens = context->Input(1); const int64 batch_size = input_shape.dim_size(batch_dim_); + if (batch_size == 0) { + context->SetOutput(0, input); + return; + } - const DataType input_type = context->input_type(0); - const DataType seq_lens_type = context->input_type(1); + // Given the input + // + // 012345 + // 6789AB + // + // and sequence lens {2, 3} we: + // + // 1. Reverse and pad each row to get + // + // 543210XXXXXX + // BA9876XXXXXX + // + // 2. Gather out the suffix from each row to get + // + // 10XXXX + // 876XXX + // + // 3. Select from the input and the array created by (2) to get the result. + // + // 102345 + // 8769AB + const xla::PrimitiveType input_type = context->input_xla_type(0); + const xla::PrimitiveType seq_lens_type = context->input_xla_type(1); const int64 max_seq_len = input_shape.dim_size(seq_dim_); - xla::Shape input_xla_shape; - OP_REQUIRES_OK(context, TensorShapeToXLAShape(input_type, input_shape, - &input_xla_shape)); - xla::Shape seq_lens_xla_shape; - OP_REQUIRES_OK(context, TensorShapeToXLAShape(seq_lens_type, seq_lens_shape, - &seq_lens_xla_shape)); - - const auto tuple_shape = xla::ShapeUtil::MakeTupleShape({ - xla::ShapeUtil::MakeShape(seq_lens_xla_shape.element_type(), {}), - seq_lens_xla_shape, - input_xla_shape, - }); - - // For each entry in the batch, reverse the sequence. - // TODO(b/65689298): generalize the Map() operator to non-scalar cases and - // use it here, instead of a While loop. - - // Condition: lambda (i, _, _): i < batch_size - auto condition_builder = - builder->CreateSubBuilder("reverse_sequence_condition"); - { - auto param = - xla::Parameter(condition_builder.get(), 0, tuple_shape, "param"); - auto i = xla::GetTupleElement(param, 0); - xla::Lt(i, XlaHelpers::IntegerLiteral(condition_builder.get(), - seq_lens_type, batch_size)); - } - auto condition = condition_builder->Build(); - OP_REQUIRES_OK(context, condition.status()); - - auto body_builder = builder->CreateSubBuilder("reverse_sequence_body"); - { - auto param = xla::Parameter(body_builder.get(), 0, tuple_shape, "param"); - auto i = xla::GetTupleElement(param, 0); - auto seq_lens = xla::GetTupleElement(param, 1); - auto output = xla::GetTupleElement(param, 2); - - // seq_len is the sequence length of the current batch element (rank 1) - auto seq_len = xla::DynamicSlice(seq_lens, xla::Reshape(i, {1}), {1}); - - // Indices is the offset of the batch element in the input. - auto batch_element_indices = - xla::Broadcast(XlaHelpers::Zero(body_builder.get(), seq_lens_type), - {input_shape.dims()}); - batch_element_indices = xla::DynamicUpdateSlice( - batch_element_indices, xla::Reshape(i, {1}), - xla::Reshape(XlaHelpers::IntegerLiteral(body_builder.get(), - seq_lens_type, batch_dim_), - {1})); - - // Slice out the current batch element and pad it out in the sequence - // dimension. - TensorShape slice_shape = input_shape; - slice_shape.set_dim(batch_dim_, 1); - slice_shape.set_dim(seq_dim_, max_seq_len); - auto slice = xla::DynamicSlice(output, batch_element_indices, - slice_shape.dim_sizes()); - auto padding_config = xla::MakeNoPaddingConfig(slice_shape.dims()); - padding_config.mutable_dimensions(seq_dim_)->set_edge_padding_high( - slice_shape.dim_size(seq_dim_)); - slice = xla::Pad(slice, XlaHelpers::Zero(body_builder.get(), input_type), - padding_config); - - // Now slice out the reversed sequence from its actual start. - // sequence_start_indices is the offset of the start of the reversed - // sequence in the input. The slice will go into the padding, however, we - // will mask off these elements and replace them with elements from the - // original input so their values do not matter. - auto sequence_start_indices = - xla::Broadcast(XlaHelpers::Zero(body_builder.get(), seq_lens_type), - {slice_shape.dims()}); - sequence_start_indices = xla::DynamicUpdateSlice( - sequence_start_indices, - xla::Sub(XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, - max_seq_len), - seq_len), - xla::Reshape(XlaHelpers::IntegerLiteral(body_builder.get(), - seq_lens_type, seq_dim_), - {1})); - slice = xla::DynamicSlice(slice, sequence_start_indices, - slice_shape.dim_sizes()); - - // Shift the reversed sequence to the left. - output = xla::DynamicUpdateSlice(output, slice, batch_element_indices); - - xla::Tuple( - body_builder.get(), - {xla::Add(i, XlaHelpers::One(body_builder.get(), seq_lens_type)), - seq_lens, output}); + xla::XlaOp rev = xla::Rev(input, {seq_dim_}); + + auto padding_config = xla::MakeNoPaddingConfig(input_shape.dims()); + padding_config.mutable_dimensions(seq_dim_)->set_edge_padding_high( + max_seq_len); + xla::XlaOp padded = + xla::Pad(rev, xla::Zero(builder, input_type), padding_config); + + // Form a start indices tensor with shape [2, batch_size]. For each batch + // entry we have a (batch offset, seq offset) pair. + xla::XlaOp start_indices = xla::ConcatInDim( + builder, + { + xla::Iota(builder, + xla::ShapeUtil::MakeShape(seq_lens_type, {1, batch_size}), + /*iota_dimension=*/1), + xla::Reshape(xla::ScalarLike(seq_lens, max_seq_len) - seq_lens, + {1, batch_size}), + }, + /*dimension=*/0); + + xla::GatherDimensionNumbers dnums; + // The first dimension of start_indices contains the batch/seq dim choice. + dnums.set_index_vector_dim(0); + dnums.add_start_index_map(batch_dim_); + dnums.add_start_index_map(seq_dim_); + + // All other dimensions other than the batch dim are offset dimensions. + for (int i = 0; i < input_shape.dims(); ++i) { + if (i != batch_dim_) { + dnums.add_offset_dims(i); + } } - auto body = body_builder->Build(); - OP_REQUIRES_OK(context, body.status()); - - auto loop_output = xla::While( - condition.ValueOrDie(), body.ValueOrDie(), - xla::Tuple(builder, {XlaHelpers::Zero(builder, seq_lens_type), seq_lens, - xla::Rev(input, {seq_dim_})})); - auto output = xla::GetTupleElement(loop_output, 2); - - // Mask out elements after the sequence length. - xla::XlaOp iota = - xla::Iota(builder, seq_lens_xla_shape.element_type(), max_seq_len); + dnums.add_collapsed_slice_dims(batch_dim_); + + auto slice_sizes = input_shape.dim_sizes(); + slice_sizes[batch_dim_] = 1; + + xla::XlaOp output = xla::Gather(padded, start_indices, dnums, slice_sizes); + + // Mask out elements after the sequence length, and copy the corresponding + // elements from the input. + xla::XlaOp iota = xla::Iota(builder, seq_lens_type, max_seq_len); std::vector dims(input_shape.dims(), 1); dims[batch_dim_] = batch_size; auto mask = xla::Lt(iota, xla::Reshape(seq_lens, dims), {seq_dim_}); diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index d79cdad9fa2dab..7b96b43ad834c2 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -126,7 +126,9 @@ class StackOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(StackOp); }; -REGISTER_XLA_OP(Name("StackV2").CompileTimeConstantInput("max_size"), StackOp); +REGISTER_XLA_OP( + Name("StackV2").CompileTimeConstantInput("max_size").CompilationOnly(), + StackOp); class StackPushOp : public XlaOpKernel { public: @@ -173,7 +175,7 @@ class StackPushOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(StackPushOp); }; -REGISTER_XLA_OP(Name("StackPushV2"), StackPushOp); +REGISTER_XLA_OP(Name("StackPushV2").CompilationOnly(), StackPushOp); class StackPopOp : public XlaOpKernel { public: @@ -227,7 +229,7 @@ class StackPopOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(StackPopOp); }; -REGISTER_XLA_OP(Name("StackPopV2"), StackPopOp); +REGISTER_XLA_OP(Name("StackPopV2").CompilationOnly(), StackPopOp); class StackCloseOp : public XlaOpKernel { public: @@ -241,7 +243,7 @@ class StackCloseOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(StackCloseOp); }; -REGISTER_XLA_OP(Name("StackCloseV2"), StackCloseOp); +REGISTER_XLA_OP(Name("StackCloseV2").CompilationOnly(), StackCloseOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 0394b6b533ff97..cc81772e8c5da7 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/versions.pb.h" @@ -75,6 +76,222 @@ Status CheckFeedFetchNameConflicts(const string& kind, return Status::OK(); } +// For graph `g`, copy all function call nodes' FunctionDef from `lookup_fld` to +// `fld`. This is to ensure that `fld` can instantiate FunctionDef of graph `g`. +Status CopyAssociatedFunctions(Graph* g, + const FunctionLibraryDefinition* lookup_fld, + FunctionLibraryDefinition* fld) { + for (Node* n : g->op_nodes()) { + for (const auto& associated_function : + GetAssociatedFunctions(*n, lookup_fld)) { + switch (associated_function.type()) { + case AssociatedFunctionInfo::kFunctionCallNode: { + const FunctionDef* fdef = + lookup_fld->Find(associated_function.func_name()); + if (!fdef) { + return errors::Internal( + "Cannot find function ", associated_function.func_name(), + " for function call node ", n->DebugString()); + } + TF_RETURN_IF_ERROR(fld->AddFunctionDef(*fdef)); + break; + } + case AssociatedFunctionInfo::kSymbolicGradient: + case AssociatedFunctionInfo::kFunctionAttr: + break; + } + } + } + return Status::OK(); +} + +// For graph `g`, replaces _Arg nodes whose "index" attribute is in +// `const_input_index_to_node` with Const nodes. +Status ReplaceArgUsageWithConstNode( + Graph* g, + const std::unordered_map& const_input_index_to_node) { + // Collect all _Arg nodes. + std::unordered_map arg_nodes; + for (Node* n : g->op_nodes()) { + if (n->type_string() == FunctionLibraryDefinition::kArgOp) { + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); + arg_nodes[index] = n; + } + } + + for (const auto& iter : const_input_index_to_node) { + int arg_index = iter.first; + Node* const_node = g->CopyNode(iter.second); + Node* arg_node = arg_nodes[arg_index]; + + // Collect all usages of the _Arg node. + struct OutEdgeInfo { + int dst_node_id, dst_input; + }; + std::vector usages; + for (const Edge* e : arg_node->out_edges()) { + if (e->IsControlEdge()) { + continue; + } + usages.push_back({e->dst()->id(), e->dst_input()}); + } + + for (int i = 0; i < usages.size(); i++) { + // Make a copy of `usage_node`, and change its input to const node. + Node* usage_node = g->FindNodeId(usages[i].dst_node_id); + NodeDef replace_def = usage_node->def(); + *replace_def.mutable_input(usages[i].dst_input) = const_node->name(); + TF_ASSIGN_OR_RETURN(Node * replace_node, + ReplaceNode(g, usage_node, replace_def)); + const Edge* usage_edge; + TF_RETURN_IF_ERROR( + replace_node->input_edge(usages[i].dst_input, &usage_edge)); + g->RemoveEdge(usage_edge); + g->AddEdge(const_node, 0, replace_node, usages[i].dst_input); + + // Later entries in `usages` might have `usage_node` as dst node, but + // `usage_node` is removed. Replace such entries with `replace_node`. + for (int j = i + 1; j < usages.size(); j++) { + if (usages[j].dst_node_id == usages[i].dst_node_id) { + usages[j].dst_node_id = replace_node->id(); + } + } + } + } + return Status::OK(); +} + +// For a node's function attr (e.g. then/else branch for "If" nodes), rewrites +// the function to replace _Arg nodes in `const_input_index_to_node` with Const +// inputs. +Status PropagateConstIntoFuncAttr( + Node* n, const string& attr_name, + const std::unordered_map& const_input_index_to_node, + const FunctionLibraryDefinition* lookup_fld, + FunctionLibraryDefinition* fld) { + // Instantiate the function. + NameAttrList func_attr; + TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &func_attr)); + const FunctionDef* fdef = lookup_fld->Find(func_attr.name()); + if (!fdef) { + return errors::Internal("Cannot find function ", func_attr.name(), + " for node ", n->name()); + } + FunctionBody* fbody; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fdef, AttrSlice(&func_attr.attr()), lookup_fld, + [lookup_fld](const string& op, const OpDef** sig) { + return lookup_fld->LookUpOpDef(op, sig); + }, + &fbody)); + std::unique_ptr fbody_deleter(fbody); + + // Rewrite _Arg usages with Const node. + Graph* func_graph = fbody->graph; + TF_RETURN_IF_ERROR( + ReplaceArgUsageWithConstNode(func_graph, const_input_index_to_node)); + + // Save rewritten function. + FunctionDef replace_fdef; + string new_func_name = + fld->UniqueFunctionName(absl::StrCat(func_attr.name(), "_const_")); + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*func_graph, new_func_name, &replace_fdef)); + TF_RETURN_IF_ERROR(fld->AddFunctionDef(replace_fdef)); + + // Change the node to use rewritten function. + func_attr.set_name(new_func_name); + n->ClearAttr(attr_name); + n->AddAttr(attr_name, func_attr); + + // Copy associated functions. + TF_RETURN_IF_ERROR(CopyAssociatedFunctions(func_graph, lookup_fld, fld)); + + return Status::OK(); +} + +// For an "If" node in graph `g`, if it has Const node inputs, rewrite its +// then/else branch function to replace _Arg nodes with those Const inputs. +Status PropagateConstIntoIfNode(Graph* g, Node* if_node, + const FunctionLibraryDefinition* lookup_fld, + FunctionLibraryDefinition* fld) { + // Notice that first input for If node is predicate; other inputs are function + // inputs. + std::unordered_map const_input_index_to_node; + for (int i = 1; i < if_node->num_inputs(); i++) { + const Node* input_node; + TF_RETURN_IF_ERROR(if_node->input_node(i, &input_node)); + if (input_node->type_string() == "Const") { + const_input_index_to_node[i - 1] = input_node; + } + } + if (const_input_index_to_node.empty()) { + return Status::OK(); + } + + // Rewrite "then_branch" and "else_branch" function, replace usage of those + // _Arg nodes with corresponding const node. + for (const auto& attr_name : + std::vector{"then_branch", "else_branch"}) { + TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr( + if_node, attr_name, const_input_index_to_node, lookup_fld, fld)); + } + + return Status::OK(); +} + +// For a "While" node in graph `g`, if it has Const node inputs, rewrite its +// cond/body function to replace _Arg nodes with those Const inputs. +Status PropagateConstIntoWhileNode(Graph* g, Node* while_node, + const FunctionLibraryDefinition* lookup_fld, + FunctionLibraryDefinition* fld) { + // For "While" node, we should only replace _Arg nodes which are loop + // invariants. For such _Arg nodes, the return value's input will come + // directly from the corresponding arg. + std::unordered_map const_input_index_to_node; + NameAttrList body_attr; + TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "body", &body_attr)); + const FunctionDef* body_func = lookup_fld->Find(body_attr.name()); + if (!body_func) { + return errors::Internal("Cannot find body function ", body_attr.name(), + " for While node ", while_node->name()); + } + for (int i = 0; i < while_node->num_inputs(); i++) { + const Node* input_node; + TF_RETURN_IF_ERROR(while_node->input_node(i, &input_node)); + if (input_node->type_string() != "Const") { + continue; + } + + // Check if i-th retval's input comes from i-th arg directly. + const OpDef_ArgDef& output_arg = body_func->signature().output_arg(i); + auto output_arg_input = body_func->ret().find(output_arg.name()); + if (output_arg_input == body_func->ret().end()) { + return errors::Internal("Cannot find input for output arg ", + output_arg.name(), " in function ", + body_attr.name()); + } + const OpDef_ArgDef& input_arg = body_func->signature().input_arg(i); + if (output_arg_input->second != input_arg.name()) { + continue; + } + + const_input_index_to_node[i] = input_node; + } + if (const_input_index_to_node.empty()) { + return Status::OK(); + } + + // Rewrite "cond" and "body" function, replace usage of those _Arg nodes with + // corresponding const node. + for (const auto& attr_name : std::vector{"cond", "body"}) { + TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr( + while_node, attr_name, const_input_index_to_node, lookup_fld, fld)); + } + return Status::OK(); +} + } // namespace const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation"; @@ -520,4 +737,17 @@ xla::StatusOr BuildIdentityNode( return id_node; } +Status PropagateConstIntoFunctionalNodes( + Graph* g, const FunctionLibraryDefinition* lookup_fld, + FunctionLibraryDefinition* fld) { + for (Node* n : g->op_nodes()) { + if (n->type_string() == "If") { + TF_RETURN_IF_ERROR(PropagateConstIntoIfNode(g, n, lookup_fld, fld)); + } else if (n->type_string() == "While") { + TF_RETURN_IF_ERROR(PropagateConstIntoWhileNode(g, n, lookup_fld, fld)); + } + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index 1232ed8c676ff3..cf3aa2f847c5ad 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -183,6 +183,20 @@ xla::StatusOr BuildIdentityNode(Graph* graph, const string& node_name, DataType dtype, const Node* input, absl::optional requested_device); +// For "If"/"While" nodes, if some of their inputs are Const nodes, rewrite +// body functions to use the Const nodes instead of original _Arg nodes. +// +// For example, say we have the following computation: +// shape = constant_op.constant([1]) +// return tf.cond(pred, lambda: tf.ones(shape), lambda: tf.zeros(shape)) +// If we do not rewrite then/else function, they will use _Arg node as shape +// input for tf.ones/tf.zeros. But XLA requires that shape input to be compile +// time constant, so XLA compilation will fail. This rewriting process will +// change the shape input to Const node. +Status PropagateConstIntoFunctionalNodes( + Graph* g, const FunctionLibraryDefinition* lookup_fld, + FunctionLibraryDefinition* fld); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index fe484c1cbe97d0..e177a5f07f5607 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -756,6 +756,8 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, CompilationResult* result) { VLOG(1) << "Executing graph symbolically to populate XlaBuilder."; + TF_RETURN_IF_ERROR(PropagateConstIntoFunctionalNodes( + graph.get(), options_.flib_def, local_flib_def_.get())); if (VLOG_IS_ON(2)) { VLOG(2) << "XlaCompiler::CompileGraph: " << dump_graph::DumpGraphToFile( diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 9f00de708cc5ac..dcd0e9c5c1f20c 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -129,21 +130,27 @@ XlaOpRegistry::~XlaOpRegistry() = default; // Lazily register the CPU and GPU JIT devices the first time // GetCompilationDevice is called. static void* registration_init = [®istry]() { + legacy_flags::MarkForCompilationPassFlags* flags = + legacy_flags::GetMarkForCompilationPassFlags(); + bool cpu_global_jit = flags->tf_xla_cpu_global_jit; + mutex_lock lock(registry.mutex_); if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_CPU)).ok()) { DeviceRegistration& registration = registry.compilation_devices_[DEVICE_CPU]; registration.compilation_device_name = DEVICE_CPU_XLA_JIT; - registration.requires_compilation = false; - registration.enable_jit_by_default = false; + registration.autoclustering_policy = + cpu_global_jit + ? XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally + : XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested; registration.compile_resource_ops = false; } if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_GPU)).ok()) { DeviceRegistration& registration = registry.compilation_devices_[DEVICE_GPU]; registration.compilation_device_name = DEVICE_GPU_XLA_JIT; - registration.requires_compilation = false; - registration.enable_jit_by_default = true; + registration.autoclustering_policy = + XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally; registration.compile_resource_ops = false; } return nullptr; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 45a40c0acc0780..0bdd4a10854454 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -66,19 +66,26 @@ class XlaOpRegistry { public: typedef OpKernel* (*Factory)(OpKernelConstruction*); + enum class AutoclusteringPolicy { + // Enable autoclustering if the user requests it, e.g., via + // experimental_jit_scope. Does not autocluster if the JIT is enabled + // globally (e.g., via the OptimizerOptions in the TF session + // configuration.) + kIfExplicitlyRequested, + // Enable autoclustering if explicitly requested, or if the JIT is enabled + // globally in the session options, or via TF_XLA_FLAGS=--tf_xla_auto_jit=N. + kIfEnabledGlobally, + // Always try to autocluster ops placed on this device. + kAlways, + }; + // Describes how to compile operators assigned to a device. struct DeviceRegistration { // The name of the an XLA compilation device to use to compile code. string compilation_device_name; - // Do operators assigned to this device require compilation? - bool requires_compilation; - - // If !requires_compilation, should we try to JIT operators on this device - // when XLA JIT compilation is enabled globally via the SessionOptions? - // (It is still possible to explicitly mark operators to JIT compile, even - // if enable_jit_by_default is false.) - bool enable_jit_by_default; + // When should we autocluster operators assigned to this device? + AutoclusteringPolicy autoclustering_policy; // Enable compilation of operators that use DT_RESOURCE types? bool compile_resource_ops = false; diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index fc74ef4aa34f60..d6b60c5f991652 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -7,6 +7,7 @@ package_group( packages = [ "//tensorflow/compiler/...", "//tensorflow/contrib/tpu/...", + "//third_party/py/jax/...", ], ) diff --git a/tensorflow/compiler/xla/README.md b/tensorflow/compiler/xla/README.md index 39f8caaa961dc7..f9c93707f7af30 100644 --- a/tensorflow/compiler/xla/README.md +++ b/tensorflow/compiler/xla/README.md @@ -1,7 +1,6 @@

- +

XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear -algebra that optimizes TensorFlow computations. See the -[documentation](https://www.tensorflow.org/performance/xla/) for more details. +algebra that optimizes TensorFlow computations. See the [documentation](./g3doc/overview.md). diff --git a/tensorflow/compiler/xla/g3doc/README.md b/tensorflow/compiler/xla/g3doc/README.md index ab16f04a7e68b9..6643bf0aab3078 100644 --- a/tensorflow/compiler/xla/g3doc/README.md +++ b/tensorflow/compiler/xla/g3doc/README.md @@ -1,3 +1,3 @@ # XLA: Accelerated Linear Algebra -These are the docs for: https://www.tensorflow.org/extend/xla +These are the docs for: https://www.tensorflow.org/xla diff --git a/tensorflow/compiler/xla/g3doc/_book.yaml b/tensorflow/compiler/xla/g3doc/_book.yaml new file mode 100644 index 00000000000000..bcfbcc3a22f50c --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/_book.yaml @@ -0,0 +1,29 @@ +upper_tabs: +# Tabs left of dropdown menu +- include: /_upper_tabs_left.yaml +- include: /api_docs/_upper_tabs_api.yaml +# Dropdown menu +- name: Ecosystem + path: /ecosystem + is_default: true + menu: + - include: /ecosystem/_menu_toc.yaml + lower_tabs: + # Subsite tabs + other: + - name: Guide + contents: + - title: XLA overview + path: /xla/overview + - title: Broadcasting semantics + path: /xla/broadcasting + - title: Developing a new backend for XLA + path: /xla/developing_new_backend + - title: Using JIT compilation + path: /xla/jit + - title: Operation semantics + path: /xla/operation_semantics + - title: Shapes and layout + path: /xla/shapes + - title: Using AOT compilation + path: /xla/tfcompile diff --git a/tensorflow/compiler/xla/g3doc/_index.yaml b/tensorflow/compiler/xla/g3doc/_index.yaml new file mode 100644 index 00000000000000..7934cd11ba22d3 --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/_index.yaml @@ -0,0 +1,35 @@ +book_path: /xla/_book.yaml +project_path: /xla/_project.yaml +description: +landing_page: + custom_css_path: /site-assets/css/style.css + rows: + - heading: XLA is a compiler that optimizes TensorFlow computations. + items: + - classname: devsite-landing-row-50 + description: > + XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear + algebra that optimizes TensorFlow computations. The results are + improvements in speed, memory usage, and portability on server and mobile + platforms. The XLA framework is experimental and in active development. + For details, read the XLA guide. + + - classname: devsite-landing-row-cards + items: + - heading: XLA - TensorFlow, compiled + image_path: /ecosystem/images/tf-logo-card-16x9.png + path: https://developers.googleblog.com/2017/03/xla-tensorflow-compiled.html + buttons: + - label: Read on Google Developers blog + path: https://developers.googleblog.com/2017/03/xla-tensorflow-compiled.html + - heading: XLA at the Dev Summit + youtube_id: kAOanJczHA0 + buttons: + - label: Watch the video + path: https://www.youtube.com/watch?v=kAOanJczHA0 + - heading: XLA on GitHub + image_path: /ecosystem/images/github-card-16x9.png + path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla + buttons: + - label: View on GitHub + path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla diff --git a/tensorflow/compiler/xla/g3doc/_project.yaml b/tensorflow/compiler/xla/g3doc/_project.yaml index 87c4c51e2b6eed..33d8bdb27a664d 100644 --- a/tensorflow/compiler/xla/g3doc/_project.yaml +++ b/tensorflow/compiler/xla/g3doc/_project.yaml @@ -1,6 +1,6 @@ name: XLA breadcrumb_name: XLA -home_url: /extend/xla +home_url: /xla/ parent_project_metadata_path: /_project.yaml description: > XLA is a compiler-based linear algebra execution engine. diff --git a/tensorflow/compiler/xla/g3doc/_toc.yaml b/tensorflow/compiler/xla/g3doc/_toc.yaml deleted file mode 100644 index ef766e8e9bbb31..00000000000000 --- a/tensorflow/compiler/xla/g3doc/_toc.yaml +++ /dev/null @@ -1,16 +0,0 @@ -toc: -- heading: XLA -- title: XLA overview - path: /extend/xla/ -- title: Broadcasting semantics - path: /extend/xla/broadcasting -- title: Developing a new backend for XLA - path: /extend/xla/developing_new_backend -- title: Using JIT compilation - path: /extend/xla/jit -- title: Operation semantics - path: /extend/xla/operation_semantics -- title: Shapes and layout - path: /extend/xla/shapes -- title: Using AOT compilation - path: /extend/xla/tfcompile diff --git a/tensorflow/compiler/xla/g3doc/index.md b/tensorflow/compiler/xla/g3doc/overview.md similarity index 100% rename from tensorflow/compiler/xla/g3doc/index.md rename to tensorflow/compiler/xla/g3doc/overview.md diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index 3fadabcf520709..2a0241af3ef359 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -29,8 +29,6 @@ namespace xla { /* static */ int64 IndexUtil::MultidimensionalIndexToLinearIndex( const Shape& shape, absl::Span multi_index) { DCHECK_EQ(shape.dimensions_size(), multi_index.size()); - // Padding and nested layouts not supported yet. - DCHECK_EQ(0, shape.layout().padded_dimensions_size()); for (size_t i = 0; i < multi_index.size(); ++i) { DCHECK_GE(multi_index[i], 0); @@ -94,8 +92,6 @@ namespace xla { /* static */ std::vector IndexUtil::LinearIndexToMultidimensionalIndex( const Shape& shape, int64 linear_index) { - // Padding and nested layouts not supported yet. - DCHECK_EQ(0, shape.layout().padded_dimensions_size()); DCHECK_GE(linear_index, 0); DCHECK_LT(linear_index, ShapeUtil::ElementsIn(shape)); @@ -133,18 +129,12 @@ namespace xla { /* static */ int64 IndexUtil::GetDimensionStride(const Shape& shape, int64 dimension) { - int64 pdim_size = LayoutUtil::PaddedDimensions(shape).size(); int64 stride = 1; - DCHECK(pdim_size == 0 || pdim_size == shape.dimensions_size()); for (auto dim : LayoutUtil::MinorToMajor(shape)) { if (dim == dimension) { break; } - if (pdim_size == 0) { - stride *= shape.dimensions(dim); - } else { - stride *= LayoutUtil::PaddedDimension(shape, dim); - } + stride *= shape.dimensions()[dim]; } return stride; } diff --git a/tensorflow/compiler/xla/index_util.h b/tensorflow/compiler/xla/index_util.h index 2979cf87dde928..458bdaf2f89819 100644 --- a/tensorflow/compiler/xla/index_util.h +++ b/tensorflow/compiler/xla/index_util.h @@ -61,8 +61,7 @@ class IndexUtil { static bool BumpIndices(const Shape& shape, absl::Span indices); // Calculates the stride size (in number of elements, not byte size) of a - // given logical shape dimension (from 0 to rank-1). If available, padded - // dimensions are used. + // given logical shape dimension (from 0 to rank-1). // Example: // GetDimensionStride(F32[5,8,10,4]{3,2,1,0}, 1) == // sizeof(dimension(3)) * sizeof(dimension(2)) == 4 * 10 diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 66af644cf78f3c..2398470dd49955 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -201,8 +201,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } if (!ShapeUtil::IsArray(shape)) { - if (layout.minor_to_major_size() != 0 || - layout.padded_dimensions_size() != 0) { + if (layout.minor_to_major_size() != 0) { return InvalidArgument( "shape of primitive type %s should not have a non-trivial layout", PrimitiveType_Name(shape.element_type())); @@ -241,28 +240,6 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } dimensions_in_layout[dim] = true; } - - if (layout.padded_dimensions_size() > 0) { - if (layout.padded_dimensions_size() != ShapeUtil::Rank(shape)) { - return InvalidArgument( - "layout has %d padded dimensions, but shape is rank %d", - layout.padded_dimensions_size(), ShapeUtil::Rank(shape)); - } - for (int i = 0; i < layout.padded_dimensions_size(); ++i) { - if (layout.padded_dimensions(i) < shape.dimensions(i)) { - return InvalidArgument( - "for dimension %d, dimension padding (%d) is smaller than " - "the dimension size (%d) of the shape", - i, layout.padded_dimensions(i), shape.dimensions(i)); - } - } - } - } - - if (layout.format() == SPARSE) { - if (!layout.padded_dimensions().empty()) { - return InvalidArgument("Sparse layout has padded dimensions"); - } } return Status::OK(); @@ -303,38 +280,6 @@ Layout CreateDefaultLayoutForRank(int64 rank) { layout.minor_to_major().end(), std::greater()); } -/* static */ bool LayoutUtil::IsPadded(const Shape& shape) { - if (!ShapeUtil::IsArray(shape) || !HasLayout(shape) || - shape.layout().padded_dimensions_size() == 0) { - return false; - } - CHECK(IsDenseArray(shape)) << shape.ShortDebugString(); - CHECK_EQ(shape.dimensions_size(), shape.layout().padded_dimensions_size()); - for (int64 i = 0; i < shape.dimensions_size(); ++i) { - if (shape.layout().padded_dimensions(i) > shape.dimensions(i)) { - return true; - } - } - return false; -} - -/* static */ absl::Span LayoutUtil::PaddedDimensions( - const Shape& shape) { - CHECK(IsDenseArray(shape)); - return AsInt64Slice(shape.layout().padded_dimensions()); -} - -/* static */ int64 LayoutUtil::PaddedDimension(const Shape& shape, - int64 index) { - CHECK(IsDenseArray(shape)); - return shape.layout().padded_dimensions(index); -} - -/* static */ PaddingValue LayoutUtil::GetPaddingValue(const Shape& shape) { - CHECK(IsDenseArray(shape)); - return shape.layout().padding_value(); -} - /* static */ bool LayoutUtil::IsSparseArray(const Shape& shape) { return ShapeUtil::IsArray(shape) && shape.has_layout() && IsSparse(shape.layout()); @@ -513,13 +458,6 @@ std::ostream& operator<<(std::ostream& out, const Layout& layout) { for (int64 minor_to_major : layout.minor_to_major()) { hash_value = Hash64Combine(hash_value, hash()(minor_to_major)); } - - for (int64 padded_dim : layout.padded_dimensions()) { - hash_value = Hash64Combine(hash_value, hash()(padded_dim)); - } - - hash_value = - Hash64Combine(hash_value, hash()(layout.padding_value())); hash_value = Hash64Combine(hash_value, layout.max_sparse_elements()); return hash_value; diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 97806d7e331114..6e0390763da151 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -104,23 +104,6 @@ class LayoutUtil { // more minor, and so on until dimension N-1 which is the minor. static bool IsMonotonicWithDim0Major(const Layout& layout); - // Returns whether the layout of the given shape has padding (a - // padded_dimension value in Layout is greater than the corresponding - // dimension size). - static bool IsPadded(const Shape& shape); - - // Returns the padded_dimensions array for the given Shape. Requires that the - // shape is an array and has a dense layout. - static absl::Span PaddedDimensions(const Shape& shape); - - // Returns the given index of the padded_dimensions array for the given Shape. - // Requires that the shape is an array and has a dense layout. - static int64 PaddedDimension(const Shape& shape, int64 index); - - // Returns the padding_value for the given Shape. Requires that the shape is - // an array and has a dense layout. - static PaddingValue GetPaddingValue(const Shape& shape); - // Returns whether the given Shape is an array (i.e. not a tuple) and has a // sparse format layout. static bool IsSparseArray(const Shape& shape); diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index a50d53eaeb15da..12ce2d2d7c6fa8 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -304,30 +304,6 @@ TEST_F(LayoutUtilTest, SetToDefaultLayoutTuple) { shape.tuple_shapes(1).layout())); } -TEST_F(LayoutUtilTest, IsPadded) { - Shape shape_without_layout = ShapeUtil::MakeShape(F32, {2, 3, 4}); - LayoutUtil::ClearLayout(&shape_without_layout); - EXPECT_FALSE(LayoutUtil::IsPadded(shape_without_layout)); - - Shape shape_with_layout = ShapeUtil::MakeShape(F32, {2, 3, 4}); - LayoutUtil::SetToDefaultLayout(&shape_with_layout); - EXPECT_FALSE(LayoutUtil::IsPadded(shape_with_layout)); - - // Add padding equal to the dimension sizes. In this case the padding is a - // nop. - Shape shape_with_degenerate_padding = ShapeUtil::MakeShape(F32, {2, 3, 4}); - shape_with_degenerate_padding.mutable_layout()->add_padded_dimensions(2); - shape_with_degenerate_padding.mutable_layout()->add_padded_dimensions(3); - shape_with_degenerate_padding.mutable_layout()->add_padded_dimensions(4); - EXPECT_FALSE(LayoutUtil::IsPadded(shape_with_degenerate_padding)); - - Shape shape_with_padding = ShapeUtil::MakeShape(F32, {2, 3, 4}); - shape_with_padding.mutable_layout()->add_padded_dimensions(2); - shape_with_padding.mutable_layout()->add_padded_dimensions(14); - shape_with_padding.mutable_layout()->add_padded_dimensions(42); - EXPECT_TRUE(LayoutUtil::IsPadded(shape_with_padding)); -} - TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) { EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}), LayoutUtil::GetDefaultLayoutForR2())); diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 510aa39b450311..80dfdb83c35183 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -1075,12 +1075,11 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, auto element_to_string = [&](absl::Span indices) -> string { PrimitiveType element_type = subshape.element_type(); - if (element_type == PRED) { - // We display predicates in a densely packed form. - return literal.Get(indices, shape_index) ? "1" : "0"; - } - return ((!indices.empty() && indices.back() > 0) ? ", " : "") + - literal.GetAsString(indices, shape_index); + // We display predicates as 0s and 1s so that the string is more dense. + string elem = element_type == PRED + ? literal.Get(indices, shape_index) ? "1" : "0" + : literal.GetAsString(indices, shape_index); + return ((!indices.empty() && indices.back() > 0) ? ", " : "") + elem; }; if (ShapeUtil::Rank(subshape) == 0) { diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index 8d4b974c166901..9d34d9d504156c 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -34,16 +34,22 @@ namespace xla { namespace literal_comparison { namespace { +// Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be +// able to transparently access the raw 16-bit value contained within. +template +T GetRawValue(T val) { + return val; +} +uint16 GetRawValue(Eigen::half val) { return val.x; } + // Helper function for comparing a floating point type, FloatT, bitwise equal // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT // -- on miscompare, a nice error message is given in the AssertionFailure. template Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs, absl::Span multi_index) { - // TODO(b/118627822): These are unsafe bit_casts because Eigen::Half is not - // trivially copyable. - auto ulhs = absl::bit_cast(lhs); - auto urhs = absl::bit_cast(rhs); + auto ulhs = absl::bit_cast(GetRawValue(lhs)); + auto urhs = absl::bit_cast(GetRawValue(rhs)); auto lhs_double = static_cast(lhs); auto rhs_double = static_cast(rhs); if (ulhs != urhs) { diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index 4ae5ddbfdb8444..3511760ac1cad1 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -133,7 +133,7 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { TEST_F(LiteralUtilTest, LiteralVectorToString) { auto pred_vec = LiteralUtil::CreateR1({true, false, true}); - EXPECT_EQ("{101}", pred_vec.ToString()); + EXPECT_EQ("{1, 0, 1}", pred_vec.ToString()); } TEST_F(LiteralUtilTest, R2ToString) { diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index e4e93090c8236b..b1fae826ab1903 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -184,8 +184,9 @@ StatusOr LocalShapedBufferTuple::Release(int i) { int64 LocalShapedBufferTuple::size() const { return elements_.size(); } -XrtAllocation::XrtAllocation(int64 handle, Shape shape) - : handle_(handle), shape_(shape) {} +XrtAllocation::XrtAllocation(int64 handle, Shape shape, + const string& session_target) + : handle_(handle), shape_(shape), session_target_(session_target) {} XrtAllocation::~XrtAllocation() { tensorflow::Scope root = tensorflow::Scope::NewRootScope(); @@ -198,7 +199,7 @@ XrtAllocation::~XrtAllocation() { return; } - tensorflow::ClientSession session(root, "local"); + tensorflow::ClientSession session(root, session_target_); tensorflow::ClientSession::FeedType inputs; inputs.insert({allocation_handle, handle()}); std::vector outputs; @@ -210,7 +211,8 @@ XrtAllocation::~XrtAllocation() { } /* static */ -StatusOr XrtAllocation::FromLiteral(const Literal& argument) { +StatusOr XrtAllocation::FromLiteral( + const Literal& argument, const string& session_target) { xrt::XLAAllocation alloc; alloc.set_device_ordinal(0); *alloc.mutable_value() = argument.ToProto(); @@ -221,14 +223,14 @@ StatusOr XrtAllocation::FromLiteral(const Literal& argument) { auto literal_handle = tensorflow::ops::XRTAllocate(root, literal_string); TF_RETURN_IF_ERROR(root.status()); - tensorflow::ClientSession session(root, "local"); + tensorflow::ClientSession session(root, session_target); tensorflow::ClientSession::FeedType inputs; inputs.insert({literal_string, alloc.SerializeAsString()}); std::vector outputs; TF_RETURN_IF_ERROR(session.Run(inputs, {literal_handle}, &outputs)); int64 handle = outputs[0].scalar()(); - return new XrtAllocation(handle, argument.shape()); + return new XrtAllocation(handle, argument.shape(), session_target); } const int64 XrtAllocation::handle() const { return handle_; } @@ -242,7 +244,7 @@ StatusOr XrtAllocation::ToLiteral() const { auto read_literal = tensorflow::ops::XRTReadLiteral(root, allocation_handle); TF_RETURN_IF_ERROR(root.status()); - tensorflow::ClientSession session(root, "local"); + tensorflow::ClientSession session(root, session_target_); tensorflow::ClientSession::FeedType inputs; inputs.insert({allocation_handle, handle()}); std::vector outputs; @@ -357,8 +359,11 @@ static StatusOr GetReturnValueShape(const XlaComputation& computation) { } CompiledXrtComputation::CompiledXrtComputation( - const ProgramShape& program_shape, int64 handle) - : program_shape_(program_shape), handle_(handle) {} + const ProgramShape& program_shape, int64 handle, + const string& session_target) + : program_shape_(program_shape), + handle_(handle), + session_target_(session_target) {} CompiledXrtComputation::~CompiledXrtComputation() { tensorflow::Scope root = tensorflow::Scope::NewRootScope(); @@ -371,7 +376,7 @@ CompiledXrtComputation::~CompiledXrtComputation() { return; } - tensorflow::ClientSession session(root, "local"); + tensorflow::ClientSession session(root, session_target_); tensorflow::ClientSession::FeedType inputs; inputs.insert({computation_handle, handle()}); std::vector outputs; @@ -407,7 +412,7 @@ StatusOr CompiledXrtComputation::Execute( e.set_release_input_handles(false); e.set_release_compilation_handle(false); - tensorflow::ClientSession session(root, "local"); + tensorflow::ClientSession session(root, session_target_); tensorflow::ClientSession::FeedType inputs; for (int i = 0; i < arguments.size(); ++i) { inputs.insert({arguments[i], argument_handles[i]->handle()}); @@ -418,7 +423,7 @@ StatusOr CompiledXrtComputation::Execute( TF_RETURN_IF_ERROR(session.Run(inputs, {execute}, &outputs)); int64 output = outputs[0].scalar()(); - return new XrtAllocation(output, program_shape().result()); + return new XrtAllocation(output, program_shape().result(), session_target_); } const ProgramShape& CompiledXrtComputation::program_shape() const { @@ -451,7 +456,7 @@ StatusOr LocalComputation::Compile( } StatusOr LocalComputation::CompileForXrt( - const std::vector& argument_shapes) { + const std::vector& argument_shapes, const string& session_target) { tensorflow::Scope root = tensorflow::Scope::NewRootScope(); auto program = tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); auto compile = tensorflow::ops::XRTCompile(root, program); @@ -468,7 +473,7 @@ StatusOr LocalComputation::CompileForXrt( auto snapshot = computation().Snapshot().ValueOrDie(); *c.mutable_hlo_snapshot() = *snapshot; - tensorflow::ClientSession session(root, "local"); + tensorflow::ClientSession session(root, session_target); tensorflow::ClientSession::FeedType inputs; inputs.insert({program, c.SerializeAsString()}); std::vector outputs; @@ -477,7 +482,7 @@ StatusOr LocalComputation::CompileForXrt( TF_ASSIGN_OR_RETURN(ProgramShape program_shape, computation().GetProgramShape()); int64 handle = outputs[0].scalar()(); - return new CompiledXrtComputation(program_shape, handle); + return new CompiledXrtComputation(program_shape, handle, session_target); } const XlaComputation& LocalComputation::computation() const { @@ -929,7 +934,7 @@ StatusOr DestructureLocalShapedBufferTuple( } StatusOr DestructureXrtAllocationTuple( - XrtAllocation* allocation) { + XrtAllocation* allocation, const string& session_target) { const Shape& tuple_shape = allocation->shape(); if (!ShapeUtil::IsTuple(tuple_shape)) { @@ -945,7 +950,7 @@ StatusOr DestructureXrtAllocationTuple( auto subtuple = tensorflow::ops::XRTSubTuple(root, base_handle, shape_index); TF_RETURN_IF_ERROR(root.status()); - tensorflow::ClientSession session(root, "local"); + tensorflow::ClientSession session(root, session_target); tensorflow::ClientSession::FeedType inputs; std::vector results; for (int32 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { @@ -964,7 +969,8 @@ StatusOr DestructureXrtAllocationTuple( const int64 subtuple_handle = outputs[0].scalar()(); const Shape& subtuple_shape = ShapeUtil::GetTupleElementShape(tuple_shape, i); - results.push_back(new XrtAllocation(subtuple_handle, subtuple_shape)); + results.push_back( + new XrtAllocation(subtuple_handle, subtuple_shape, session_target)); } return new XrtAllocationTuple(std::move(results)); } diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 720cb77e007485..82f84ddb35bd44 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ #define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ +#include +#include + #include "absl/types/span.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" @@ -110,17 +113,22 @@ StatusOr DestructureLocalShapedBufferTuple( // graph, and an XLA shape to track the referent's shape. class XrtAllocation { public: - static StatusOr FromLiteral(const Literal& argument); + // Accepts a `session_target` argument, used in constructing the + // `tensorflow::ClientSession` instance in which allocation and deallocation + // graphs are run. + static StatusOr FromLiteral(const Literal& argument, + const string& session_target); - XrtAllocation(int64 handle, Shape shape); + XrtAllocation(int64 handle, Shape shape, const string& session_target); ~XrtAllocation(); StatusOr ToLiteral() const; const Shape& shape() const; const int64 handle() const; private: - int64 handle_; - Shape shape_; + const int64 handle_; + const Shape shape_; + const string session_target_; }; // Result of a tuple destructuring operation on an XrtAllocation. @@ -145,8 +153,12 @@ class XrtAllocationTuple { // Destructures a tuple-valued XrtAllocation into its constitutent elements // in XrtAllocationTuple form. +// +// Accepts a `session_target` argument, used in constructing the +// `tensorflow::ClientSession` instance in which the sub-tupling graph is run, +// and passed along in constructing each constituent XrtAllocation. StatusOr DestructureXrtAllocationTuple( - XrtAllocation* allocation); + XrtAllocation* allocation, const string& session_target); // Represents a compiled computation that can be executed given handles to // device-allocated literals. Specifically, wraps an XLA LocalExecutable. @@ -165,7 +177,10 @@ class CompiledLocalComputation { // device-allocated literals. Specifically, wraps an XRT computation handle. class CompiledXrtComputation { public: - CompiledXrtComputation(const ProgramShape& program_shape, int64 handle); + // Accepts a `session_target` argument, used in constructing the + // `tensorflow::ClientSession` instance in which the execution graph is run. + CompiledXrtComputation(const ProgramShape& program_shape, int64 handle, + const string& session_target); ~CompiledXrtComputation(); StatusOr Execute( @@ -175,8 +190,9 @@ class CompiledXrtComputation { int64 handle() const; private: - ProgramShape program_shape_; - int64 handle_; + const ProgramShape program_shape_; + const int64 handle_; + const string session_target_; }; // Wraps a XlaComputation produced by a LocalComputationBuilder. The @@ -191,8 +207,10 @@ class LocalComputation { const std::vector& argument_shapes, const ExecutableBuildOptions* build_options); + // Accepts a `session_target` argument, used in constructing the + // `tensorflow::ClientSession` instance in which the compilation graph is run. StatusOr CompileForXrt( - const std::vector& argument_shapes); + const std::vector& argument_shapes, const string& session_target); const XlaComputation& computation() const; diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index ae57dc49feb44c..c13d00d2530c7e 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -451,6 +451,10 @@ tensorflow::ImportNumpy(); // Shape +%typemap(out) const Shape& { + $result = numpy::PyShapeInfoFromXlaShape(*$1); +} + %typemap(out) StatusOr { if ($1.ok()) { $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()); @@ -980,6 +984,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalShapedBuffer; %unignore xla::swig::LocalShapedBuffer::FromLiteral; %unignore xla::swig::LocalShapedBuffer::ToLiteral; +%unignore xla::swig::LocalShapedBuffer::shape; %unignore xla::swig::LocalShapedBufferTuple; %unignore xla::swig::LocalShapedBufferTuple::Release; %unignore xla::swig::LocalShapedBufferTuple::size; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 7bc5988480d1f8..07e0e093255b2b 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -51,6 +51,10 @@ class BackendType(enum.Enum): XRT = 2 +BackendSpec = collections.namedtuple('Backend', ('backend_type', 'target')) +XLA_LOCAL_BACKEND = BackendSpec(BackendType.XLA_LOCAL, 'local') + + def OpMetadataToProto(pyobj): proto = xla_data_pb2.OpMetadata() for field in _OP_METADATA_FIELDS: @@ -211,17 +215,17 @@ class LocalBuffer(object): def __init__(self, c_buffer, backend): self.c_buffer = c_buffer self._backend = backend - if backend == BackendType.XRT: + if backend.backend_type == BackendType.XRT: self._delete = c_api.DeleteXrtAllocation else: self._delete = c_api.DeleteLocalShapedBuffer @staticmethod - def from_pyval(pyval, backend=BackendType.XLA_LOCAL): + def from_pyval(pyval, backend=XLA_LOCAL_BACKEND): """Allocate and copy to XLA the given python value.""" pyval = require_numpy_array_layout(pyval) - if backend == BackendType.XRT: - cbuf = c_api.XrtAllocation.FromLiteral(pyval) + if backend.backend_type == BackendType.XRT: + cbuf = c_api.XrtAllocation.FromLiteral(pyval, backend.target) else: cbuf = c_api.LocalShapedBuffer.FromLiteral(pyval, None) return LocalBuffer(cbuf, backend) @@ -229,6 +233,9 @@ def from_pyval(pyval, backend=BackendType.XLA_LOCAL): def to_py(self): return self.c_buffer.ToLiteral() + def shape(self): + return _wrap_shape(self.c_buffer.shape()) + def delete(self): if self.c_buffer is not None: self._delete(self.c_buffer) @@ -237,8 +244,9 @@ def delete(self): def destructure(self): """Assuming a tuple buffer, unpack it into constituent tuple elements.""" assert self.c_buffer is not None - if self._backend == BackendType.XRT: - result = c_api.DestructureXrtAllocationTuple(self.c_buffer) + if self._backend.backend_type == BackendType.XRT: + result = c_api.DestructureXrtAllocationTuple(self.c_buffer, + self._backend.target) else: result = c_api.DestructureLocalShapedBufferTuple(self.c_buffer) self.delete() @@ -467,14 +475,14 @@ class LocalComputation(object): ComputationBuilder methods. """ - def __init__(self, c_computation, is_compiled, backend=BackendType.XLA_LOCAL): + def __init__(self, c_computation, is_compiled, backend=XLA_LOCAL_BACKEND): self._c_computation = c_computation self._backend = backend self._is_compiled = is_compiled # Ensure a reference to C-based destructor for use in __del__. if is_compiled: - if backend == BackendType.XRT: + if backend.backend_type == BackendType.XRT: assert isinstance(c_computation, c_api.CompiledXrtComputation) self._delete = c_api.DeleteCompiledXrtComputation else: @@ -535,8 +543,8 @@ def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None): compile_options = compile_options or CompileOptions() compile_options.result_shape = result_shape - if self._backend == BackendType.XRT: - c = self.computation.CompileForXrt(argument_shapes) + if self._backend.backend_type == BackendType.XRT: + c = self.computation.CompileForXrt(argument_shapes, self._backend.target) else: c = self.computation.Compile(argument_shapes, compile_options) return LocalComputation(c, is_compiled=True, backend=self._backend) @@ -590,7 +598,7 @@ def __init__(self, name): self._client = c_api.LocalComputationBuilder(name.encode('utf8')) self._parameter_numbering = itertools.count() - def Build(self, root=None, backend=BackendType.XLA_LOCAL): + def Build(self, root=None, backend=XLA_LOCAL_BACKEND): if root is not None: return LocalComputation( self._client.BuildWithRoot(root), is_compiled=False, backend=backend) diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index d04eef81dc8345..21b5c93b615ec4 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -439,6 +439,13 @@ def testDestructureTupleNested(self): np.testing.assert_equal(NumpyArrayF32([1.0, 2.0]), got[0]) np.testing.assert_equal(NumpyArrayS32([3, 4]), got[1]) + def testShape(self): + pyval = np.array([[1., 2.]], np.float32) + local_buffer = xla_client.LocalBuffer.from_pyval(pyval) + xla_shape = local_buffer.shape() + self.assertEqual(xla_shape.dimensions(), (1, 2,)) + self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32)) + class SingleOpTest(LocalComputationTest): """Tests for single ops. diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 017b11465d1400..04b2f72ac9525e 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -323,7 +323,6 @@ cc_library( ":hlo_casting_utils", ":hlo_module_config", ":hlo_proto", - ":hlo_reachability", ":name_uniquer", "//tensorflow/compiler/xla:array", "//tensorflow/compiler/xla:literal", @@ -402,6 +401,7 @@ cc_library( srcs = ["hlo_reachability.cc"], hdrs = ["hlo_reachability.h"], deps = [ + ":hlo", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", @@ -1103,6 +1103,7 @@ cc_library( ":hlo", ":hlo_dataflow_analysis", ":hlo_proto", + ":hlo_reachability", ":hlo_value", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1362,6 +1363,7 @@ cc_library( ":fusion_queue", ":hlo", ":hlo_pass", + ":hlo_reachability", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", @@ -1387,6 +1389,7 @@ cc_library( srcs = ["multi_output_fusion.cc"], hdrs = ["multi_output_fusion.h"], deps = [ + ":hlo_reachability", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/service:hlo", @@ -3241,6 +3244,7 @@ cc_library( ":hlo_profile_printer_data", ":human_readable_profile_builder", "//tensorflow/compiler/xla:types", + "@com_google_absl//absl/strings", ], ) @@ -3365,6 +3369,7 @@ cc_library( ":bfloat16_normalization", ":defuser", ":hlo", + ":hlo_memory_scheduler", ":hlo_pass", ":hlo_pass_pipeline", ":implicit_broadcast_remover", @@ -3448,6 +3453,7 @@ tf_cc_test( ":hlo_casting_utils", ":hlo_matchers", ":hlo_parser", + "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:window_util", "//tensorflow/core:lib", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 72ed5ca4821729..85fc42f7475645 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -306,6 +306,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Tries to use a kDot in place of the given convolution. StatusOr SimplifyConvToDot(HloInstruction* convolution); + // Tries to simplify a slice(pad(...)) where the result of the slice is a + // scalar. + StatusOr TrySimplifySliceOfPad(HloInstruction* slice); + // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; @@ -1822,6 +1826,62 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) { return Status::OK(); } +StatusOr AlgebraicSimplifierVisitor::TrySimplifySliceOfPad( + HloInstruction* slice) { + // Only try to do this for effective scalars. We could do the same for slicing + // out larger pieces of padding (replacing with a broadcast of the padding + // value), but this is probably not worth it. + if (!ShapeUtil::IsEffectiveScalar(slice->shape()) || + slice->operand(0)->opcode() != HloOpcode::kPad) { + return false; + } + + VLOG(10) << "Trying to simplify scalar slice of pad"; + // Check there's no internal padding. Again, we could handle that too, since + // everything is statically known, but it's not worth it. + auto pad = Cast(slice->mutable_operand(0)); + auto padding_config = pad->padding_config(); + int64 rank = padding_config.dimensions_size(); + if (HasInteriorPadding(padding_config)) { + VLOG(10) << "Not folding scalar slice of pad, pad has interior padding"; + return false; + } + + // Check whether the scalar we're slicing out falls into the padding. + bool in_padding = [&]() { + for (int64 i = 0; i < rank; ++i) { + int64 start = slice->slice_starts(i); + int64 low = padding_config.dimensions(i).edge_padding_low(); + int64 data = pad->operand(0)->shape().dimensions(i); + if (start >= low && start < low + data) { + return false; + } + } + return true; + }(); + + if (in_padding) { + VLOG(10) << "Folding scalar slice of pad into padding value"; + TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + slice, HloInstruction::CreateReshape(slice->shape(), + pad->mutable_padding_value()))); + return true; + } else { + // We already know the output of the slice is scalar. If the padded + // value is scalar, and it's not in the padding, then it's exactly the + // output value. + bool replaced = + ReplaceInstructionIfSameShape(slice, pad->mutable_operand(0)); + if (replaced) { + VLOG(10) << "Folding scalar slice of pad into padded value"; + } else { + VLOG(10) << "Not folding scalar slice of pad into padded value as they " + "have different shapes."; + } + return replaced; + } +} + Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { // Delete no-op slices, i.e. where shape = operand shape. if (ReplaceInstructionIfSameShape(slice, slice->mutable_operand(0))) { @@ -1846,6 +1906,12 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { slice->shape(), operand_slice->mutable_operand(0), new_slice_starts, new_slice_limits, slice->slice_strides())); } + + TF_ASSIGN_OR_RETURN(bool replaced, TrySimplifySliceOfPad(slice)); + if (replaced) { + return Status::OK(); + } + return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index c79c518700b63b..7b3e957fbcf9f4 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -3163,6 +3163,92 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) { EXPECT_EQ(Cast(root)->iota_dimension(), 2); } +TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param = f32[3,4] parameter(0) + constant = f32[] constant(0.0) + pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5 + ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[2:3],[0:1]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN( + auto module, + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Reshape(op::Constant())); +} + +TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param = f32[3,4] parameter(0) + constant = f32[] constant(0.0) + pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5 + ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[6:7],[9:10]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN( + auto module, + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Reshape(op::Constant())); +} + +TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param = f32[3,4] parameter(0) + constant = f32[] constant(0.0) + pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5 + ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[9:10]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN( + auto module, + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); +} + +TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param = f32[1,1] parameter(0) + constant = f32[] constant(0.0) + pad = f32[8,10] pad(f32[1,1] param, f32[] constant), padding=3_4x4_5 + ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[3:4],[4:5]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN( + auto module, + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Parameter()); +} + struct PadReduceWindowEffectiveBroadcastCase { std::vector input_spatials; std::vector symmetric_pad_spatials; diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index d63287539dfde5..e9d30fc03c1c31 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -151,15 +151,10 @@ Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions( Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { // Do not fold BF16 conversions for instructions related to tuples, entry and - // exit of a computation, fusion, convert, and control flow. + // exit of a computation, fusion, convert, side-effecting instructions and + // control flow. if (hlo->opcode() == HloOpcode::kTuple || // hlo->opcode() == HloOpcode::kGetTupleElement || // - hlo->opcode() == HloOpcode::kInfeed || // - hlo->opcode() == HloOpcode::kOutfeed || // - hlo->opcode() == HloOpcode::kSend || // - hlo->opcode() == HloOpcode::kSendDone || // - hlo->opcode() == HloOpcode::kRecv || // - hlo->opcode() == HloOpcode::kRecvDone || // hlo->opcode() == HloOpcode::kConstant || // hlo->opcode() == HloOpcode::kParameter || // hlo->opcode() == HloOpcode::kFusion || // @@ -167,7 +162,8 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { hlo->opcode() == HloOpcode::kCall || // hlo->opcode() == HloOpcode::kCustomCall || // hlo->opcode() == HloOpcode::kWhile || // - hlo->opcode() == HloOpcode::kConditional) { + hlo->opcode() == HloOpcode::kConditional || // + hlo->HasSideEffectNoRecurse()) { return Status::OK(); } if (hlo == computation_->root_instruction() && @@ -182,6 +178,10 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { Status BFloat16ConversionFoldingVisitor::HandleCrossReplicaSum( HloInstruction* crs) { + if (crs->IsCrossModuleAllReduce()) { + // Cross-module all-reduce has side effect. + return Status::OK(); + } // First use DefaultAction() to handle the operands. It can't handle // tuple-shaped output. TF_RETURN_IF_ERROR(DefaultAction(crs)); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index 1251f0258f5d43..b8a8f844eff17a 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -346,11 +346,9 @@ Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) { Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) { // Do not change instructions related to entry and exit of a computation, - // tuples, fusion, convert, and control flow. + // tuples, fusion, convert, side-effecting instructions, and control flow. if (hlo->opcode() == HloOpcode::kTuple || // hlo->opcode() == HloOpcode::kGetTupleElement || // - hlo->opcode() == HloOpcode::kInfeed || // - hlo->opcode() == HloOpcode::kOutfeed || // hlo->opcode() == HloOpcode::kConstant || // hlo->opcode() == HloOpcode::kParameter || // hlo->opcode() == HloOpcode::kFusion || // @@ -358,7 +356,8 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) { hlo->opcode() == HloOpcode::kCall || // hlo->opcode() == HloOpcode::kCustomCall || // hlo->opcode() == HloOpcode::kWhile || // - hlo->opcode() == HloOpcode::kConditional) { + hlo->opcode() == HloOpcode::kConditional || // + hlo->HasSideEffectNoRecurse()) { return Status::OK(); } // TODO(b/112040122): Correctly normalize variadic reduce. diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 002be9c97098ef..63d4572f2028c4 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -236,6 +236,10 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, // the end of the BFloat16Propagation pass. continue; } + if (use.instruction->HasSideEffectNoRecurse()) { + // Keep side-effecting instruction's operands unchanged. + return false; + } // Any visited user that can accept BF16 has already been updated if // necessary, e.g., the output has been changed to BF16 if it propagates // precision, or a called computation's parameters have been changed to @@ -329,22 +333,6 @@ void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo, return; } - // Do not change precision for instructions related to entry and exit of a - // computation, and control flow, because this pass might break the interfaces - // or assumptions for them. - if (hlo->opcode() == HloOpcode::kInfeed || // - hlo->opcode() == HloOpcode::kOutfeed || // - hlo->opcode() == HloOpcode::kSend || // - hlo->opcode() == HloOpcode::kSendDone || // - hlo->opcode() == HloOpcode::kRecv || // - hlo->opcode() == HloOpcode::kRecvDone || // - hlo->opcode() == HloOpcode::kCustomCall || // - hlo->opcode() == HloOpcode::kCall || // - hlo->opcode() == HloOpcode::kConditional || // - (hlo->opcode() == HloOpcode::kParameter && skip_parameters)) { - return; - } - // Prevent root instructions from having their output modified by recording // all F32 output values as needing to stay as F32. CHECK(hlo->parent() != nullptr); @@ -366,6 +354,17 @@ void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo, return; } + // Do not change precision for instructions related to entry and exit of a + // computation, side-effecting instructions, and control flow, because this + // pass might break the interfaces or assumptions for them. + if (hlo->opcode() == HloOpcode::kCustomCall || // + hlo->opcode() == HloOpcode::kCall || // + hlo->opcode() == HloOpcode::kConditional || // + hlo->HasSideEffectNoRecurse() || // + (hlo->opcode() == HloOpcode::kParameter && skip_parameters)) { + return; + } + if (!ContainsKey(consider_using_bfloat16_, hlo)) { return; } diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index e032b5c624c015..0af71eaac96fca 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -136,6 +136,40 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) { EXPECT_FALSE(OutputsBF16(c)); } +// Tests that side-effecting all-reduce should not be changed. +TEST_F(BFloat16PropagationTest, DoNotChangeAllReduce) { + auto module = CreateNewVerifiedModule(); + + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + auto rb = HloComputation::Builder(TestName()); + rb.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, + rb.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")), + rb.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")))); + auto reduction = module->AddEmbeddedComputation(rb.Build()); + HloInstruction* all_reduce = + builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( + ShapeUtil::MakeTupleShape({shape, shape}), {a, b}, reduction, + /*replica_groups=*/{}, /*barrier=*/"", /*all_reduce_id=*/1)); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, all_reduce, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, all_reduce, 1)); + HloInstruction* dot = builder.AddInstruction(CreateDot(shape, gte0, gte1)); + HloInstruction* root = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(PropagatePrecision(module.get())); + EXPECT_EQ(computation->root_instruction(), root); +} + // Tests that if a constant is converted to BF16 then its literal must also be // converted. TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index da01c0caf2a666..4ce5a8a29255a7 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -502,8 +502,8 @@ Status CreateHloProfilingArtifacts( HloCostAnalysis cost_analysis(shape_size_bytes); TF_RETURN_IF_ERROR(entry_computation.Accept(&cost_analysis)); - *hlo_profile_printer_data = - CreateHloProfilePrinterData(**hlo_profile_index_map, cost_analysis); + *hlo_profile_printer_data = CreateHloProfilePrinterData( + **hlo_profile_index_map, cost_analysis, entry_computation.name()); *computation_to_profile_idx = (*hlo_profile_index_map)->computation_to_profile_idx(); diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 99fa707c959854..97f9b85a606e14 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -1546,10 +1546,8 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { LayoutUtil::Minor(target_array_.GetShape().layout(), 0) == 0}; } -// Return whether the given shape is a matrix with no padding. -static bool IsRank2WithNoPadding(const Shape& shape) { - return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape); -} +// Return whether the given shape is rank 2. +static bool IsRank2(const Shape& shape) { return ShapeUtil::Rank(shape) == 2; } // In a gemm operation where output = lhs * rhs, check whether the given shapes // are valid for the operation. @@ -1565,8 +1563,7 @@ static bool AreValidGemmShapes( return false; } - if (!(IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && - IsRank2WithNoPadding(output_shape))) { + if (!(IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape))) { return false; } diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index b1f81ad73b57f4..d6968323f337d8 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -2206,16 +2206,16 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace"; CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); - // Delegate to common implementation of fused in-place dynamic-update-slice. - auto operands = GetIrArraysForOperandsOf(fusion); return llvm_ir::EmitFusedDynamicUpdateSliceInPlace( - fusion, operands, GetIrArrayFor(fusion), &elemental_emitter, &b_); + fusion, GetGeneratorForOperandIrArrays(fusion), GetIrArrayFor(fusion), + &elemental_emitter, &b_); } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) { VLOG(3) << "HandleFusion kLoop"; CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); auto operands = GetIrArraysForOperandsOf(fusion); - FusedIrEmitter fused_emitter(operands, &elemental_emitter); + FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion), + &elemental_emitter); TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter)); return EmitTargetElementLoop(fusion, fused_emitter.GetRootGenerator()); @@ -2415,14 +2415,8 @@ StatusOr IrEmitter::EmitFastConcatenate( *failure_reason = "operand has mismatching layouts"; return false; } - if (LayoutUtil::IsPadded(op->shape())) { - *failure_reason = "operand has padded layout"; - return false; - } } - CHECK(!LayoutUtil::IsPadded(concatenate->shape())); - // We split the dimensions into three categories: the dimension over which we // are concatenating (concat_dim), the dimensions that are minor to it // (inner_dims) and the dimensions that are major to it (outer_dims). diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 586f27b104ed70..136b88ff75ea8a 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -59,6 +59,9 @@ namespace cpu { class IrEmitter : public DfsHloVisitorWithDefault, public IrBuilderMixin { public: + using GeneratorForOperandIrArrays = + std::function()>; + // Create a new LLVM IR emitter. // // hlo_module: the HLO module we are emitting IR for. @@ -208,6 +211,11 @@ class IrEmitter : public DfsHloVisitorWithDefault, std::vector GetIrArraysForOperandsOf( const HloInstruction* hlo); + GeneratorForOperandIrArrays GetGeneratorForOperandIrArrays( + HloInstruction* unnested_hlo) { + return [=]() { return GetIrArraysForOperandsOf(unnested_hlo); }; + } + // Augments IrArray with aliasing information. void AddAliasingInformationToIrArray(const HloInstruction& hlo, llvm_ir::IrArray* array) { diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc index b3549acfc291a5..ed37099a542807 100644 --- a/tensorflow/compiler/xla/service/despecializer.cc +++ b/tensorflow/compiler/xla/service/despecializer.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_normalization.h" #include "tensorflow/compiler/xla/service/defuser.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/implicit_broadcast_remover.h" namespace xla { @@ -45,6 +46,7 @@ class ControlDepRemover : public HloModulePass { Despecializer::Despecializer() : pipeline_("despecializer") { // TODO(b/70588125): Also deal with window reversal in a fast way. + pipeline_.AddPass(); pipeline_.AddPass(); pipeline_.AddPass(); pipeline_.AddPass(); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc index 6ee8a9f3e19151..87a835f2504068 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc @@ -484,7 +484,7 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); } -// Extracted from //learning/brain/google/xla/benchmarks/resnet.py +// Extracted from Resnet-50. // // For simplicity, we focus on the column dimension and ignore other dimensions. // We use [?] to represent the shape instead of the content. diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc index 0006e85e160e26..492d290bf4a27a 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc @@ -126,9 +126,9 @@ Status RunCudnnConvImpl(CudnnConvParams params, int64 feature_group_count = params.feature_group_count; AlgorithmConfig algorithm = params.algorithm; - VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id(); + VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm()->algo_id(); VLOG(3) << "tensor_ops_enabled: " - << algorithm.algorithm().tensor_ops_enabled(); + << algorithm.algorithm()->tensor_ops_enabled(); VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind); VLOG(3) << "input shape: " << ShapeUtil::HumanStringWithLayout(input_shape); VLOG(3) << "filter shape: " << ShapeUtil::HumanStringWithLayout(filter_shape); @@ -302,8 +302,8 @@ Status RunCudnnConvImpl(CudnnConvParams params, if (!stream->ok()) { return InternalError( "Unable to launch convolution with type %s and algorithm (%d, %d)", - CudnnConvKindToString(kind), algorithm.algorithm().algo_id(), - algorithm.algorithm_no_scratch().algo_id()); + CudnnConvKindToString(kind), algorithm.algorithm()->algo_id(), + algorithm.algorithm_no_scratch()->algo_id()); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 9c4a4903667ea1..27f07b1d581250 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -51,7 +52,8 @@ struct MatrixDescriptor { // rhs_matrix, and stores the result to output_matrix. template bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, - MatrixDescriptor output_matrix, double alpha, se::Stream* stream) { + MatrixDescriptor output_matrix, double alpha, double beta, + se::Stream* stream) { DCHECK(!output_matrix.transpose); const int64 batch_size = lhs_matrix.batch_size; @@ -73,7 +75,7 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, lhs_transpose, rhs_transpose, output_matrix.num_rows, output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha, lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, - /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0, + /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/beta, &output_data, /*leading dim of output=*/output_matrix.num_rows) .ok(); } @@ -88,7 +90,7 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, /*alpha=*/alpha, lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, lhs_stride, rhs_data, /*leading dim of RHS=*/rhs_matrix.num_rows, rhs_stride, - /*beta=*/0.0, &output_data, + /*beta=*/beta, &output_data, /*leading dim of output=*/output_matrix.num_rows, output_stride, batch_size) .ok(); @@ -112,6 +114,7 @@ template bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, MatrixDescriptor output_matrix, double alpha, + double beta, se::blas::ComputationType computation_type, se::blas::AlgorithmType algorithm, se::Stream* stream, se::blas::ProfileResult* output_profile_result) { @@ -138,7 +141,7 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix, /*alpha=*/static_cast(alpha), lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, /*leading dim of RHS=*/rhs_matrix.num_rows, - /*beta=*/static_cast(0.0f), &output_data, + /*beta=*/static_cast(beta), &output_data, /*leading dim of output=*/output_matrix.num_rows, computation_type, algorithm, output_profile_result) .ok(); @@ -153,7 +156,7 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix, template StatusOr DoGemmAutotune( MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, - MatrixDescriptor output_matrix, double alpha, + MatrixDescriptor output_matrix, double alpha, double beta, se::blas::ComputationType computation_type, se::Stream* stream) { std::vector algorithms; CHECK(stream->parent()->GetBlasGemmAlgorithms(&algorithms)); @@ -166,7 +169,7 @@ StatusOr DoGemmAutotune( // non-null ProfileResult, DoGemmWithAlgorithm should always return true, // and the actual success-ness is returned in ProfileResult::is_valid. CHECK(DoGemmWithAlgorithm(lhs_matrix, rhs_matrix, output_matrix, - alpha, computation_type, algorithm, + alpha, beta, computation_type, algorithm, stream, &profile_result)); if (profile_result.is_valid()) { @@ -263,8 +266,9 @@ DotDimensionNumbers GetDimensionNumbers(const HloInstruction& hlo_instruction) { } CHECK_EQ(hlo_instruction.opcode(), HloOpcode::kFusion); CHECK_EQ(hlo_instruction.fusion_kind(), HloInstruction::FusionKind::kOutput); - CHECK_EQ(hlo_instruction.fused_expression_root()->opcode(), - HloOpcode::kMultiply); + CHECK(hlo_instruction.fused_expression_root()->opcode() == HloOpcode::kAdd || + hlo_instruction.fused_expression_root()->opcode() == + HloOpcode::kMultiply); // Try to find the dot inside the output fusion node. const HloInstruction* dot = hlo_instruction.fused_expression_root()->operand(0); @@ -282,8 +286,9 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, const BufferAllocation::Slice& rhs_buffer, const BufferAllocation::Slice& output_buffer, const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape, double alpha, - const HloInstruction* hlo_instruction) + const Shape& output_shape, double alpha, double beta, + const HloInstruction* hlo_instruction, + bool implements_whole_instruction) : Thunk(Kind::kGemm, hlo_instruction), lhs_buffer_(lhs_buffer), rhs_buffer_(rhs_buffer), @@ -291,7 +296,9 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, lhs_shape_(lhs_shape), rhs_shape_(rhs_shape), output_shape_(output_shape), - alpha_(alpha) {} + alpha_(alpha), + beta_(beta), + implements_whole_instruction_(implements_whole_instruction) {} Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, se::Stream* stream, @@ -386,7 +393,7 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, // TODO(b/112111608): Implement auto tune for batched gemm. if (batch_size != 1) { return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, - alpha_, stream); + alpha_, beta_, stream); } auto thunk_name = [&] { @@ -398,9 +405,27 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, auto autotune_it = autotune_results_.find(device_name); if (autotune_it == autotune_results_.end()) { VLOG(3) << "Starting autotune of GemmThunk " << thunk_name(); - StatusOr best_algorithm = - GetGemmAutotuneFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, - alpha_, computation_type, stream); + + // If the output buffer already contains a bias then autotune into a + // scratch buffer. This avoids overwriting the bias buffer. The scratch + // buffer may contain arbitrary garbage values. + se::DeviceMemoryBase scratch_data = output_data; + std::unique_ptr> scratch_mem; + if (beta_ != 0.0) { + auto temp_status = stream->AllocateTemporaryArray( + ShapeUtil::ByteSizeOf(output_shape_)); + if (!temp_status.ok()) { + return false; + } + scratch_mem = std::move(temp_status).ValueOrDie(); + scratch_data = scratch_mem->device_memory(); + } + const MatrixDescriptor scratch_descriptor( + scratch_data, false, output_num_cols, output_num_rows, batch_size); + + StatusOr best_algorithm = GetGemmAutotuneFn( + element_type)(lhs_matrix, rhs_matrix, scratch_descriptor, alpha_, + beta_, computation_type, stream); autotune_it = autotune_results_.insert({device_name, best_algorithm}).first; @@ -421,18 +446,19 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, VLOG(2) << "Using algorithm " << algorithm << " chosen by autotuning on GemmThunk " << thunk_name(); return GetGemmWithAlgorithmFn(element_type)( - lhs_matrix, rhs_matrix, output_matrix, alpha_, computation_type, - algorithm, stream, + lhs_matrix, rhs_matrix, output_matrix, alpha_, beta_, + computation_type, algorithm, stream, /*output_profile_result=*/nullptr); } // Autotune will fail when CUDA 8 and GPU sm_50 or older are used. // Use the older Gemm API in this case. return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, - alpha_, stream); + alpha_, beta_, stream); }; - auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); + auto op_profiler = profiler->MakeScopedInstructionProfiler( + implements_whole_instruction_ ? hlo_instruction() : nullptr); bool launch_ok; if (LayoutUtil::Minor(output_shape_.layout(), row_dim) == 0) { launch_ok = launch(lhs_descriptor, rhs_descriptor, diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index 12c81f9bfc6bfd..cc2d12a39c045f 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -41,8 +41,9 @@ class GemmThunk : public Thunk { const BufferAllocation::Slice& rhs_buffer, const BufferAllocation::Slice& output_buffer, const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape, double alpha, - const HloInstruction* hlo_instruction); + const Shape& output_shape, double alpha, double beta, + const HloInstruction* hlo_instruction, + bool implements_whole_instruction); GemmThunk(const GemmThunk&) = delete; GemmThunk& operator=(const GemmThunk&) = delete; @@ -70,6 +71,9 @@ class GemmThunk : public Thunk { const Shape output_shape_; const double alpha_; + const double beta_; + + const bool implements_whole_instruction_; // Maps device names (StreamExecutor::DeviceDescription::name()) to autotune // results. The map's value is the best algorithm we've found for this thunk diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc index 02a0d028c118ab..d8f2e9f6abd8bb 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc @@ -124,7 +124,8 @@ GpuHloOrdering::GpuHloOrdering( for (auto* computation : module->computations()) { if (computation != module->entry_computation() && !computation->IsFusionComputation()) { - predecessors_.emplace(computation, computation->ComputeReachability()); + predecessors_.emplace(computation, + HloReachabilityMap::Build(computation)); } } } diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index c80cdc810a3f15..7f2b59810f0334 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -179,6 +179,10 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, IsIEEEFloatingPointScalarConstant(alpha->operand(0))) { return true; } + } else if (consumer->operand_count() == 2 && + consumer->opcode() == HloOpcode::kAdd) { + // Fuse a bias add into the output of the dot. + return true; } } diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 7ca72abff19242..57e66f5a12cf54 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -331,6 +331,33 @@ TEST_F(InstructionFusionTest, DotOutputFusion) { op::Broadcast(op::Constant()))); } +TEST_F(InstructionFusionTest, DotOutputFusionBiasAdd) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + alpha = f32[] constant(3) + broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={} + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + p2 = f32[4,4]{1,0} parameter(2) + transpose = f32[3,4]{1,0} transpose(p1), dimensions={1, 0} + dot = f32[4,4]{1,0} dot(p0, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT add = f32[4,4] add(dot, p2) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kOutput); + EXPECT_THAT(root->fused_expression_root(), + op::Add(op::Dot(op::Parameter(), op::Transpose(op::Parameter())), + op::Parameter())); +} + // Compute sum(1/p0), where p0 has type f32, twice. Check that the division is // duplicated and fused into both reduces. TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index ec3d8f9405840b..42fb38dffae31b 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -38,10 +38,9 @@ namespace gpu { namespace { -// Return whether the given shape is a matrix with no padding. -bool IsRank2WithNoPadding(const Shape& shape, int64 batch_dimensions_size) { - return ShapeUtil::Rank(shape) == batch_dimensions_size + 2 && - !LayoutUtil::IsPadded(shape); +// Return whether the given shape is rank 2 excluding the batch dimensions. +bool IsRank2(const Shape& shape, int64 batch_dimensions_size) { + return ShapeUtil::Rank(shape) == batch_dimensions_size + 2; } // In a gemm operation where output = lhs * rhs, check whether the given shapes @@ -56,10 +55,9 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, bool type_is_allowed = (output_primitive_type == F16 || output_primitive_type == F32 || output_primitive_type == F64 || output_primitive_type == C64); - return type_is_allowed && - IsRank2WithNoPadding(lhs_shape, batch_dimensions_size) && - IsRank2WithNoPadding(rhs_shape, batch_dimensions_size) && - IsRank2WithNoPadding(output_shape, batch_dimensions_size) && + return type_is_allowed && IsRank2(lhs_shape, batch_dimensions_size) && + IsRank2(rhs_shape, batch_dimensions_size) && + IsRank2(output_shape, batch_dimensions_size) && !ShapeUtil::IsZeroElementArray(lhs_shape) && !ShapeUtil::IsZeroElementArray(rhs_shape); } @@ -93,7 +91,8 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { if (hlo.opcode() == HloOpcode::kFusion && hlo.fusion_kind() == HloInstruction::FusionKind::kOutput && - hlo.fused_expression_root()->opcode() == HloOpcode::kMultiply) { + (hlo.fused_expression_root()->opcode() == HloOpcode::kMultiply || + hlo.fused_expression_root()->opcode() == HloOpcode::kAdd)) { // Try to find the dot inside the output fusion node. const HloInstruction* dot = hlo.fused_expression_root()->operand(0); if (dot->opcode() != HloOpcode::kDot) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index a3821e077ecf6b..7fcdd805ed3200 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -697,15 +697,11 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { Status IrEmitter::HandleFusion(HloInstruction* fusion) { // kFusion for library calls should be handled by // IrEmitterUnnested::HandleFusion. - CHECK(HloInstruction::FusionKind::kLoop == fusion->fusion_kind()); - - std::vector parameter_arrays; - for (HloInstruction* operand : fusion->operands()) { - parameter_arrays.push_back(GetIrArray(*operand, *fusion)); - } + CHECK_EQ(HloInstruction::FusionKind::kLoop, fusion->fusion_kind()); GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_, GetNestedComputer()); - FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); + FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion), + &elemental_emitter); TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter)); return EmitTargetElementLoop(*fusion, fused_emitter.GetRootGenerator()); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 88052014800583..56c3f452006f9e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -68,6 +68,9 @@ namespace gpu { class IrEmitter : public DfsHloVisitorWithDefault, public IrBuilderMixin { public: + using GeneratorForOperandIrArrays = + std::function()>; + IrEmitter(const IrEmitter&) = delete; IrEmitter& operator=(const IrEmitter&) = delete; @@ -179,6 +182,20 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Hlo configuration data used during code generation. const HloModuleConfig& hlo_module_config_; + protected: + GeneratorForOperandIrArrays GetGeneratorForOperandIrArrays( + HloInstruction* fusion) { + return [=]() { + std::vector ir_arrays; + ir_arrays.reserve(fusion->operand_count()); + absl::c_transform(fusion->operands(), std::back_inserter(ir_arrays), + [&](const HloInstruction* operand) { + return GetIrArray(*operand, *fusion); + }); + return ir_arrays; + }; + } + private: // A helper method for EmitAtomicOperationForNestedComputation. Certain // computations, such as floating-point addition and integer maximization, can diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 921c9decf99604..21e44e1e7d3fb7 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -337,14 +337,6 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size, } // namespace Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { - int unroll_factor = 1; - // Unfused elementwise operations are usually memory bound, unroll them. - if (hlo->IsElementwise()) { - unroll_factor = ComputeMaxUnrollFactor(hlo); - } - - AddThunkToThunkSequence(BuildKernelThunk( - hlo, /*implements_whole_instruction=*/true, unroll_factor)); return IrEmitter::DefaultAction(hlo); } @@ -505,15 +497,12 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { thunks.push_back(BuildKernelThunk( fusion, /*implements_whole_instruction=*/false, unroll_factor)); - std::vector operand_parameter_arrays; - for (HloInstruction* operand : fusion->operands()) { - operand_parameter_arrays.push_back(GetIrArray(*operand, *fusion)); - } GpuElementalIrEmitter operand_elemental_emitter( hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, GetNestedComputer()); - FusedIrEmitter operand_fused_emitter(operand_parameter_arrays, - &operand_elemental_emitter); + FusedIrEmitter operand_fused_emitter( + GetGeneratorForOperandIrArrays(fusion), + &operand_elemental_emitter); TF_RETURN_IF_ERROR( root->mutable_operand(0)->Accept(&operand_fused_emitter)); @@ -529,15 +518,12 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { BuildKernelThunk(fusion, /*implements_whole_instruction=*/false)); // Spin up a new fused emitter for the scatter kernel and emit it. - std::vector scatter_parameter_arrays; - for (HloInstruction* operand : fusion->operands()) { - scatter_parameter_arrays.push_back(GetIrArray(*operand, *fusion)); - } GpuElementalIrEmitter scatter_elemental_emitter( hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, GetNestedComputer()); - FusedIrEmitter scatter_fused_emitter(scatter_parameter_arrays, - &scatter_elemental_emitter); + FusedIrEmitter scatter_fused_emitter( + GetGeneratorForOperandIrArrays(fusion), + &scatter_elemental_emitter); TF_RETURN_IF_ERROR(root->Accept(&scatter_fused_emitter)); TF_RETURN_IF_ERROR(EmitScatter( thunks.back().get(), root, @@ -585,14 +571,11 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { CHECK(first_reduce != nullptr); std::unique_ptr kernel_thunk = BuildKernelThunk(fusion, /*implements_whole_instruction=*/false); - std::vector parameter_arrays; - for (HloInstruction* operand : fusion->operands()) { - parameter_arrays.push_back(GetIrArray(*operand, *fusion)); - } GpuElementalIrEmitter elemental_emitter( hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, GetNestedComputer()); - FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); + FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion), + &elemental_emitter); TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); // For multi-output fusion CHECK the constraints and feed all the @@ -663,10 +646,6 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // Set up kernel thunk and fused ir emitter. std::unique_ptr fusion_thunk = BuildKernelThunk(fusion, /*implements_whole_instruction=*/true); - std::vector operand_arrays; - for (HloInstruction* operand : fusion->operands()) { - operand_arrays.push_back(GetIrArray(*operand, *fusion)); - } GpuElementalIrEmitter elemental_emitter(hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, GetNestedComputer()); @@ -685,8 +664,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { AddThunkToThunkSequence(std::move(fusion_thunk)); return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace( - fusion, operand_arrays, output_array, &elemental_emitter, - launch_dimensions, &b_); + fusion, GetGeneratorForOperandIrArrays(fusion), output_array, + &elemental_emitter, launch_dimensions, &b_); } if (ImplementedAsGemm(*fusion)) { @@ -700,10 +679,6 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { return Status::OK(); } - int unroll_factor = ComputeMaxUnrollFactor(fusion); - - AddThunkToThunkSequence(BuildKernelThunk( - fusion, /*implements_whole_instruction=*/true, unroll_factor)); return IrEmitter::HandleFusion(fusion); } @@ -1629,7 +1604,7 @@ Status IrEmitterUnnested::EmitReductionToVector( // the dimensions to keep are contiguous, by prerequisite of // `EmitReductionToVector`, we only need to check whether the minormost // dimension of the input is to keep. - if (input_dims_to_keep.empty()) { + if (ShapeUtil::IsEffectiveScalar(reduce->shape())) { return EmitReductionToScalar(kernel_thunk, reduce, input_shape, input_gens, init_value_gens, reducers, reduce_output_shapes, extra_output_gens); @@ -1721,8 +1696,6 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { return Status::OK(); } - AddThunkToThunkSequence( - BuildKernelThunk(reduce, /*implements_whole_instruction=*/true)); return IrEmitter::HandleReduce(reduce); } @@ -2192,8 +2165,6 @@ Status IrEmitterUnnested::EmitScatter( } Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { - AddThunkToThunkSequence( - BuildKernelThunk(select, /*implements_whole_instruction=*/true)); return IrEmitter::HandleSelect(select); } @@ -2655,28 +2626,43 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( rhs->shape(), // The shape of RHS. inst->shape(), // The shape of the output. 1.0, // alpha. - inst); + 0.0, // beta. + inst, /*implements_whole_instruction=*/true); } if (inst->opcode() == HloOpcode::kFusion) { CHECK_EQ(inst->fusion_kind(), HloInstruction::FusionKind::kOutput); - const HloInstruction* mul = inst->fused_expression_root(); - const HloInstruction* dot = mul->operand(0); - const HloInstruction* alpha = mul->operand(1); - if (dot->opcode() != HloOpcode::kDot) { - std::swap(dot, alpha); - } - if (alpha->opcode() == HloOpcode::kBroadcast) { - alpha = alpha->operand(0); - } - if (alpha->opcode() == HloOpcode::kParameter) { - alpha = inst->operand(alpha->parameter_number()); - } - // TODO(b/74185543): Remove the following if block once we support fusion - // with a non-constant as well. Then we will just always use the constant - // on the device. - if (alpha->opcode() == HloOpcode::kCopy) { - alpha = alpha->operand(0); + const HloInstruction* output_fused_op = inst->fused_expression_root(); + + double alpha_value = 1.0; + const HloInstruction* bias = nullptr; + const HloInstruction* dot = output_fused_op->operand(0); + if (output_fused_op->opcode() == HloOpcode::kMultiply) { + const HloInstruction* alpha = output_fused_op->operand(1); + if (dot->opcode() != HloOpcode::kDot) { + std::swap(dot, alpha); + } + if (alpha->opcode() == HloOpcode::kBroadcast) { + alpha = alpha->operand(0); + } + if (alpha->opcode() == HloOpcode::kParameter) { + alpha = inst->operand(alpha->parameter_number()); + } + // TODO(b/74185543): Remove the following if block once we support fusion + // with a non-constant as well. Then we will just always use the constant + // on the device. + if (alpha->opcode() == HloOpcode::kCopy) { + alpha = alpha->operand(0); + } + alpha_value = GetScalarConstantAsDouble(alpha->literal()); + } else { + // Fused bias add. + CHECK_EQ(output_fused_op->opcode(), HloOpcode::kAdd); + bias = output_fused_op->operand(1); + if (dot->opcode() != HloOpcode::kDot) { + std::swap(dot, bias); + } + bias = inst->operand(bias->parameter_number()); } DCHECK(dot->opcode() == HloOpcode::kDot); @@ -2689,15 +2675,38 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( const HloInstruction* rhs = inst->operand(rhs_parameter->parameter_number()); + // The bias is passed inside the output buffer. If those buffers are shared + // we can just use it, otherwise copy the bias values into the output buffer + // first. + if (bias != nullptr && + GetAllocationSlice(*bias) != GetAllocationSlice(*inst)) { + std::vector> thunks; + thunks.push_back(absl::make_unique( + /*source_buffer=*/GetAllocationSlice(*bias), + /*destination_buffer=*/GetAllocationSlice(*inst), + /*mem_size=*/ShapeUtil::ByteSizeOf(inst->shape()), nullptr)); + thunks.push_back(absl::make_unique( + GetAllocationSlice(*lhs), // The buffer assigned to LHS. + GetAllocationSlice(*rhs), // The buffer assigned to RHS. + GetAllocationSlice(*inst), // The output buffer. + lhs->shape(), // The shape of LHS. + rhs->shape(), // The shape of RHS. + inst->shape(), // The shape of the output. + alpha_value, // alpha. + 1.0, // beta. + inst, /*implements_whole_instruction=*/false)); + return absl::make_unique(std::move(thunks), inst); + } return absl::make_unique( - GetAllocationSlice(*lhs), // The buffer assigned to LHS. - GetAllocationSlice(*rhs), // The buffer assigned to RHS. - GetAllocationSlice(*inst), // The output buffer. - lhs->shape(), // The shape of LHS. - rhs->shape(), // The shape of RHS. - inst->shape(), // The shape of the output. - GetScalarConstantAsDouble(alpha->literal()), // alpha. - inst); + GetAllocationSlice(*lhs), // The buffer assigned to LHS. + GetAllocationSlice(*rhs), // The buffer assigned to RHS. + GetAllocationSlice(*inst), // The output buffer. + lhs->shape(), // The shape of LHS. + rhs->shape(), // The shape of RHS. + inst->shape(), // The shape of the output. + alpha_value, // alpha. + bias != nullptr ? 1.0 : 0.0, // beta. + inst, /*implements_whole_instruction=*/true); } LOG(FATAL) << "Cannot build a GemmThunk for " << inst->ToString(); @@ -2806,15 +2815,12 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( if (fused) { // If init_value was fused into this reduce we have to generate it first. - std::vector parameter_arrays; - for (HloInstruction* operand : hlo->operands()) { - parameter_arrays.push_back(GetIrArray(*operand, *hlo)); - } GpuElementalIrEmitter elemental_emitter(hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, GetNestedComputer()); - FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); + FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo), + &elemental_emitter); TF_RETURN_IF_ERROR(init_value_operand->Accept(&fused_emitter)); TF_RETURN_IF_ERROR( ParallelLoopEmitter(fused_emitter.GetGenerator(init_value_operand), @@ -3037,9 +3043,19 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( Status IrEmitterUnnested::EmitTargetElementLoop( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator) { - CHECK_EQ(Thunk::Kind::kKernel, LastThunk()->kind()); - return EmitTargetElementLoopInThunk(hlo, element_generator, - static_cast(LastThunk())); + int unroll_factor = 1; + // Unfused elementwise operations are usually memory bound, unroll them. + if (hlo.IsElementwise() || hlo.opcode() == HloOpcode::kFusion) { + unroll_factor = ComputeMaxUnrollFactor(&hlo); + } + + std::unique_ptr kernel_thunk = BuildKernelThunk( + &hlo, /*implements_whole_instruction=*/true, unroll_factor); + Status emit_status = + EmitTargetElementLoopInThunk(hlo, element_generator, kernel_thunk.get()); + thunk_sequence_->emplace_back(std::move(kernel_thunk)); + + return emit_status; } std::vector IrEmitterUnnested::ConstructIrArrayForInputs( @@ -3403,7 +3419,8 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( [&](const IrArray::Index& index, llvm::Value* y_loc) { GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, GetNestedComputer()); - FusedIrEmitter fused_emitter(param_arrays, &elem_emitter); + FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo), + &elem_emitter); tiled_param_info.set_y(y_loc); fused_emitter.SetTiledParameterInfo(&tiled_param_info); TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter)); diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index 8a6e5327e08279..1d4856e0cae163 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -505,7 +505,7 @@ TEST_F(MultiOutputFusionTest, p1.1 = f16[2,2,2]{2,1,0} parameter(1) c0 = f16[] constant(0) broadcast = f16[2,2,2]{2,1,0} broadcast(f16[] c0), dimensions={} - greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast) + greater-than = pred[2,2,2]{2,1,0} greater-than(f16[2,2,2]{2,1,0} p1.1, f16[2,2,2]{2,1,0} broadcast) p0.1 = f16[2,2,2]{2,1,0} parameter(0) ROOT select = f16[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f16[2,2,2]{2,1,0} p0.1, f16[2,2,2]{2,1,0} broadcast) } diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 19ed70ddc3f398..de04ed85c30717 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -128,6 +128,7 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) { << potential_libdevice_dir; } + LOG(WARNING) << "Unable to find libdevice dir. Using '.'"; // Last resort: maybe in the current folder. return "."; } @@ -242,7 +243,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, * fixing the ticket. */ pipeline.AddInvariantChecker( /*layout_sensitive=*/true, - /*allow_mixed_precision=*/false, nullptr); + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. @@ -295,7 +297,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, * fixing the ticket. */ fusion.AddInvariantChecker( /*layout_sensitive=*/true, - /*allow_mixed_precision=*/false, nullptr); + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); fusion.AddPass(/*may_duplicate=*/false); fusion.AddPass(/*may_duplicate=*/true); fusion.AddPass(); @@ -309,7 +312,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after * fixing the ticket. */ reduce_pipeline.AddInvariantChecker( - /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false, nullptr); + /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); ReducePrecisionInsertion::AddPasses( &reduce_pipeline, hlo_module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); @@ -339,7 +343,8 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { * fixing the ticket. */ pipeline.AddInvariantChecker( /*layout_sensitive=*/true, - /*allow_mixed_precision=*/false, nullptr); + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which @@ -472,9 +477,10 @@ StatusOr> CompilePtx(const string& ptx, int cc_major, tracing::ScopedActivity activity("Compile PTX", /*is_expensive=*/true); const string ptxas_path = tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin", "ptxas"); - VLOG(2) << "Using ptxas at " << ptxas_path; + VLOG(2) << "Checking ptxas at " << ptxas_path; auto env = tensorflow::Env::Default(); TF_RETURN_IF_ERROR(env->FileExists(ptxas_path)); + VLOG(2) << "Using ptxas at " << ptxas_path; WarnIfBadPtxasVersion(ptxas_path); @@ -540,8 +546,8 @@ StatusOr> NVPTXCompiler::RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) { // We dump the post-optimization HLO in RunBackend so no need to dump it here. - VLOG(2) << "*** HLO Before Optimization"; - XLA_VLOG_LINES(2, module->ToString()); + VLOG(3) << "*** HLO Before Optimization"; + XLA_VLOG_LINES(3, module->ToString()); XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunHloPasses"); tracing::ScopedActivity activity("HLO Transforms", module->name(), @@ -600,8 +606,8 @@ StatusOr> NVPTXCompiler::RunBackend( // include headers, so no need for us to print them ourselves. XLA_VLOG_LINES(1, buffer_assignment->GetStats().ToString()); XLA_VLOG_LINES(2, buffer_assignment->ToString()); - VLOG(2) << "*** HLO After Optimization"; - XLA_VLOG_LINES(2, module->ToString()); + VLOG(3) << "*** HLO After Optimization"; + XLA_VLOG_LINES(3, module->ToString()); const string xla_dump_optimized_hlo_proto_to = module->config().debug_options().xla_dump_optimized_hlo_proto_to(); if (!xla_dump_optimized_hlo_proto_to.empty()) { @@ -631,10 +637,10 @@ StatusOr> NVPTXCompiler::RunBackend( string ir_module_string_before_opt; const bool embed_ir_in_executable = module->config().debug_options().xla_embed_ir_in_executable(); - if (VLOG_IS_ON(2) || embed_ir_in_executable) { + if (VLOG_IS_ON(3) || embed_ir_in_executable) { ir_module_string_before_opt = llvm_ir::DumpModuleToString(llvm_module); - VLOG(2) << "LLVM module before optimizations:"; - XLA_VLOG_LINES(2, ir_module_string_before_opt); + VLOG(3) << "LLVM module before optimizations:"; + XLA_VLOG_LINES(3, ir_module_string_before_opt); } const string& ir_dump_directory = @@ -678,6 +684,8 @@ StatusOr> NVPTXCompiler::RunBackend( } libdevice_dir = cached_libdevice_dir_; } + VLOG(2) << "Libdevice dir = " << libdevice_dir << "\n"; + int cc_major, cc_minor; if (!stream_exec->GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor)) { @@ -704,10 +712,10 @@ StatusOr> NVPTXCompiler::RunBackend( if (user_post_optimization_hook_) { TF_CHECK_OK(user_post_optimization_hook_(llvm_module)); } - VLOG(2) << "LLVM module after optimizations:"; - XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(llvm_module)); - VLOG(2) << "PTX:"; - XLA_VLOG_LINES(2, ptx); + VLOG(3) << "LLVM module after optimizations:"; + XLA_VLOG_LINES(3, llvm_ir::DumpModuleToString(llvm_module)); + VLOG(3) << "PTX:"; + XLA_VLOG_LINES(3, ptx); // Write PTX to IR dump directory, if IR dumping was requested. if (!ir_dump_directory.empty()) { @@ -731,8 +739,8 @@ StatusOr> NVPTXCompiler::RunBackend( auto thunk_schedule = absl::make_unique( ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment), hlo_schedule->ThunkLaunchOrder()); - VLOG(2) << "Printing the thunk schedule..."; - XLA_VLOG_LINES(2, thunk_schedule->ToString()); + VLOG(3) << "Printing the thunk schedule..."; + XLA_VLOG_LINES(3, thunk_schedule->ToString()); std::unique_ptr profile_index_map; std::unique_ptr profile_printer; @@ -743,8 +751,8 @@ StatusOr> NVPTXCompiler::RunBackend( stream_exec->GetDeviceDescription().memory_bandwidth()); TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); profile_index_map = absl::make_unique(*module); - profile_printer = - CreateHloProfilePrinterData(*profile_index_map, cost_analysis); + profile_printer = CreateHloProfilePrinterData( + *profile_index_map, cost_analysis, entry_computation->name()); } auto* gpu_executable = new GpuExecutable( diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc index 5b6cf2c04d0537..4775baf44aecfe 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc @@ -122,7 +122,7 @@ std::unique_ptr AssignStreams(const HloModule& module) { auto stream_assignment = absl::make_unique(); const HloComputation& computation = *module.entry_computation(); std::unique_ptr reachability = - computation.ComputeReachability(); + HloReachabilityMap::Build(&computation); std::vector seen_gemms; // The execution of different RNG Hlo instructions in the same module updates // a common global variable. To avoid a race condition, we simply assign all diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index b0f7cd91ad1db0..01ae6a55fcf2d6 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -739,72 +739,6 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, return RemoveInstructionAndUnusedOperands(old_instruction); } -std::unique_ptr HloComputation::ComputeReachability() - const { - const auto& all = MakeInstructionPostOrder(); - auto result = absl::make_unique(all); - auto channel_dependency_map = ComputeChannelDependencies(); - - std::vector inputs; - for (const HloInstruction* hlo : all) { - inputs.assign(hlo->operands().begin(), hlo->operands().end()); - inputs.insert(inputs.end(), hlo->control_predecessors().begin(), - hlo->control_predecessors().end()); - - switch (hlo->opcode()) { - case HloOpcode::kRecvDone: { - auto it = channel_dependency_map.find(hlo->channel_id()); - if (it != channel_dependency_map.end()) { - absl::c_copy(it->second, std::back_inserter(inputs)); - } - break; - } - case HloOpcode::kCrossReplicaSum: { - auto all_reduce_id = hlo->all_reduce_id(); - if (all_reduce_id) { - auto it = channel_dependency_map.find(all_reduce_id.value()); - if (it != channel_dependency_map.end()) { - absl::c_copy(it->second, std::back_inserter(inputs)); - } - } - break; - } - default: - break; - } - - result->FastSetReachabilityToUnion(inputs, hlo); - } - return result; -} - -void HloComputation::UpdateReachabilityThroughInstruction( - const HloInstruction* instruction, HloReachabilityMap* reachability_map) { - std::queue worklist; - worklist.push(instruction); - - std::vector inputs; - - while (!worklist.empty()) { - const HloInstruction* item = worklist.front(); - worklist.pop(); - - inputs.assign(item->operands().begin(), item->operands().end()); - inputs.insert(inputs.end(), item->control_predecessors().begin(), - item->control_predecessors().end()); - - if (reachability_map->SetReachabilityToUnion(inputs, item)) { - // Add immediate successors to worklist. - for (const HloInstruction* user : item->users()) { - worklist.push(user); - } - for (const HloInstruction* succ : item->control_successors()) { - worklist.push(succ); - } - } - } -} - std::vector HloComputation::CollectUnreachableRoots() const { std::vector unreachable_roots; for (auto* instruction : instructions()) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index dec96d11a93cf5..2cce866e5c17c0 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/statusor.h" @@ -215,19 +214,6 @@ class HloComputation { // this order, definitions of values always appear before their uses. std::vector MakeInstructionPostOrder() const; - // Computes and returns the reachability between HLO instructions in the - // computation. The returned HloReachabilityMap is constructed such that - // HloReachabilityMap::IsReachable(a, b) returns true iff there exists a - // directed path (from producer to consumer) from 'a' to 'b'. Both data - // dependencies (operands) and control dependencies are considered for - // reachability. Trivially an instruction is reachable from itself. - std::unique_ptr ComputeReachability() const; - - // Updates the given reachability map after the immediate predecessor set - // (operands and control predecessors) of 'instruction' has changed. - void UpdateReachabilityThroughInstruction( - const HloInstruction* instruction, HloReachabilityMap* reachability_map); - int64 instruction_count() const { return instruction_iterators_.size(); } // Creates and returns a list of the embedded computations called by this @@ -355,6 +341,14 @@ class HloComputation { // channel complete). bool IsRemovable(const HloInstruction* instruction); + // Returns a map from channel-id to directed dependencies of the channel + // instructions. For send&recv pairs it means the send instruction and for + // cross-replica-sum the union of the dependencies for all participating + // instructions. + using ChannelDependencyMap = + absl::flat_hash_map>; + ChannelDependencyMap ComputeChannelDependencies() const; + // Returns true if this computation has a side effect. A computation has a // side effect if it contains one or more instructions with a side effect. bool HasSideEffect() const; @@ -410,14 +404,6 @@ class HloComputation { // Internal helper to collect unreachable roots. std::vector CollectUnreachableRoots() const; - // Returns a map from channel-id to directed dependencies of the channel - // instructions. For send&recv pairs it means the send instruction and for - // cross-replica-sum the union of the dependencies for all participating - // instructions. - using ChannelDependencyMap = - absl::flat_hash_map>; - ChannelDependencyMap ComputeChannelDependencies() const; - enum VisitState { kVisiting, kVisited }; void ComputeInstructionPostOrder( const HloComputation::ChannelDependencyMap& channel_dependency_map, diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 2aaaef1d36d58b..ac6d08b026ad33 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -484,107 +484,6 @@ TEST_F(HloComputationTest, CloneWithControlDependency) { EXPECT_THAT(successors, ::testing::ElementsAre(cloned_add)); } -TEST_F(HloComputationTest, Reachability) { - // Test reachability of a non-trivial computation: - // - // const1 const2 - // | | - // | +-------+ - // | | | - // add .. negate - // | . | - // | .... exp - // | | - // +---+ +-+---+ - // | | | - // multiply copy - // - // There is a control dependency from 'add' to 'exp'. - auto builder = HloComputation::Builder(TestName()); - auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); - auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); - auto add = builder.AddInstruction(HloInstruction::CreateBinary( - r0f32_, HloOpcode::kAdd, constant1, constant2)); - auto negate = builder.AddInstruction( - HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant2)); - auto exp = builder.AddInstruction( - HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, negate)); - auto mul = builder.AddInstruction( - HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply, add, exp)); - auto copy = builder.AddInstruction( - HloInstruction::CreateUnary(r0f32_, HloOpcode::kCopy, exp)); - - auto module = CreateNewModule(); - auto computation = - module->AddEntryComputation(builder.Build(/*root_instruction=*/mul)); - - TF_CHECK_OK(add->AddControlDependencyTo(exp)); - auto reachability = computation->ComputeReachability(); - - EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); - EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); - EXPECT_TRUE(reachability->IsReachable(constant1, add)); - EXPECT_FALSE(reachability->IsReachable(constant1, negate)); - EXPECT_TRUE(reachability->IsReachable(constant1, exp)); - EXPECT_TRUE(reachability->IsReachable(constant1, mul)); - EXPECT_TRUE(reachability->IsReachable(constant1, copy)); - - EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); - EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); - EXPECT_TRUE(reachability->IsReachable(constant2, add)); - EXPECT_TRUE(reachability->IsReachable(constant2, negate)); - EXPECT_TRUE(reachability->IsReachable(constant2, exp)); - EXPECT_TRUE(reachability->IsReachable(constant2, mul)); - EXPECT_TRUE(reachability->IsReachable(constant2, copy)); - - EXPECT_FALSE(reachability->IsReachable(exp, constant1)); - EXPECT_FALSE(reachability->IsReachable(exp, constant2)); - EXPECT_FALSE(reachability->IsReachable(exp, add)); - EXPECT_FALSE(reachability->IsReachable(exp, negate)); - EXPECT_TRUE(reachability->IsReachable(exp, exp)); - EXPECT_TRUE(reachability->IsReachable(exp, mul)); - EXPECT_TRUE(reachability->IsReachable(exp, copy)); - - EXPECT_FALSE(reachability->IsReachable(mul, constant1)); - EXPECT_FALSE(reachability->IsReachable(mul, constant2)); - EXPECT_FALSE(reachability->IsReachable(mul, add)); - EXPECT_FALSE(reachability->IsReachable(mul, negate)); - EXPECT_FALSE(reachability->IsReachable(mul, exp)); - EXPECT_TRUE(reachability->IsReachable(mul, mul)); - EXPECT_FALSE(reachability->IsReachable(mul, copy)); - - EXPECT_TRUE(reachability->IsConnected(constant1, copy)); - EXPECT_TRUE(reachability->IsConnected(copy, constant1)); - EXPECT_FALSE(reachability->IsConnected(negate, add)); - EXPECT_FALSE(reachability->IsConnected(add, negate)); - - // Remove the control dependency then update and verify the reachability map - ASSERT_IS_OK(add->RemoveControlDependencyTo(exp)); - computation->UpdateReachabilityThroughInstruction(exp, reachability.get()); - - EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); - EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); - EXPECT_TRUE(reachability->IsReachable(constant1, add)); - EXPECT_FALSE(reachability->IsReachable(constant1, negate)); - EXPECT_FALSE(reachability->IsReachable(constant1, exp)); - EXPECT_TRUE(reachability->IsReachable(constant1, mul)); - EXPECT_FALSE(reachability->IsReachable(constant1, copy)); - - // Change a use within the graph then update and verify the reachability map - ASSERT_IS_OK(constant2->ReplaceUseWith(negate, constant1)); - computation->UpdateReachabilityThroughInstruction(negate, reachability.get()); - - EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); - EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); - EXPECT_TRUE(reachability->IsReachable(constant2, add)); - EXPECT_FALSE(reachability->IsReachable(constant2, negate)); - EXPECT_FALSE(reachability->IsReachable(constant2, exp)); - EXPECT_TRUE(reachability->IsReachable(constant2, mul)); - EXPECT_FALSE(reachability->IsReachable(constant2, copy)); -} - TEST_F(HloComputationTest, Stringification) { const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); @@ -700,27 +599,5 @@ TEST_F(HloComputationTest, StringificationCanonical) { EXPECT_EQ(computation->ToString(options), expected_computation2); } -TEST_F(HloComputationTest, ChannelReachability) { - const Shape shape = ShapeUtil::MakeShape(F32, {5, 7}); - HloComputation::Builder builder("ChannelReachability"); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "param")); - auto token0 = builder.AddInstruction(HloInstruction::CreateToken()); - auto send = - builder.AddInstruction(HloInstruction::CreateSend(param, token0, 1)); - auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); - auto token1 = builder.AddInstruction(HloInstruction::CreateToken()); - auto recv = - builder.AddInstruction(HloInstruction::CreateRecv(shape, token1, 1)); - auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); - - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build(recv_done)); - auto reachability = computation->ComputeReachability(); - EXPECT_TRUE(reachability->IsReachable(param, recv_done)); - EXPECT_FALSE(reachability->IsReachable(send, recv)); - EXPECT_FALSE(reachability->IsReachable(send_done, recv)); -} - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index ce4cad42355ec5..2df8eb962ae54e 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -28,7 +28,8 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" namespace xla { -HloProfileIndexMap::HloProfileIndexMap(const HloModule& module) { +HloProfileIndexMap::HloProfileIndexMap(const HloModule& module, + absl::Span extra_metrics) { size_t current_profile_index = 0; for (xla::HloComputation* computation : module.MakeComputationPostOrder()) { InsertOrDie(&computation_to_profile_idx_, computation, @@ -40,11 +41,15 @@ HloProfileIndexMap::HloProfileIndexMap(const HloModule& module) { current_profile_index++); } } + for (const string& key : extra_metrics) { + InsertOrDie(&extra_metric_to_profile_idx_, key, current_profile_index++); + } } std::unique_ptr CreateHloProfilePrinterData( const HloProfileIndexMap& hlo_profile_index_map, - const HloCostAnalysis& cost_analysis) { + const HloCostAnalysis& cost_analysis, + const string& entry_computation_name) { using HloComputationInfo = HloProfilePrinterData::HloComputationInfo; using HloInstructionInfo = HloProfilePrinterData::HloInstructionInfo; @@ -105,6 +110,14 @@ std::unique_ptr CreateHloProfilePrinterData( } } + // Add extra metrics if any. + for (const auto& pair : hlo_profile_index_map.extra_metric_to_profile_idx()) { + profile_printer_data->mutable_extra_metrics()->insert( + {pair.first, pair.second}); + } + + profile_printer_data->set_entry_computation(entry_computation_name); + return profile_printer_data; } diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.h b/tensorflow/compiler/xla/service/hlo_execution_profile.h index be989846ef5cd2..da30e15908328f 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.h +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EXECUTION_PROFILE_H_ #include +#include #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" @@ -34,7 +35,10 @@ class HloInstruction; class HloProfileIndexMap { public: // Scans `module` to populate this instance of HloProfileIndexMap. - explicit HloProfileIndexMap(const HloModule& module); + explicit HloProfileIndexMap(const HloModule& module) + : HloProfileIndexMap(module, {}) {} + explicit HloProfileIndexMap(const HloModule& module, + absl::Span extra_metrics); HloProfileIndexMap(const HloProfileIndexMap&) = default; HloProfileIndexMap(HloProfileIndexMap&&) = default; @@ -50,6 +54,10 @@ class HloProfileIndexMap { return FindOrDie(computation_to_profile_idx(), &computation); } + size_t GetProfileIndexFor(const string& key) const { + return xla::FindOrDie(extra_metric_to_profile_idx(), key); + } + size_t instruction_count() const { return instruction_to_profile_idx().size(); } @@ -58,8 +66,12 @@ class HloProfileIndexMap { return computation_to_profile_idx().size(); } + size_t extra_metrics_count() const { + return extra_metric_to_profile_idx().size(); + } + size_t total_count() const { - return instruction_count() + computation_count(); + return instruction_count() + computation_count() + extra_metrics_count(); } const std::unordered_map& @@ -72,15 +84,20 @@ class HloProfileIndexMap { return computation_to_profile_idx_; } + const std::unordered_map& extra_metric_to_profile_idx() const { + return extra_metric_to_profile_idx_; + } + private: std::unordered_map instruction_to_profile_idx_; std::unordered_map computation_to_profile_idx_; + std::unordered_map extra_metric_to_profile_idx_; }; // Create an instance of `HloProfilePrinterData`. std::unique_ptr CreateHloProfilePrinterData( const HloProfileIndexMap& hlo_profile_index_map, - const HloCostAnalysis& cost_analysis); + const HloCostAnalysis& cost_analysis, const string& entry_computation_name); // Describes how much time each HLO operation took. // @@ -113,6 +130,12 @@ class HloExecutionProfile { total_cycles_executed; } + // Record extra metric. + void set_extra_metrics(const string& metric, uint64 value) { + profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(metric)] = + value; + } + // Returns a version of the execution profile suitable for performance // debugging; e.g. emits cycle counts, execution time at the nominal device // frequency, and the effective throughput given the provided cost_analysis diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc index 460ae2b5eca786..5be9dba3aa49d6 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -54,7 +54,8 @@ TEST_F(HloExecutionProfileTest, Basic) { HloCostAnalysis cost_analysis(shape_size_function); HloProfileIndexMap profile_index_map(*hlo_module); std::unique_ptr profile_printer = - CreateHloProfilePrinterData(profile_index_map, cost_analysis); + CreateHloProfilePrinterData(profile_index_map, cost_analysis, + hlo_module->entry_computation()->name()); HloExecutionProfile execution_profile(profile_printer.get(), &profile_index_map); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index f6ed86b41650fd..ada536770ed47b 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -2639,46 +2639,6 @@ Status HloInstruction::Accept( return this->Accept(&visitor); } -Status HloInstruction::AcceptOrdered( - DfsHloVisitor* visitor, const std::vector& order) { - VLOG(2) << "HloInstruction::AcceptOrdered(%" << name() << ")"; - TF_RET_CHECK(OrderIsTopologicalSort(order)); - - // Compute the predecessors of this instruction. - std::unordered_set predecessors; - TF_RETURN_IF_ERROR(this->Accept([&predecessors](HloInstruction* instruction) { - predecessors.insert(instruction); - return Status::OK(); - })); - - for (auto* const_instruction : order) { - if (!ContainsKey(predecessors, const_instruction)) { - // Instruction is not a predecessors of 'this'. - continue; - } - - // The visitor can mark instructions as visited to skip particular - // instructions. - if (visitor->DidVisit(*const_instruction)) { - VLOG(3) << "Not visiting HLO %" << const_instruction->name() - << " as it was already visited."; - continue; - } - - // TODO(b/78350259): Eliminate const laundering. - HloInstruction* instruction = - const_cast(const_instruction); - - TF_RETURN_IF_ERROR(visitor->Preprocess(instruction)); - VLOG(2) << "Visiting HLO %" << instruction->name(); - TF_RETURN_IF_ERROR(instruction->Visit(visitor)); - visitor->SetVisited(*instruction); - TF_RETURN_IF_ERROR(visitor->Postprocess(instruction)); - } - - return visitor->FinishVisit(this); -} - const Shape& HloInstruction::shape() const { return shape_; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 15a4da8dbe0053..c6a938383ce708 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -954,16 +954,6 @@ class HloInstruction { Status Accept( const std::function& visitor_func) const; - // Visits all instructions rooted at this instruction using the given visitor - // in the given order. 'order' must contain at least the set of instructions - // rooted at this node (ie, those accessible from a DFS traversal from this - // instruction). Instructions contained in 'order' which are not in the set of - // instructions rooted at this node are ignored. 'order' must also be a valid - // topological sort of these instructions (defs appear before uses) though - // need not be a DFS post-order. - Status AcceptOrdered(DfsHloVisitor* visitor, - const std::vector& order); - // Visit this instruction and only this instruction with the given visitor. template Status Visit(DfsHloVisitorBase* visitor); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 5f06dc093248e1..bf4daf2be47ed0 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1144,6 +1144,9 @@ class HloPadInstruction : public HloInstruction { const PaddingConfig& padding_config); // Returns the padding configuration for a pad node. const PaddingConfig& padding_config() const { return padding_config_; } + // Returns the padding value. + const HloInstruction* padding_value() const { return operand(1); } + HloInstruction* mutable_padding_value() { return mutable_operand(1); } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc index 5cee865b7ad34e..234fcd266aa09e 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc @@ -605,6 +605,23 @@ StatusOr HloMemoryScheduler::Run(HloModule* module) { return true; } +StatusOr HloTrivialScheduler::Run(HloModule* module) { + HloSchedule schedule(module); + for (HloComputation* computation : module->MakeComputationPostOrder()) { + if (!computation->IsFusionComputation()) { + HloInstructionSequence& computation_sequence = + schedule.GetOrCreateSequence(computation); + TF_RETURN_IF_ERROR(computation->Accept( + [&computation_sequence](HloInstruction* instruction) { + computation_sequence.push_back(instruction); + return Status::OK(); + })); + } + } + TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); + return true; +} + StatusOr HloDescheduler::Run(HloModule* module) { bool changed = module->has_schedule(); module->clear_schedule(); diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h index a4c1d3db8170a1..cca5dc49398981 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h @@ -108,6 +108,15 @@ class HloMemoryScheduler : public HloModulePass { MemorySchedulerAlgorithm algorithm_; }; +// A pass which produces a naive, but correct schedule. The schedule is produced +// using a DFS traversal of the graph with no attempt to minimize memory use. +class HloTrivialScheduler : public HloModulePass { + public: + absl::string_view name() const override { return "hlo-trivial-scheduler"; } + + StatusOr Run(HloModule* module) override; +}; + // A trivial pass which clears the schedule currently set on the // HloModule. After this pass runs HloModudle::has_schedule will return false. class HloDescheduler : public HloModulePass { diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc index 214119fba881c4..2f15997fc175c4 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc @@ -309,5 +309,40 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { .ValueOrDie()); } +TEST_F(HloSchedulingTest, TrivialScheduler) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + param.b = (s32[], s32[]) parameter(0) + gte.0 = s32[] get-tuple-element(param.b), index=0 + gte.1 = s32[] get-tuple-element(param.b), index=1 + add = s32[] add(gte.0, gte.1) + ROOT tuple = (s32[], s32[]) tuple(gte.0, add) +} + +cond { + param.c = (s32[], s32[]) parameter(0) + ROOT constant = pred[] constant(true) +} + +ENTRY main { + init = (s32[], s32[]) parameter(0) + ROOT while = (s32[], s32[]) while(init), condition=cond, body=body +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + EXPECT_FALSE(module->has_schedule()); + TF_ASSERT_OK(HloTrivialScheduler().Run(module.get()).status()); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK(module->schedule().Verify()); + + // Verify that a clone of the module also has a schedule. + std::unique_ptr clone = module->Clone(); + ASSERT_TRUE(clone->has_schedule()); + TF_ASSERT_OK(clone->schedule().Verify()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index ab819b7031858e..6a838b7eb969d5 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -41,18 +41,6 @@ HloModule::HloModule(const string& name, const HloModuleConfig& config) config_(config), unique_id_(next_unique_module_id_++) {} -StatusOr HloModule::LaunderConstInstructionFromModule( - const HloInstruction* hlo) { - if (hlo == nullptr) { - return nullptr; - } - - TF_RET_CHECK(hlo->GetModule() == this); - - // TODO(b/78350259): Eliminate const laundering. - return const_cast(hlo); -} - Status HloModule::set_schedule(HloSchedule schedule) { TF_RET_CHECK(schedule.module() == this); TF_RETURN_IF_ERROR(schedule.Verify()); @@ -576,6 +564,22 @@ std::unique_ptr HloModule::Clone(const HloModuleConfig& config, HloCloneContext context(module.get(), suffix); auto cloned_computation = entry_computation_->Clone(suffix, &context); module->AddEntryComputation(std::move(cloned_computation)); + + if (has_schedule() && schedule().Verify().ok()) { + HloSchedule clone_schedule(module.get()); + for (HloComputation* computation : computations()) { + if (schedule().is_computation_scheduled(computation)) { + HloInstructionSequence& clone_sequence = + clone_schedule.GetOrCreateSequence( + context.GetComputation(computation)); + for (const HloInstruction* instruction : + schedule().sequence(computation).instructions()) { + clone_sequence.push_back(context.GetInstruction(instruction)); + } + } + } + TF_CHECK_OK(module->set_schedule(std::move(clone_schedule))); + } return module; } diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 5dc795fabec5d8..8a1f999e3ab076 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -236,25 +236,6 @@ class HloModule { // the lifetime of this process. int unique_id() const { return unique_id_; } - // Returns a non-const version of the passed-in const HloInstruction*. This is - // safe on the argument that if you have a non-const module, then you can - // access all instructions in the module as non-const. - // - // Returns an error if the passed-in instruction is not from this module, - // except that it is allowed to pass in a null pointer. - // - // TODO(b/78350259): Eliminate const laundering. The argument above is not - // reliable since at any time someone could add or discover a way for a - // non-const module to transitively contain a const HloInstruction. The - // reliable way to do this would be to create a const laundering map from a - // module, mapping each encountered HloInstruction to its non-const version - // and then look up each instruction in need of laundering in that map, but - // this is much more expensive and complicated. This returns a Status instead - // of doing a CHECK-failure in part to make it strongly apparent that this is - // something that can fail. - StatusOr LaunderConstInstructionFromModule( - const HloInstruction* hlo); - // Sets the schedule of the module to the given schedule. Status set_schedule(HloSchedule schedule); diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 23d41d91d6969d..1c93641a588811 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -334,7 +334,7 @@ DependencyHloOrdering::DependencyHloOrdering(const HloModule* module) // ordering based on dependencies. ExecutesBefore will return true iff there // exists a path in the HLO computation graph from 'a' to 'b'. for (auto* computation : module->MakeNonfusionComputations()) { - predecessors_.emplace(computation, computation->ComputeReachability()); + predecessors_.emplace(computation, HloReachabilityMap::Build(computation)); } } diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index 66313492eb2dd1..4dbe44769a2212 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index aec4902e9b6f5b..450660b94b783b 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -418,6 +418,18 @@ std::pair* HloParser::FindInstruction( } return create_missing_instruction_(name, *shape); } + + if (instr != nullptr && shape.has_value() && + !ShapeUtil::Compatible(instr->first->shape(), shape.value())) { + Error( + lexer_.GetLoc(), + StrCat("The declared operand shape ", + ShapeUtil::HumanStringWithLayout(shape.value()), + " is not compatible with the shape of the operand instruction ", + ShapeUtil::HumanStringWithLayout(instr->first->shape()), ".")); + return nullptr; + } + return instr; } @@ -1794,6 +1806,10 @@ bool HloParser::SetValueInLiteral(tensorflow::int64 value, case U64: return SetValueInLiteralHelper(value, linear_index, literal); + case PRED: + // Bool type literals with rank >= 1 are printed in 0s and 1s. + return SetValueInLiteralHelper(static_cast(value), + linear_index, literal); default: LOG(FATAL) << "unknown integral primitive type " << PrimitiveType_Name(shape.element_type()); @@ -2048,14 +2064,13 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { } if (lexer_.GetKind() == TokKind::kw_true || lexer_.GetKind() == TokKind::kw_false) { - // TODO(congliu): bool type literals with rank >= 1 are actually - // printed in a compact form instead of "true" or "false". Fix that. if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true, linear_index++, literal)) { return false; } lexer_.Lex(); - } else if (primitive_util::IsIntegralType(shape.element_type())) { + } else if (primitive_util::IsIntegralType(shape.element_type()) || + shape.element_type() == PRED) { LocTy loc = lexer_.GetLoc(); tensorflow::int64 value; if (!ParseInt64(&value)) { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 19f84d8bd28371..c59bdc0a0b372d 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -75,6 +75,18 @@ ENTRY %constant_pred () -> pred[] { )" }, +// pred array constant +{ +"ConstantPredArray", +R"(HloModule module + +ENTRY %constant_pred_array () -> pred[2,3] { + ROOT %constant = pred[2,3]{1,0} constant(pred[2,3] { { 0, 1, 0 }, { 1, 0, 1 } }) +} + +)" +}, + // s32 constant { "ConstantS32", @@ -1138,6 +1150,25 @@ ENTRY CrossReplicaSumWithSubgroups { ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_groups={{0,1},{2,3}}, barrier="abc", to_apply=add } +)" +}, +// cross-replica-sum with all-reduce-id +{ +"CrossReplicaSumAllReduce", +R"(HloModule CRS + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY CRS { + input = f32[8]{0} parameter(0) + crs.1 = f32[8]{0} cross-replica-sum(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add + ROOT crs.0 = f32[8]{0} cross-replica-sum(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add +} + )" }, // all-to-all @@ -2161,6 +2192,22 @@ ENTRY entry { ParseHloString(text)); } +TEST_F(HloParserTest, ShapeMismatchInOperand) { + const string text = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2,2]) -> f32[2,2] { + %p = f32[2,2] parameter(0) + %constant.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + ROOT %add.1 = f32[2,2] add(f32[2,2] %p, f32[2,5] %constant.1) +} +)"; + + ExpectHasSubstr(ParseHloString(text).status().error_message(), + "The declared operand shape f32[2,5]{1,0} is not compatible" + " with the shape of the operand instruction f32[2,2]{1,0}."); +} + // custom call incompatible shape. } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer.cc b/tensorflow/compiler/xla/service/hlo_profile_printer.cc index dcc22793015147..5eb707a957e49d 100644 --- a/tensorflow/compiler/xla/service/hlo_profile_printer.cc +++ b/tensorflow/compiler/xla/service/hlo_profile_printer.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_profile_printer.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" namespace xla { @@ -25,6 +26,11 @@ string PrintHloProfile(const HloProfilePrinterData& hlo_profile_printer_data, string result; + for (const auto& item : hlo_profile_printer_data.extra_metrics()) { + absl::StrAppend(&result, "Extra metric ", item.first, ": ", + counters[item.second], "\n"); + } + for (const HloComputationInfo& computation_info : hlo_profile_printer_data.computation_infos()) { const auto& instruction_infos = computation_info.instruction_infos(); @@ -41,8 +47,9 @@ string PrintHloProfile(const HloProfilePrinterData& hlo_profile_printer_data, // Once we start using this in AOT for real, we will probably need a more // minimal version of HumanReadableProfileBuilder. HumanReadableProfileBuilder builder( - computation_info.name(), counters[computation_info.profile_index()], - clock_rate_ghz); + computation_info.name(), + hlo_profile_printer_data.entry_computation() == computation_info.name(), + counters[computation_info.profile_index()], clock_rate_ghz); for (const auto& instruction_info : instruction_infos) { builder.AddOp( diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer_data.proto b/tensorflow/compiler/xla/service/hlo_profile_printer_data.proto index 9f22b733fe1d67..ee66c86ffcb4fb 100644 --- a/tensorflow/compiler/xla/service/hlo_profile_printer_data.proto +++ b/tensorflow/compiler/xla/service/hlo_profile_printer_data.proto @@ -57,4 +57,10 @@ message HloProfilePrinterData { // The size of the profile counters array we will pretty-print. int64 profile_counters_size = 2; + + // Maps extra metric name to the index into the profile counters array. + map extra_metrics = 3; + + // Name of the entry computation. + string entry_computation = 4; } diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 961930f0a888e9..7e73cf5889c37b 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/xla/service/hlo_reachability.h" namespace xla { @@ -71,4 +73,70 @@ bool HloReachabilityMap::IsConnected(const HloInstruction* a, return IsReachable(a, b) || IsReachable(b, a); } +std::unique_ptr HloReachabilityMap::Build( + const HloComputation* computation) { + const auto& all = computation->MakeInstructionPostOrder(); + auto result = absl::make_unique(all); + auto channel_dependency_map = computation->ComputeChannelDependencies(); + + std::vector inputs; + for (const HloInstruction* hlo : all) { + inputs.assign(hlo->operands().begin(), hlo->operands().end()); + inputs.insert(inputs.end(), hlo->control_predecessors().begin(), + hlo->control_predecessors().end()); + + switch (hlo->opcode()) { + case HloOpcode::kRecvDone: { + auto it = channel_dependency_map.find(hlo->channel_id()); + if (it != channel_dependency_map.end()) { + absl::c_copy(it->second, std::back_inserter(inputs)); + } + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = hlo->all_reduce_id(); + if (all_reduce_id) { + auto it = channel_dependency_map.find(all_reduce_id.value()); + if (it != channel_dependency_map.end()) { + absl::c_copy(it->second, std::back_inserter(inputs)); + } + } + break; + } + default: + break; + } + + result->FastSetReachabilityToUnion(inputs, hlo); + } + return result; +} + +void HloReachabilityMap::UpdateReachabilityThroughInstruction( + const HloInstruction* instruction) { + std::queue worklist; + worklist.push(instruction); + + std::vector inputs; + + while (!worklist.empty()) { + const HloInstruction* item = worklist.front(); + worklist.pop(); + + inputs.assign(item->operands().begin(), item->operands().end()); + inputs.insert(inputs.end(), item->control_predecessors().begin(), + item->control_predecessors().end()); + + if (SetReachabilityToUnion(inputs, item)) { + // Add immediate successors to worklist. + for (const HloInstruction* user : item->users()) { + worklist.push(user); + } + for (const HloInstruction* succ : item->control_successors()) { + worklist.push(succ); + } + } + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index 5a5f01f8fd647c..2c965f58bfadfb 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" @@ -32,11 +33,11 @@ class HloInstruction; // A class for representing reachability between HloInstructions. // -// !!! THIS CLASS DOES NOT COMPUTE REACHABILITY !!! It has an adjacency matrix -// and it is up to the user of the class to set the adjacency matrix such that -// it represents reachability, i.e. such that it is transitive. That the graph -// be transitive is thus not an invariant of this class, but it is required for -// the name of the class and its methods to make sense. +// It has an adjacency matrix and it is up to the user of the class to set the +// adjacency matrix such that it represents reachability, i.e. such that it is +// transitive. That the graph be transitive is thus not an invariant of this +// class, but it is required for the name of the class and its methods to make +// sense. class HloReachabilityMap { public: // Sets up a graph with no edges and where the nodes correspond to the given @@ -44,6 +45,15 @@ class HloReachabilityMap { explicit HloReachabilityMap( absl::Span instructions); + // Computes and returns the reachability between HLO instructions in the + // computation. The returned HloReachabilityMap is constructed such that + // HloReachabilityMap::IsReachable(a, b) returns true iff there exists a + // directed path (from producer to consumer) from 'a' to 'b'. Both data + // dependencies (operands) and control dependencies are considered for + // reachability. Trivially an instruction is reachable from itself. + static std::unique_ptr Build( + const HloComputation* computation); + // Set the reachability set of 'instruction' to the union of the reachability // sets of 'inputs'. Upon return, IsReachable(x, instruction) where // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true @@ -70,6 +80,10 @@ class HloReachabilityMap { // adjacency matrix. void SetReachable(const HloInstruction* a, const HloInstruction* b); + // Updates the given reachability map after the immediate predecessor set + // (operands and control predecessors) of 'instruction' has changed. + void UpdateReachabilityThroughInstruction(const HloInstruction* instruction); + // Returns true if "b" is reachable from "a" // // Note that this function only correctly answers queries about reachability @@ -82,6 +96,9 @@ class HloReachabilityMap { // if the set of edges that have been provided to this class are transitive. bool IsConnected(const HloInstruction* a, const HloInstruction* b) const; + // Checks if an instruction is in the Reachability map. + bool IsPresent(const HloInstruction* a) const { return indices_.contains(a); } + private: // A bit-vector implementation specialized for this use case which provides a // fast bitwise OR operation not available in tensorflow::gtl::BitMap. diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc index d9848cee0bfa90..21265d9f222527 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability_test.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc @@ -81,6 +81,130 @@ TEST_F(HloReachabilityTest, Reachability) { EXPECT_FALSE(reachability.SetReachabilityToUnion({b, c}, d)); } +TEST_F(HloReachabilityTest, NonTrivialReachability) { + // Test reachability of a non-trivial computation: + // + // const1 const2 + // | | + // | +-------+ + // | | | + // add .. negate + // | . | + // | .... exp + // | | + // +---+ +-+---+ + // | | | + // multiply copy + // + // There is a control dependency from 'add' to 'exp'. + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32, HloOpcode::kAdd, constant1, constant2)); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kNegate, constant2)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, negate)); + auto mul = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, add, exp)); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kCopy, exp)); + + auto module = CreateNewVerifiedModule(); + auto computation = + module->AddEntryComputation(builder.Build(/*root_instruction=*/mul)); + + TF_CHECK_OK(add->AddControlDependencyTo(exp)); + auto reachability = HloReachabilityMap::Build(computation); + + EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); + EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant1, add)); + EXPECT_FALSE(reachability->IsReachable(constant1, negate)); + EXPECT_TRUE(reachability->IsReachable(constant1, exp)); + EXPECT_TRUE(reachability->IsReachable(constant1, mul)); + EXPECT_TRUE(reachability->IsReachable(constant1, copy)); + + EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); + EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant2, add)); + EXPECT_TRUE(reachability->IsReachable(constant2, negate)); + EXPECT_TRUE(reachability->IsReachable(constant2, exp)); + EXPECT_TRUE(reachability->IsReachable(constant2, mul)); + EXPECT_TRUE(reachability->IsReachable(constant2, copy)); + + EXPECT_FALSE(reachability->IsReachable(exp, constant1)); + EXPECT_FALSE(reachability->IsReachable(exp, constant2)); + EXPECT_FALSE(reachability->IsReachable(exp, add)); + EXPECT_FALSE(reachability->IsReachable(exp, negate)); + EXPECT_TRUE(reachability->IsReachable(exp, exp)); + EXPECT_TRUE(reachability->IsReachable(exp, mul)); + EXPECT_TRUE(reachability->IsReachable(exp, copy)); + + EXPECT_FALSE(reachability->IsReachable(mul, constant1)); + EXPECT_FALSE(reachability->IsReachable(mul, constant2)); + EXPECT_FALSE(reachability->IsReachable(mul, add)); + EXPECT_FALSE(reachability->IsReachable(mul, negate)); + EXPECT_FALSE(reachability->IsReachable(mul, exp)); + EXPECT_TRUE(reachability->IsReachable(mul, mul)); + EXPECT_FALSE(reachability->IsReachable(mul, copy)); + + EXPECT_TRUE(reachability->IsConnected(constant1, copy)); + EXPECT_TRUE(reachability->IsConnected(copy, constant1)); + EXPECT_FALSE(reachability->IsConnected(negate, add)); + EXPECT_FALSE(reachability->IsConnected(add, negate)); + + // Remove the control dependency then update and verify the reachability map + ASSERT_IS_OK(add->RemoveControlDependencyTo(exp)); + reachability->UpdateReachabilityThroughInstruction(exp); + + EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); + EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant1, add)); + EXPECT_FALSE(reachability->IsReachable(constant1, negate)); + EXPECT_FALSE(reachability->IsReachable(constant1, exp)); + EXPECT_TRUE(reachability->IsReachable(constant1, mul)); + EXPECT_FALSE(reachability->IsReachable(constant1, copy)); + + // Change a use within the graph then update and verify the reachability map + ASSERT_IS_OK(constant2->ReplaceUseWith(negate, constant1)); + reachability->UpdateReachabilityThroughInstruction(negate); + + EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); + EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant2, add)); + EXPECT_FALSE(reachability->IsReachable(constant2, negate)); + EXPECT_FALSE(reachability->IsReachable(constant2, exp)); + EXPECT_TRUE(reachability->IsReachable(constant2, mul)); + EXPECT_FALSE(reachability->IsReachable(constant2, copy)); +} + +TEST_F(HloReachabilityTest, ChannelReachability) { + const Shape shape = ShapeUtil::MakeShape(F32, {5, 7}); + HloComputation::Builder builder("ChannelReachability"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto token0 = builder.AddInstruction(HloInstruction::CreateToken()); + auto send = + builder.AddInstruction(HloInstruction::CreateSend(param, token0, 1)); + auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); + auto token1 = builder.AddInstruction(HloInstruction::CreateToken()); + auto recv = + builder.AddInstruction(HloInstruction::CreateRecv(shape, token1, 1)); + auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); + + auto module = CreateNewVerifiedModule(); + auto computation = module->AddEntryComputation(builder.Build(recv_done)); + auto reachability = HloReachabilityMap::Build(computation); + EXPECT_TRUE(reachability->IsReachable(param, recv_done)); + EXPECT_FALSE(reachability->IsReachable(send, recv)); + EXPECT_FALSE(reachability->IsReachable(send_done, recv)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc index 0778ff52174ef8..a5780b7551a43f 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule.cc @@ -264,7 +264,10 @@ Status HloSchedule::Verify() const { } TF_RET_CHECK(instruction_position.size() == - computation->instruction_count()); + computation->instruction_count()) + << "Schedule for computation " << computation->name() << " has " + << instruction_position.size() << " instructions, expected " + << computation->instruction_count(); for (const HloInstruction* instruction : computation->instructions()) { TF_RET_CHECK(instruction_position.count(instruction) == 1) << "Instruction " << instruction->name() << " is not in schedule"; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index af78326b81f8be..a2a6fb7c77e73d 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" @@ -65,7 +66,9 @@ Status ShapeVerifier::Preprocess(HloInstruction* hlo) { return VerifyNotSparse(hlo->shape()); } -static Status CheckOperandCount(const HloInstruction* hlo, int expected) { +namespace { + +Status CheckOperandCount(const HloInstruction* hlo, int expected) { if (hlo->operand_count() != expected) { return InternalError("Expected %d operands for %s instruction: %s", expected, HloOpcodeString(hlo->opcode()), @@ -74,6 +77,19 @@ static Status CheckOperandCount(const HloInstruction* hlo, int expected) { return Status::OK(); } +Status CheckParameterCount(const HloInstruction* calling_instruction, + const HloComputation* computation, int expected) { + if (computation->num_parameters() != expected) { + return InternalError( + "Expected computation %s called from %s to have %d parameters, has %d", + computation->name(), calling_instruction->name(), expected, + computation->num_parameters()); + } + return Status::OK(); +} + +} // namespace + Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) { return CheckUnaryShape(hlo); } @@ -441,6 +457,8 @@ Status ShapeVerifier::HandleFusion(HloInstruction* fusion) { } Status ShapeVerifier::HandleCall(HloInstruction* call) { + TF_RETURN_IF_ERROR( + CheckParameterCount(call, call->to_apply(), call->operand_count())); for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) { TF_RETURN_IF_ERROR(CheckOperandAndParameter(call, i, call->to_apply(), i)); } @@ -540,6 +558,10 @@ Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) { Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { TF_RETURN_IF_ERROR(CheckOperandCount(xla_while, 1)); + TF_RETURN_IF_ERROR( + CheckParameterCount(xla_while, xla_while->while_body(), 1)); + TF_RETURN_IF_ERROR( + CheckParameterCount(xla_while, xla_while->while_condition(), 1)); TF_RETURN_IF_ERROR( CheckOperandAndParameter(xla_while, 0, xla_while->while_body(), 0)); TF_RETURN_IF_ERROR( @@ -560,6 +582,10 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { Status ShapeVerifier::HandleConditional(HloInstruction* conditional) { TF_RETURN_IF_ERROR(CheckOperandCount(conditional, 3)); + TF_RETURN_IF_ERROR( + CheckParameterCount(conditional, conditional->true_computation(), 1)); + TF_RETURN_IF_ERROR( + CheckParameterCount(conditional, conditional->false_computation(), 1)); TF_RETURN_IF_ERROR(CheckOperandAndParameter( conditional, 1, conditional->true_computation(), 0)); TF_RETURN_IF_ERROR(CheckOperandAndParameter( @@ -1306,6 +1332,15 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { return Status::OK(); } + Status HandleCrossReplicaSum(HloInstruction* crs) override { + if (crs->all_reduce_id().has_value()) { + TF_RET_CHECK(crs->all_reduce_id().value() > 0) + << "All reduce id must be greater than 0 for " + << crs->ToShortString(); + } + return Status::OK(); + } + Status Preprocess(HloInstruction* instruction) override { auto previous = instructions_by_name_.find(instruction->name()); TF_RET_CHECK(previous == instructions_by_name_.end()) diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index e76b93107c923b..e103222b55facc 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -112,8 +112,9 @@ string HumanReadableProfileBuilder::ToString() const { VLOG(1) << "Total floating point ops: " << total_flops; - print_op({"[total]", "[total]", /*category=*/"", total_cycles_, total_flops, - total_transcendentals, total_bytes, optimal_seconds_sum}, + print_op({is_entry_computation_ ? "[total] [entry]" : "[total]", "[total]", + /*category=*/"", total_cycles_, total_flops, total_transcendentals, + total_bytes, optimal_seconds_sum}, /*is_total=*/true); // Sort ops in decreasing order of cycles, and print them. diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.h b/tensorflow/compiler/xla/service/human_readable_profile_builder.h index 925111fa1f1e48..d4e5cbbe27418d 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.h +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.h @@ -30,9 +30,11 @@ namespace xla { class HumanReadableProfileBuilder { public: explicit HumanReadableProfileBuilder(absl::string_view computation_name, + bool is_entry_computation, int64 total_cycles, double clock_rate_ghz) : computation_name_(computation_name), + is_entry_computation_(is_entry_computation), total_cycles_(total_cycles), clock_rate_ghz_(clock_rate_ghz) { CHECK_GE(clock_rate_ghz, 1e-9); @@ -75,6 +77,7 @@ class HumanReadableProfileBuilder { } string computation_name_; + bool is_entry_computation_; int64 total_cycles_; double clock_rate_ghz_; std::vector op_infos_; diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 69a4c160ee5c45..426c1256080ac9 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -26,7 +26,9 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/fusion_queue.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" @@ -437,8 +439,7 @@ class ReversePostOrderFusionQueue : public FusionQueue { } // namespace std::unique_ptr InstructionFusion::GetFusionQueue( - HloComputation* computation, - const std::function& skip_producer) { + HloComputation* computation) { return absl::make_unique(computation); } @@ -451,14 +452,11 @@ StatusOr InstructionFusion::Run(HloModule* module) { for (auto* computation : module->MakeNonfusionComputations()) { CHECK(!computation->IsFusionComputation()); computation_ = computation; - reachability_ = computation_->ComputeReachability(); + reachability_ = HloReachabilityMap::Build(computation_); HloInstructionSet do_not_duplicate = ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder()); - auto fusion_queue = - GetFusionQueue(computation_, [&](HloInstruction* producer) { - return do_not_duplicate.count(producer) > 0; - }); + auto fusion_queue = GetFusionQueue(computation_); // Instruction fusion effectively fuses edges in the computation graph // (producer instruction -> consumer instruction) so we iterate over all @@ -489,9 +487,8 @@ StatusOr InstructionFusion::Run(HloModule* module) { HloInstruction* fusion_instruction; // Try "regular" fusion if the operand may be duplicated. Otherwise, // perform multi-output fusion, unless this creates a cycle. - // TODO(tjoerg): Consider making multi-output fusion the default. - if (ShouldFuse(instruction, i) && - do_not_duplicate.count(operand) == 0) { + if (do_not_duplicate.count(operand) == 0 && + ShouldFuse(instruction, i)) { fusion_queue->PreFusion(operand, instruction); fusion_instruction = Fuse(operand, instruction); } else if (ShouldFuseIntoMultiOutput(instruction, i) && @@ -565,15 +562,19 @@ HloInstruction* InstructionFusion::FuseIntoMultiOutput( bool InstructionFusion::MultiOutputFusionCreatesCycle( HloInstruction* producer, HloInstruction* consumer) { - return absl::c_any_of( - consumer->operands(), [&](const HloInstruction* consumer_operand) { - // The fusion algorithm traverses the HLO graph in reverse post order. - // Thus `cosumers` is visited before its operands (including - // `producer`). Therefore, consumer operands cannot have been fused yet. - // It is thus safe to use the pre-computed reachability map. - return consumer_operand != producer && - reachability_->IsReachable(producer, consumer_operand); - }); + auto is_reachable = [&](const HloInstruction* a, const HloInstruction* b) { + // A consumer operand may have been multii-output fused into a parallel + // consumer and thus be missing from the oridinal reachability map. + if (!reachability_->IsPresent(a) || !reachability_->IsPresent(b)) { + reachability_ = HloReachabilityMap::Build(consumer->parent()); + } + return reachability_->IsReachable(a, b); + }; + return absl::c_any_of(consumer->operands(), + [&](const HloInstruction* consumer_operand) { + return consumer_operand != producer && + is_reachable(producer, consumer_operand); + }); } bool InstructionFusion::ShouldFuse(HloInstruction* consumer, diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index f14c6675208c72..198bd7fce5f392 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -54,8 +55,7 @@ class InstructionFusion : public HloModulePass { // fused. The default implementation processes consumers in reverse post // order. virtual std::unique_ptr GetFusionQueue( - HloComputation* computation, - const std::function& skip_producer); + HloComputation* computation); // Returns whether the given producer instruction should be fused into the // given consumer instruction. producer is necessarily an operand of consumer. @@ -111,6 +111,10 @@ class InstructionFusion : public HloModulePass { return is_expensive_(instruction); } + // Whether multi-output fusion would introduce a cycle into the HLO graph. + bool MultiOutputFusionCreatesCycle(HloInstruction* producer, + HloInstruction* consumer); + // Current HloComputation instance the loop fuser is traversing. HloComputation* computation_; HloModule* module_; @@ -145,10 +149,6 @@ class InstructionFusion : public HloModulePass { // duplicated. std::function is_expensive_; - // Whether multi-output fusion would introduce a cycle into the HLO graph. - bool MultiOutputFusionCreatesCycle(HloInstruction* producer, - HloInstruction* consumer); - // Returns whether we may duplicate an instruction if we want to fuse it. bool may_duplicate_; diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 232d1dc0879cd6..6b03394669858e 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -449,7 +449,6 @@ Status LayoutAssignment::AddMandatoryConstraints( // instruction. // TODO(b/31425034): Change infeeds to be more like parameters, with // shapes in the ComputationLayout. - DCHECK(!LayoutUtil::IsPadded(instruction->shape())); TF_RETURN_IF_ERROR( constraints->SetInstructionLayout(instruction->shape(), instruction)); } else if (instruction->opcode() == HloOpcode::kOutfeed) { @@ -989,10 +988,8 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( const Layout& output_layout, const HloInstruction* instruction, int64 operand_no) { const HloInstruction* operand = instruction->operand(operand_no); - CHECK(ShapeUtil::IsArray(instruction->shape())); CHECK(ShapeUtil::IsArray(operand->shape())); - if (!ShapeUtil::IsScalar(operand->shape()) && ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(instruction->shape()) && @@ -1251,12 +1248,20 @@ Status LayoutAssignment::PropagateOperandConstraint( operand_constraint.operand(), constraints)); // For array-shaped operands and user instructions try to pick a minimum cost - // layout. For example, if the operand of a elementwise instruction is - // constained to a certain layout we want the output of the instruction to + // layout. For example, if the operand of an elementwise instruction is + // constrained to a certain layout we want the output of the instruction to // have the same layout. + // + // If the user is not array-shaped, we still want to propagate the layout + // to siblings if the instruction can't change layout. This is to represent + // the information that non-layout-changing instructions should have the same + // layout for the operands with the same ranks. const HloInstruction* operand = operand_constraint.operand(); const HloInstruction* user = operand_constraint.instruction(); - if (!ShapeUtil::IsArray(operand->shape()) || + if (!ShapeUtil::IsArray(operand->shape())) { + return Status::OK(); + } + if (instruction_can_change_layout_func_(user) && !ShapeUtil::IsArray(user->shape())) { return Status::OK(); } @@ -1267,52 +1272,183 @@ Status LayoutAssignment::PropagateOperandConstraint( operand_constraint.operand_no())) { return Status::OK(); } - TF_ASSIGN_OR_RETURN( - const LogicalBuffer* buffer, - constraints->points_to_analysis().GetBufferDefinedAt(user, /*index=*/{})); - if (constraints->BufferLayout(*buffer) == nullptr) { - std::unique_ptr layout = ChooseOutputLayoutFromOperandLayout( - operand_constraint.shape_layout().layout(), user, - operand_constraint.operand_no()); - if (layout != nullptr) { - TF_RETURN_IF_ERROR( - constraints->SetBufferLayout(*layout, *buffer, /*mandatory=*/false)); + int64 operand_rank = ShapeUtil::Rank(operand->shape()); + if (operand_rank <= 1) { + return Status::OK(); + } + + // Propagate layouts between operands of the same instruction. This is a + // constraint on non-layout-changing instructions. + if (!instruction_can_change_layout_func_(user)) { + // Make sure all siblings have the same layout as the operand. + for (int64 operand_no = 0; operand_no < user->operand_count(); + ++operand_no) { + if (user->operand(operand_no) == operand) { + continue; + } + const HloInstruction* sibling = user->operand(operand_no); + const int64 sibling_rank = ShapeUtil::Rank(sibling->shape()); + if (sibling_rank <= 1) { + continue; + } + if (operand_rank != sibling_rank) { + continue; + } + const OperandLayoutConstraint* constraint = + constraints->GetOperandLayoutConstraint(user, operand_no); + if (constraint != nullptr) { + // Due to the DFS of the propagation we can end up here when operand_no + // has a layout set that hasn't been propagated yet (is still on the + // stack of layouts to propagate). + // We can continue here and leave the operands with different layouts, + // as we will either: + // - overwrite the current operand when the DFS gets back to propagating + // operand(operand_no) to its siblings + // - overwrite operand(operand_no)'s layout with a mandatory layout if + // we continue to propagate our layout to the result, and then + // backwards into all operands (if the result is an array of rank > 1) + continue; + } + TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( + operand_constraint.shape_layout().layout(), user, operand_no, + /*mandatory=*/false)); } + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + user->shape(), + [&](const Shape& subshape, const ShapeIndex& shape_index) { + if (ShapeUtil::IsTuple(subshape)) { + return Status::OK(); + } + if (ShapeUtil::Rank(subshape) <= 1) { + return Status::OK(); + } + + // Assign the right layout to input fusion of higher rank reduce + // operations. + if (ShapeUtil::Rank(subshape) != ShapeUtil::Rank(operand->shape())) { + return Status::OK(); + } + // TODO(b/67641796): Are there cases except fusion that use this code + // path? + TF_ASSIGN_OR_RETURN( + const LogicalBuffer* buffer, + constraints->points_to_analysis().GetBufferDefinedAt( + user, shape_index)); + // Make sure the output has the same layout as the operand. + const BufferLayoutConstraint* constraint = + constraints->GetBufferLayoutConstraint(*buffer); + // If we already have a constraint for the buffer it was assigned but + // hasn't propagated yet. This can happen with diamond-shaped graphs + // where one path is first evaluated in depth-first order (we're here) + // and the other path is propagated later. We don't set the layout + // here as it will always be overwritten later. + if (constraint == nullptr) { + TF_RETURN_IF_ERROR(constraints->SetBufferLayout( + operand_constraint.shape_layout().layout(), *buffer, + /*mandatory=*/false)); + } + return Status::OK(); + })); + return Status::OK(); } + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + user->shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) { + if (ShapeUtil::IsTuple(subshape)) { + return Status::OK(); + } + if (ShapeUtil::Rank(subshape) <= 1) { + return Status::OK(); + } + TF_ASSIGN_OR_RETURN( + const LogicalBuffer* buffer, + constraints->points_to_analysis().GetBufferDefinedAt(user, + shape_index)); + if (constraints->BufferLayout(*buffer) == nullptr || + !constraints->GetBufferLayoutConstraint(*buffer)->mandatory()) { + std::unique_ptr layout = ChooseOutputLayoutFromOperandLayout( + operand_constraint.shape_layout().layout(), user, + operand_constraint.operand_no()); + if (layout != nullptr) { + TF_RETURN_IF_ERROR(constraints->SetBufferLayout( + *layout, *buffer, + /*mandatory=*/user->opcode() == HloOpcode::kReduce, + /*dfs=*/false)); + } + } + return Status::OK(); + })); return Status::OK(); } -Status LayoutAssignment::PropagateBufferConstraint( +Status LayoutAssignment::PropagateBufferConstraintToOperands( const BufferLayoutConstraint& buffer_constraint, LayoutConstraints* constraints) { - // Only propagate array layouts. + VLOG(5) << "PropagateBufferConstraintToOperands: " + << buffer_constraint.ToString(); const LogicalBuffer& buffer = buffer_constraint.buffer(); - if (!buffer.IsArray()) { + + const HloInstruction* instruction = buffer.instruction(); + if (IsAtMostRank1(instruction->shape())) { return Status::OK(); } - // If this buffer is the result of an array-shaped op (as opposed to an array - // element in a tuple) try to propagate the layout to its operands. - if (buffer.IsTopLevel()) { - const HloInstruction* instruction = buffer.instruction(); - // Propagate the def-constraint on an instruction to the use-constraints on - // its operands (use-def propagation). - for (int64 operand_no = 0; operand_no < instruction->operand_count(); - ++operand_no) { - if (constraints->OperandLayout(instruction, operand_no) == nullptr && - ShapeUtil::IsArray(instruction->operand(operand_no)->shape())) { + for (int64 operand_no = 0; operand_no < instruction->operand_count(); + ++operand_no) { + const HloInstruction* operand = instruction->operand(operand_no); + if (IsAtMostRank1(operand->shape())) { + continue; + } + if (!instruction_can_change_layout_func_(instruction)) { + // Copy the layout to the operand. + if (buffer.IsArray() && ShapeUtil::IsArray(operand->shape()) && + ShapeUtil::Rank(operand->shape()) == + LayoutUtil::MinorToMajor(buffer_constraint.layout()).size()) { + TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( + buffer_constraint.layout(), instruction, operand_no, + /*mandatory=*/true)); + } + } else { + if (!buffer.IsTopLevel() || + !ShapeUtil::IsArray(instruction->operand(operand_no)->shape())) { + continue; // Don't touch buffers that are internal to a tuple. + } + VLOG(6) << "Propagating constraint to operand " << operand_no << " of " + << instruction->ToShortString(); + // Assign a layout if there is no constraint already. + const OperandLayoutConstraint* constraint = + constraints->GetOperandLayoutConstraint(instruction, operand_no); + if (constraint == nullptr || !constraint->mandatory()) { std::unique_ptr operand_layout = ChooseOperandLayoutFromOutputLayout(buffer_constraint.layout(), instruction, operand_no); if (operand_layout != nullptr) { + // Do not propagate operand constraints of transposes and reshapes, it + // tends to create really bad layouts. TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( - *operand_layout, instruction, operand_no, /*mandatory=*/true)); + *operand_layout, instruction, operand_no, /*mandatory=*/false, + /*dfs=*/false)); } + } else { + VLOG(6) << "Operand already has a constraint " + << constraint->ToString(); } } } - return PropagateBufferConstraintToUses(buffer_constraint, constraints); + return Status::OK(); +} + +Status LayoutAssignment::PropagateBufferConstraint( + const BufferLayoutConstraint& buffer_constraint, + LayoutConstraints* constraints) { + // Only propagate array layouts. + const LogicalBuffer& buffer = buffer_constraint.buffer(); + if (!buffer.IsArray()) { + return Status::OK(); + } + TF_RETURN_IF_ERROR( + PropagateBufferConstraintToUses(buffer_constraint, constraints)); + return PropagateBufferConstraintToOperands(buffer_constraint, constraints); } Status LayoutAssignment::PropagateBufferConstraintToUses( @@ -1340,12 +1476,12 @@ Status LayoutAssignment::PropagateBufferConstraintToUses( } Status LayoutAssignment::PropagateResultConstraint( - const ResultLayoutConstraint& result_constraint, + const ResultLayoutConstraint& layout_constraint, LayoutConstraints* constraints) { // Propagate the use constraint of the root instruction up to the logical // buffers which make up the result. return PropagateUseConstraintToDefs( - result_constraint.shape_layout(), + layout_constraint.shape_layout(), constraints->computation()->root_instruction(), constraints); } @@ -1960,6 +2096,16 @@ bool LayoutAssignment::InstructionCanChangeLayout( } } +/* static */ +bool LayoutAssignment::IsAtMostRank1(const Shape& shape) { + if (ShapeUtil::IsArray(shape)) { + return ShapeUtil::Rank(shape) <= 1; + } + return absl::c_all_of(shape.tuple_shapes(), [](const Shape& subshape) { + return IsAtMostRank1(subshape); + }); +} + Status LayoutAssignment::Init() { computation_layouts_.clear(); *entry_computation_layout_ = saved_entry_computation_layout_; diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index cb56f4cd19ded0..3b081de3c7826c 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -315,6 +315,10 @@ class LayoutAssignment : public HloModulePass { // rank as the output to have the same layout as the output. static bool InstructionCanChangeLayout(const HloInstruction* instruction); + // In case of an array shape returns true iff it is at most rank 1. In case of + // a tuple shape returns true iff all leaf shapes are at most rank 1. + static bool IsAtMostRank1(const Shape& shape); + protected: // These methods, invoked by PropagateConstraints, propagate a layout // constraint to its neighbors (i.e. operands and users) in order to minimize @@ -362,7 +366,7 @@ class LayoutAssignment : public HloModulePass { // `user` that minimizes its cost on that operand. Returns null if it can't // decide the best layout. // Precondition: `user` and the operand are array-shaped. - std::unique_ptr ChooseOutputLayoutFromOperandLayout( + virtual std::unique_ptr ChooseOutputLayoutFromOperandLayout( const Layout& operand_layout, const HloInstruction* user, int64 operand_no); @@ -408,6 +412,10 @@ class LayoutAssignment : public HloModulePass { // required for correctness. Status PropagateConstraints(LayoutConstraints* constraints); + Status PropagateBufferConstraintToOperands( + const BufferLayoutConstraint& buffer_constraint, + LayoutConstraints* constraints); + // Check that all layouts in the module have been set and satisfy all // necessary conditions. Status CheckLayouts(HloModule* module); diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index a831751fa96f8c..11c57682c11577 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -897,11 +897,11 @@ TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) { param = (f32[2,2]) parameter(0) gte = f32[2,2] get-tuple-element(param), index=0 ar.0 = f32[2,2] cross-replica-sum(gte), - all_reduce_id=0, replica_groups={{0}}, to_apply=add, + all_reduce_id=1, replica_groups={{0}}, to_apply=add, sharding={maximal device=0} const = f32[2,2] constant(f32[2,2]{{0,1},{2,3}}) ROOT ar.1 = f32[2,2] cross-replica-sum(const), - all_reduce_id=0, replica_groups={{0}}, to_apply=add, + all_reduce_id=1, replica_groups={{0}}, to_apply=add, sharding={maximal device=1} })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -1291,5 +1291,59 @@ ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0}, ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}}); } +Status AssignLayoutsToComputation( + HloModule* module, + ChannelLayoutConstraints* channel_constraints = nullptr) { + if (!module->entry_computation_layout().result_layout().LayoutIsSet()) { + module->mutable_entry_computation_layout() + ->mutable_result_layout() + ->SetToDefaultLayout(); + } + LayoutAssignment layout_assignment( + module->mutable_entry_computation_layout(), + LayoutAssignment::InstructionCanChangeLayout, channel_constraints); + return layout_assignment.Run(module).status(); +} + +TEST_F(LayoutAssignmentTest, OverwriteDiamondShapedConstraintsX) { + // Check that we handle a diamond-shaped graph correctly. + // transpose + // / \ + // add | + // \ / + // tuple + + auto b = HloComputation::Builder(TestName()); + Shape ashape = ShapeUtil::MakeShape(F32, {12, 8}); + Shape bshape = ShapeUtil::MakeShape(F32, {8, 12}); + auto param0 = + b.AddInstruction(HloInstruction::CreateParameter(0, bshape, "input")); + auto param1 = + b.AddInstruction(HloInstruction::CreateParameter(1, ashape, "input")); + auto transpose = + b.AddInstruction(HloInstruction::CreateTranspose(ashape, param0, {1, 0})); + auto add = b.AddInstruction( + HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, transpose, param1)); + b.AddInstruction(HloInstruction::CreateTuple({add, transpose})); + auto module = CreateNewVerifiedModule(); + module->AddEntryComputation(b.Build()); + Shape ashape_major = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {1, 0}); + Shape ashape_minor = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {0, 1}); + *module->mutable_entry_computation_layout()->mutable_result_layout() = + ShapeLayout(ShapeUtil::MakeTupleShape({ashape_major, ashape_minor})); + const Layout r2_dim0major = LayoutUtil::MakeLayout({1, 0}); + ForceParameterLayout(module.get(), 0, r2_dim0major); + ForceParameterLayout(module.get(), 1, r2_dim0major); + TF_ASSERT_OK(AssignLayoutsToComputation(module.get())); + + EXPECT_THAT(add->shape().layout().minor_to_major(), ElementsAre(1, 0)); + EXPECT_THAT(add->operand(0)->shape().layout().minor_to_major(), + ElementsAre(1, 0)); + EXPECT_THAT(add->operand(1)->shape().layout().minor_to_major(), + ElementsAre(1, 0)); + + EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(0, 1)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 850501a4b5c521..56a729bca8ec04 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -169,6 +169,7 @@ cc_library( "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@llvm//:core", ], diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc index cc2e862f2eb9a4..4d7f36d9f8b565 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -130,7 +130,8 @@ Status EmitDynamicUpdateSliceInPlace(absl::Span operand_arrays, // // Emits a sequential loop if launch_dimensions is null. static Status EmitFusedDynamicUpdateSliceInPlaceImpl( - HloInstruction* fusion, absl::Span fusion_operand_arrays, + HloInstruction* fusion, + GeneratorForOperandIrArrays operand_arrays_generator, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, const gpu::LaunchDimensions* launch_dimensions, llvm::IRBuilder<>* b) { CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); @@ -160,7 +161,8 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl( LayoutUtil::CopyLayoutBetweenShapes(fusion->shape(), &update_shape)); // Create element generators for update and start_indices. - FusedIrEmitter fused_emitter(fusion_operand_arrays, elemental_emitter); + FusedIrEmitter fused_emitter(std::move(operand_arrays_generator), + elemental_emitter); TF_RETURN_IF_ERROR(dynamic_update_slice->Accept(&fused_emitter)); ElementGenerator update_array_generator = fused_emitter.GetGenerator(update); ElementGenerator start_indices_generator = @@ -173,21 +175,24 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl( } Status EmitFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, absl::Span fusion_operand_arrays, + HloInstruction* fusion, + GeneratorForOperandIrArrays operand_arrays_generator, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, llvm::IRBuilder<>* b) { return EmitFusedDynamicUpdateSliceInPlaceImpl( - fusion, fusion_operand_arrays, fusion_output_array, elemental_emitter, + fusion, std::move(operand_arrays_generator), fusion_output_array, + elemental_emitter, /*launch_dimensions=*/nullptr, b); } Status EmitParallelFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, absl::Span fusion_operand_arrays, + HloInstruction* fusion, + GeneratorForOperandIrArrays operand_arrays_generator, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b) { return EmitFusedDynamicUpdateSliceInPlaceImpl( - fusion, fusion_operand_arrays, fusion_output_array, elemental_emitter, - &launch_dimensions, b); + fusion, std::move(operand_arrays_generator), fusion_output_array, + elemental_emitter, &launch_dimensions, b); } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h index fb3e4eb97cae06..7fe803d1f8da52 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h @@ -27,6 +27,9 @@ limitations under the License. namespace xla { namespace llvm_ir { +using GeneratorForOperandIrArrays = + std::function()>; + // Checks if we can emit code for the given DynamicUpdateSlice node that updates // its input in place. Returns true if the dynamic-update-slice's // array-to-be-updated and output share the same BufferAllocation::Slice. @@ -73,14 +76,16 @@ Status EmitDynamicUpdateSliceInPlace(absl::Span operand_arrays, // (sequential) code for a fusion node that does the dynamic-update-slice in // place. Status EmitFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, absl::Span fusion_operand_arrays, + HloInstruction* fusion, + GeneratorForOperandIrArrays operand_arrays_generator, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, llvm::IRBuilder<>* b); // Same as EmitFusedDynamicUpdateSliceInPlace, except emits a parallel loop with // the given launch dimensions. Status EmitParallelFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, absl::Span fusion_operand_arrays, + HloInstruction* fusion, + GeneratorForOperandIrArrays operand_arrays_generator, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b); diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index b0a492c70b8ddc..38f2b5da23a7b9 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -72,16 +72,17 @@ Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) { } Status FusedIrEmitter::HandleConstant(HloInstruction* constant) { - const Literal& literal = constant->literal(); - llvm::Constant* initializer = - llvm_ir::ConvertLiteralToIrConstant(literal, module_); - llvm::GlobalVariable* global = new llvm::GlobalVariable( - *b_->GetInsertBlock()->getModule(), initializer->getType(), - /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer, - /*Name=*/""); - llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast( - global, llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo()); indexed_generators_[constant] = [=](const IrArray::Index& index) { + const Literal& literal = constant->literal(); + llvm::Constant* initializer = + llvm_ir::ConvertLiteralToIrConstant(literal, module_); + llvm::GlobalVariable* global = new llvm::GlobalVariable( + *b_->GetInsertBlock()->getModule(), initializer->getType(), + /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer, + /*Name=*/""); + llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast( + global, + llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo()); return IrArray(shape_constant, constant->shape()) .EmitReadArrayElement(index, b_); }; @@ -105,7 +106,7 @@ Status FusedIrEmitter::HandleGetTupleElement( tuple_operand->name()); } tuple_ptr = - parameter_arrays_[tuple_operand->parameter_number()].GetBasePointer(); + GetBasePointerForFusedParameter(tuple_operand->parameter_number()); } // Lookup tuple element pointer. @@ -148,7 +149,7 @@ Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) { "tiled_buffer"); } } - return parameter_arrays_[parameter->parameter_number()] + return GetIrArrayForFusedParameter(parameter->parameter_number()) .EmitReadArrayElement(index, b_); }; return Status::OK(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index 076c449c1e42eb..1b9c61f6700e2a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/types/optional.h" #include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" @@ -54,10 +55,13 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { public: using IndexedGenerator = llvm_ir::ElementGenerator; using NonIndexedGenerator = std::function()>; + using GeneratorForOperandIrArrays = + std::function()>; - FusedIrEmitter(absl::Span parameter_arrays, + FusedIrEmitter(GeneratorForOperandIrArrays operand_arrays_generator, ElementalIrEmitter* elemental_emitter) - : parameter_arrays_(parameter_arrays), + : operand_arrays_(), + operand_arrays_generator_(std::move(operand_arrays_generator)), tiled_parameter_info_(nullptr), elemental_emitter_(elemental_emitter), b_(elemental_emitter->b()), @@ -86,9 +90,25 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { tiled_parameter_info_ = info; } + protected: + // Returns the IrArrays for the fusion instruction operands. + llvm_ir::IrArray& GetIrArrayForFusedParameter(int64 parameter_number) { + if (!operand_arrays_.has_value()) { + operand_arrays_ = operand_arrays_generator_(); + } + return operand_arrays_.value()[parameter_number]; + } + + llvm::Value* GetBasePointerForFusedParameter(int64 parameter_number) { + return GetIrArrayForFusedParameter(parameter_number).GetBasePointer(); + } + private: - // Arrays of parameters of fusion instruction - absl::Span parameter_arrays_; + // IrArrays for the fusion instruction operands, whose base addresses are the + // base address of the corresponding parameters in the fused computation. + absl::optional> operand_arrays_; + GeneratorForOperandIrArrays operand_arrays_generator_; + const llvm_ir::TiledParameterInfo* tiled_parameter_info_; ElementalIrEmitter* elemental_emitter_; diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index 2ca527bc4cb8f6..6088fa4df66a6e 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/platform/types.h" @@ -257,7 +258,7 @@ bool MultiOutputFusion::LegalToFuse(HloInstruction* instr1, } void MultiOutputFusion::RecomputeReachability() { - reachability_ = computation_->ComputeReachability(); + reachability_ = HloReachabilityMap::Build(computation_); } void MultiOutputFusion::UpdateReachability( diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index 9508ab2ed1d38e..1c7583ece720f9 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -23,6 +23,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index f952e64af2b675..49f0b8f8b72001 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -95,7 +95,13 @@ class TransferManager { // but need not have the same layout. // // This operation is performed asynchronously on the given stream. It returns - // once the transfer is enqueued. + // once the transfer is enqueued, and may return before the transfer has + // completed. + // + // The caller may free the data structures 'literal' and 'device_buffer' + // immediately after this function returns, however their constituent buffers + // on both host and device must remain valid until the enqueued transfer has + // completed on 'stream'. virtual Status TransferLiteralToDeviceAsync( se::Stream* stream, const LiteralSlice& literal, const ShapedBuffer& device_buffer) = 0; diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc index 0e7667de832c54..d17b86fab5b14d 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -114,7 +114,7 @@ HloModule ModuleWithWhile body { p_b = (f32[2],(f32[2],f32[2])) parameter(0) - p_b.0 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_b), index=0 + p_b.0 = f32[2] get-tuple-element((f32[2],(f32[2],f32[2])) p_b), index=0 p_b.1 = (f32[2],f32[2]) get-tuple-element((f32[2],(f32[2],f32[2])) p_b), index=1 p_b.1.1 = f32[2] get-tuple-element(p_b.1), index=0 diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index d7aa29d8ee9be1..17120e610cb26d 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -116,16 +116,6 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts, VLOG(3) << "CompareShapes: lhs layout != rhs layout"; return false; } - if (!absl::c_equal(lhs.layout().padded_dimensions(), - rhs.layout().padded_dimensions())) { - VLOG(3) - << "CompareShapes: lhs padded_dimensions != rhs padded_dimensions"; - return false; - } - if (lhs.layout().padding_value() != rhs.layout().padding_value()) { - VLOG(3) << "CompareShapes: lhs padding value != rhs padding_value"; - return false; - } } } @@ -818,17 +808,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout()); } else { CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString(); - absl::Span padded_dimensions = - LayoutUtil::PaddedDimensions(shape); - if (!padded_dimensions.empty()) { - CHECK_EQ(Rank(shape), padded_dimensions.size()); - allocated_element_count = 1; - for (int64 dimension_size : padded_dimensions) { - allocated_element_count *= dimension_size; - } - } else { - allocated_element_count = ElementsIn(shape); - } + allocated_element_count = ElementsIn(shape); } return allocated_element_count * ByteSizeOfPrimitiveType(shape.element_type()); @@ -946,12 +926,8 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { return dense_shape_size; } - bool is_padded = shape_has_valid_layout && - LayoutUtil::IsDenseArray(shape) && - LayoutUtil::IsPadded(shape); absl::Span shape_max_dimensions = - is_padded ? LayoutUtil::PaddedDimensions(shape) - : AsInt64Slice(shape.dimensions()); + AsInt64Slice(shape.dimensions()); for (int64 dim : shape_max_dimensions) { dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, dim); if (dense_shape_size < 0) { @@ -1193,13 +1169,6 @@ Status ForEachMutableSubshapeHelper( permutation, AsInt64Slice(shape.layout().minor_to_major()))) { new_layout->add_minor_to_major(index); } - if (shape.layout().padded_dimensions_size() > 0) { - new_layout->clear_padded_dimensions(); - for (auto dim : - Permute(permutation, shape.layout().padded_dimensions())) { - new_layout->add_padded_dimensions(dim); - } - } // The permutation accepted by TransposeIsBitcast is the inverse of the // permutation here. CHECK(TransposeIsBitcast(shape, new_shape, InversePermutation(permutation))) @@ -1302,11 +1271,6 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, return false; } - // Padding is not handled. - if (LayoutUtil::IsPadded(input_shape) && LayoutUtil::IsPadded(output_shape)) { - return false; - } - // Check the reshape permutes the positions of each dimension in the // minor-to-major order. positions[i]=k means dimension `i` is k-th minor. // input_positions = apply(dimension_mapping, output_positions) @@ -1338,11 +1302,6 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, return false; } - // Padding is not handled. - if (LayoutUtil::IsPadded(input_shape) || LayoutUtil::IsPadded(output_shape)) { - return false; - } - CHECK_EQ(ElementsIn(input_shape), ElementsIn(output_shape)); if (ElementsIn(input_shape) == 0) { return true; diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index c622ecdca1fd66..0c647369a37e70 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -345,26 +345,6 @@ TEST(ShapeUtilTest, OpaqueVsArray) { EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(shape2, shape1)); } -TEST(ShapeUtilTest, CompareShapesWithPaddedDimensionsMismatch) { - Shape shape1 = ShapeUtil::MakeShape(F32, {20, 30}); - shape1.mutable_layout()->add_padded_dimensions(10); - - Shape shape2 = ShapeUtil::MakeShape(F32, {20, 30}); - shape2.mutable_layout()->add_padded_dimensions(11); - - EXPECT_FALSE(ShapeUtil::Equal(shape1, shape2)); -} - -TEST(ShapeUtilTest, CompareShapesWithPaddingValueMismatch) { - Shape shape1 = ShapeUtil::MakeShape(F32, {20, 30}); - shape1.mutable_layout()->set_padding_value(ZERO_PAD); - - Shape shape2 = ShapeUtil::MakeShape(F32, {20, 30}); - shape2.mutable_layout()->set_padding_value(LOWEST_PAD); - - EXPECT_FALSE(ShapeUtil::Equal(shape1, shape2)); -} - TEST(ShapeUtilTest, ScalarDefaultLayoutEqualsScalarEmptyMin2Maj) { Shape scalar_default_layout = ShapeUtil::MakeShape(F32, {}); ASSERT_TRUE(scalar_default_layout.has_layout()) @@ -395,16 +375,6 @@ TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) { EXPECT_EQ(0, ShapeUtil::ByteSizeOf(ShapeUtil::MakeTokenShape())); } -TEST(ShapeUtilTest, ByteSizeOfWithPadding) { - EXPECT_EQ(4, ShapeUtil::ByteSizeOfPrimitiveType(F32)); - Shape shape = ShapeUtil::MakeShape(F32, {10, 20}); - EXPECT_EQ(800, ShapeUtil::ByteSizeOf(shape)); - - shape.mutable_layout()->add_padded_dimensions(15); - shape.mutable_layout()->add_padded_dimensions(21); - EXPECT_EQ(15 * 21 * 4, ShapeUtil::ByteSizeOf(shape)); -} - TEST(ShapeUtilTest, NilShape) { EXPECT_TRUE(ShapeUtil::IsNil(ShapeUtil::MakeNil())); EXPECT_FALSE(ShapeUtil::IsNil(ShapeUtil::MakeShape(F32, {1, 2, 3}))); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 5c6183984ff5d0..d395c9a4ceecfb 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1177,6 +1177,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index c131bfd6a6e6d8..2180b22cb3bc2e 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -2478,8 +2478,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { Ne(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,2] { - { 00 }, - { 01 } + { 0, 0 }, + { 0, 1 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -2492,8 +2492,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) { Ge(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { - { 1100 }, - { 0001 } + { 1, 1, 0, 0 }, + { 0, 0, 0, 1 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -2506,8 +2506,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) { Gt(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { - { 0100 }, - { 0000 } + { 0, 1, 0, 0 }, + { 0, 0, 0, 0 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -2520,8 +2520,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) { Le(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { - { 1011 }, - { 1111 } + { 1, 0, 1, 1 }, + { 1, 1, 1, 1 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -2534,8 +2534,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) { Lt(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { - { 0011 }, - { 1110 } + { 0, 0, 1, 1 }, + { 1, 1, 1, 0 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -2744,12 +2744,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) { Array3D expected_3d( {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}}); const string expected = R"(pred[2,3,2] { -{ { 01 }, - { 00 }, - { 00 } }, -{ { 01 }, - { 10 }, - { 01 } } +{ { 0, 1 }, + { 0, 0 }, + { 0, 0 } }, +{ { 0, 1 }, + { 1, 0 }, + { 0, 1 } } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc index 58539e6b061b0c..774eb8d2a85914 100644 --- a/tensorflow/compiler/xla/tests/pred_test.cc +++ b/tensorflow/compiler/xla/tests/pred_test.cc @@ -87,8 +87,8 @@ TEST_F(PredTest, ConstantR2Pred) { XlaBuilder builder(TestName()); ConstantR2(&builder, {{false, true, true}, {true, false, false}}); const string expected = R"(pred[2,3] { - { 011 }, - { 100 } + { 0, 1, 1 }, + { 1, 0, 0 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 83997cdac21c43..18c99490a38792 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -32,6 +32,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" @@ -980,5 +981,25 @@ XLA_TEST_F(ReduceTest, OrReduceU64) { ComputeAndCompareR1(&builder, expected, {}); } +XLA_TEST_F(ReduceTest, R0ReduceInDisguise) { + XlaBuilder builder(TestName()); + XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); + constexpr int element_count = 127; + const Shape input_shape = ShapeUtil::MakeShape(F32, {element_count, 1}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto zero = ConstantR0(&builder, 0.0); + Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0}); + + Array2D input_data(element_count, 1); + input_data.FillRandom(3.0f); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); + std::unique_ptr input_global_data = + client_->TransferToServer(input_literal).ConsumeValueOrDie(); + + float expected = absl::c_accumulate(input_data, 0.0f); + ComputeAndCompareR1(&builder, {expected}, {input_global_data.get()}, + ErrorSpec(0.001)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index a6e70eb6ca25ff..376559500efad6 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -99,8 +99,8 @@ Status ParseOneProfileOutputLine( // %dot33 = f32[256,256]{1,0} dot(...) // ^^^ - string match_opcode = - expect_hlo ? "%[^=]+= [^ ]+ ([^(]+)\\(.*" : "(\\[total\\])"; + string match_opcode = expect_hlo ? "%[^=]+= [^ ]+ ([^(]+)\\(.*" + : "(\\[total\\])( \\[entry\\])?"; string regexp_pattern = absl::StrCat( " +", match_cycles, separator, match_usecs, separator, match_flops, separator, match_trops, separator, match_bytes_per_sec, separator, @@ -125,6 +125,10 @@ Status ParseOneProfileOutputLine( return Status::OK(); } +bool IsExtraMetricProfileOutputLine(const string& line) { + return RE2::FullMatch(line, "Extra metric \\S+: \\d+"); +} + // Returns void so that we can ASSERT. void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, const XlaComputation& computation, @@ -210,14 +214,26 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { absl::flat_hash_map parsed_profile_lines; - TF_ASSERT_OK(ParseOneProfileOutputLine( - profile_output_lines[1], /*expect_hlo=*/false, &parsed_profile_lines)); + int line_no = 0; + + // Skip extra metrics. + while (IsExtraMetricProfileOutputLine(profile_output_lines[line_no])) { + line_no++; + } + + line_no++; // Skip 'Execution profile for ....' + + TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++], + /*expect_hlo=*/false, + &parsed_profile_lines)); - TF_ASSERT_OK(ParseOneProfileOutputLine( - profile_output_lines[2], /*expect_hlo=*/true, &parsed_profile_lines)); + TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++], + /*expect_hlo=*/true, + &parsed_profile_lines)); - TF_ASSERT_OK(ParseOneProfileOutputLine( - profile_output_lines[3], /*expect_hlo=*/true, &parsed_profile_lines)); + TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++], + /*expect_hlo=*/true, + &parsed_profile_lines)); TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_profile, MaybeFind(parsed_profile_lines, "[total]")); diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 73b3589dbf1234..b6bd919e2b26a1 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -78,28 +78,6 @@ enum PrimitiveType { // Next = 18 } -// Describes the value held inside padding elements. -enum PaddingValue { - INVALID_PAD = 0; - - // Zero padding must be 0-values that correspond to the shape's element type. - ZERO_PAD = 1; - - // One padding must be 1-values that correspond to the shape's element type. - ONE_PAD = 2; - - // "Lowest" padding must be the lowest values in the shape's element type, - // used as padding for operations like max-accumulation. - LOWEST_PAD = 3; - - // "Highest" padding must be the largest values in the shape's element type, - // used as padding for operations like min-accumulation. - HIGHEST_PAD = 4; - - // Unknown padding could be anything; e.g. floating NaNs! - UNKNOWN_PAD = 5; -} - // Describes the padding configuration for Pad operation. The padding amount on // both edges as well as between the elements are specified for each dimension. message PaddingConfig { @@ -123,8 +101,7 @@ message PaddingConfig { // A format specifies the method used by a layout to store an array in memory. enum Format { INVALID_FORMAT = 0; - // The default layout, with exactly one storage location per element (ignoring - // padding). + // The default layout, with exactly one storage location per element. DENSE = 1; // A sparsely encoded layout, providing only the index/value pairs of non-zero // elements. @@ -132,8 +109,7 @@ enum Format { } // A layout describes how the array is placed in (1D) memory space. This -// includes the minor-to-major ordering of dimensions within a shape, as well as -// any padding present in those dimensions. +// includes the minor-to-major ordering of dimensions within a shape. // // Clients must specify the layouts of input Literals to the // computation. Layouts specified in interior operations which take Shapes (for @@ -151,16 +127,11 @@ message Layout { // (slowest varying index). This field is required. repeated int64 minor_to_major = 1; - // The width to which the layout of each dimension is padded up to. If - // present, the size of the padded_dimensions must equal the rank of the - // shape. The padding appears at the end of a dimension, not at the - // beginning. This kind of padding, unlike padding in e.g. convolution, is not - // part of the shape. This field must be unset unless the format is DENSE. - repeated int64 padded_dimensions = 2; + reserved 2; + reserved "padded_dimensions"; - // Describes the values in the padding specified by padded_dimensions. This - // field must be unset unless the format is DENSE. - PaddingValue padding_value = 3; + reserved 3; + reserved "padding_value"; // The maximum number of elements that can be stored for SPARSE formats. This // can be used to determine the maximum size in bytes of arrays stored in diff --git a/tensorflow/compiler/xla/xlalogo.png b/tensorflow/compiler/xla/xlalogo.png deleted file mode 100644 index 7a0a295953d0c4..00000000000000 Binary files a/tensorflow/compiler/xla/xlalogo.png and /dev/null differ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py index 5d4819b0f1cb59..efa2ab1dad8df9 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py @@ -19,15 +19,17 @@ from __future__ import print_function import abc + +import six + from tensorflow.contrib.boosted_trees.python.ops import batch_ops_utils from tensorflow.python.ops import control_flow_ops +@six.add_metaclass(abc.ABCMeta) class BaseSplitHandler(object): """Abstract Base class defining split handlers interface.""" - __metaclass__ = abc.ABCMeta - def __init__(self, l1_regularization, l2_regularization, diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index f45010ec26ed25..1fffbb5f660c68 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -142,7 +142,7 @@ def __init__(self, name="StatsAccumulator/{}".format(self._name)) # Allocate both stats accumulator and quantile accumulator on the same # device so that we can build splits with fewer RPCs. - with ops.colocate_with(self._stats_accumulator.resource()): + with ops.colocate_with(self._stats_accumulator.resource_handle): self._quantile_accumulator = quantile_ops.QuantileAccumulator( init_stamp_token, epsilon=epsilon, @@ -268,8 +268,8 @@ def make_splits(self, stamp_token, next_stamp_token, class_id): handler = make_dense_split_tensor are_splits_ready, partition_ids, gains, split_infos = ( - handler(self._quantile_accumulator.resource(), - self._stats_accumulator.resource(), stamp_token, + handler(self._quantile_accumulator.resource_handle, + self._stats_accumulator.resource_handle, stamp_token, next_stamp_token, self._multiclass_strategy, class_id, self._feature_column_group_id, self._l1_regularization, self._l2_regularization, self._tree_complexity_regularization, @@ -447,8 +447,8 @@ def make_splits(self, stamp_token, next_stamp_token, class_id): handler = make_sparse_split_tensor are_splits_ready, partition_ids, gains, split_infos = ( - handler(self._quantile_accumulator.resource(), - self._stats_accumulator.resource(), stamp_token, + handler(self._quantile_accumulator.resource_handle, + self._stats_accumulator.resource_handle, stamp_token, next_stamp_token, self._multiclass_strategy, class_id, self._feature_column_group_id, self._l1_regularization, self._l2_regularization, self._tree_complexity_regularization, diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py index 05ce0884ccfff5..356ae337685d58 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py @@ -34,7 +34,7 @@ def testSimpleAcculumator(self): stamp_token=0, gradient_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.scalar()) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -62,7 +62,7 @@ def testMultidimensionalAcculumator(self): stamp_token=0, gradient_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.scalar()) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2, 1], @@ -91,7 +91,7 @@ def testDropStaleUpdate(self): stamp_token=0, gradient_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.scalar()) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -123,7 +123,7 @@ def testSerialize(self): stamp_token=0, gradient_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.scalar()) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -133,7 +133,7 @@ def testSerialize(self): with ops.control_dependencies([op1]): (stamp_token, num_updates, partition_1, feature_1, grads_1, - hessians_1) = accumulator.serialize() + hessians_1) = accumulator.saveable.serialize() # Make sure that the accumulator hasn't changed during serialization. with ops.control_dependencies([stamp_token]): num_updates_2, partition_2, feature_2, grads_2, hessians_2 = ( @@ -164,7 +164,7 @@ def testDeserialize(self): stamp_token=0, gradient_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.scalar()) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): # These will be deleted due to deserialize call. op1 = accumulator.add( stamp_token=0, @@ -175,7 +175,7 @@ def testDeserialize(self): with ops.control_dependencies([op1]): deserialize = ( - accumulator.deserialize( + accumulator.saveable.deserialize( stamp_token=2, num_updates=3, partition_ids=[3, 4], @@ -223,7 +223,7 @@ def testSimpleAcculumator(self): stamp_token=0, gradient_shape=tensor_shape.TensorShape([2]), hessian_shape=tensor_shape.TensorShape([2, 2])) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -261,7 +261,7 @@ def testMultidimensionalAcculumator(self): stamp_token=0, gradient_shape=tensor_shape.TensorShape([2]), hessian_shape=tensor_shape.TensorShape([2, 2])) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -299,7 +299,7 @@ def testDropStaleUpdate(self): stamp_token=0, gradient_shape=tensor_shape.TensorShape([2]), hessian_shape=tensor_shape.TensorShape([2, 2])) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -336,7 +336,7 @@ def testSerialize(self): stamp_token=0, gradient_shape=tensor_shape.TensorShape([2]), hessian_shape=tensor_shape.TensorShape([2, 2])) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -349,7 +349,7 @@ def testSerialize(self): with ops.control_dependencies([op1]): (stamp_token, num_updates_1, partition_1, feature_1, grads_1, - hessians_1) = accumulator.serialize() + hessians_1) = accumulator.saveable.serialize() # Make sure that the accumulator hasn't changed during serialization. with ops.control_dependencies([stamp_token]): num_updates_2, partition_2, feature_2, grads_2, hessians_2 = ( @@ -386,7 +386,7 @@ def testDeserialize(self): stamp_token=0, gradient_shape=tensor_shape.TensorShape([2]), hessian_shape=tensor_shape.TensorShape([2, 2])) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): # These will be deleted due to deserialize call. op1 = accumulator.add( stamp_token=0, @@ -399,7 +399,7 @@ def testDeserialize(self): 0.08]]]) with ops.control_dependencies([op1]): - deserialize = accumulator.deserialize( + deserialize = accumulator.saveable.deserialize( stamp_token=2, num_updates=3, partition_ids=[3, 4], diff --git a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py index 843420968ac6a6..4dc764f95713ab 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py +++ b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py @@ -20,6 +20,8 @@ import abc import collections +import six + from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -27,11 +29,10 @@ from tensorflow.python.ops import array_ops +@six.add_metaclass(abc.ABCMeta) class ScheduledOp(object): """Represents a scheduled remote operation.""" - __metaclass__ = abc.ABCMeta - @abc.abstractmethod def batching_key(self): """Returns the key for batching operations.""" diff --git a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py index 25b2c9e2fd72bd..fca22c71a83459 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py @@ -17,6 +17,8 @@ from __future__ import division from __future__ import print_function +import functools + # pylint: disable=unused-import from tensorflow.contrib.boosted_trees.python.ops import boosted_trees_ops_loader # pylint: enable=unused-import @@ -31,6 +33,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import resources from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import tracking ops.NotDifferentiable("TreeEnsembleVariable") ops.NotDifferentiable("TreeEnsembleSerialize") @@ -82,6 +85,44 @@ def restore(self, restored_tensors, unused_restored_shapes): tree_ensemble_config=restored_tensors[1]) +class TreeEnsembleVariable(tracking.TrackableResource): + """A Tree ensemble model.""" + + def __init__(self, stamp_token, tree_ensemble_config, name, container=None): + self._stamp_token = stamp_token + self._tree_ensemble_config = tree_ensemble_config + self._name = name + self._container = container + self._init_op = None + super(TreeEnsembleVariable, self).__init__() + + def create_resource(self): + return gen_model_ops.decision_tree_ensemble_resource_handle_op( + self._container, shared_name=self._name, name=self._name) + + def initialize(self): + return gen_model_ops.create_tree_ensemble_variable( + self.resource_handle, self._stamp_token, self._tree_ensemble_config) + + @property + def initializer(self): + if self._init_op is None: + self._init_op = self.initialize() + return self._init_op + + def is_initialized(self): + return gen_model_ops.tree_ensemble_is_initialized_op(self.resource_handle) + + def _gather_saveables_for_checkpoint(self): + return { + "tree_ensemble_variable": + functools.partial( + TreeEnsembleVariableSavable, + tree_ensemble_handle=self.resource_handle, + create_op=self.initializer) + } + + def tree_ensemble_variable(stamp_token, tree_ensemble_config, name, @@ -99,12 +140,11 @@ def tree_ensemble_variable(stamp_token, A `Tensor` of type mutable `string`. The handle to the tree ensemble. """ with ops.name_scope(name, "TreeEnsembleVariable") as name: - resource_handle = gen_model_ops.decision_tree_ensemble_resource_handle_op( - container, shared_name=name, name=name) - create_op = gen_model_ops.create_tree_ensemble_variable( - resource_handle, stamp_token, tree_ensemble_config) - is_initialized_op = gen_model_ops.tree_ensemble_is_initialized_op( - resource_handle) + tree_ensemble_var = TreeEnsembleVariable(stamp_token, tree_ensemble_config, + name, container) + resource_handle = tree_ensemble_var.resource_handle + create_op = tree_ensemble_var.initializer + is_initialized_op = tree_ensemble_var.is_initialized() # Adds the variable to the savable list. saveable = TreeEnsembleVariableSavable(resource_handle, create_op, resource_handle.name) diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 19b6b3296db394..0c319cc9bd1f72 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -33,59 +33,20 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import resources from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import tracking # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") -class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): - """A resource that allows distributed quantile computation.""" - - def __init__(self, - init_stamp_token, - epsilon, - num_quantiles, - max_elements=None, - name=None, - container=None, - generate_quantiles=False): - """Creates a QuantileAccumulator object. - - Args: - init_stamp_token: The initial value for the stamp token. - epsilon: Error bound on the quantile computation. - num_quantiles: Number of quantiles to produce from the final summary. - max_elements: Maximum number of elements added to the accumulator. - name: the name to save the accumulator under. - container: An optional `string`. Defaults to `""` - generate_quantiles: Generate quantiles instead of approximate boundaries. - If true, exactly `num_quantiles` will be produced in the final summary. - """ - self._epsilon = epsilon - self._generate_quantiles = generate_quantiles +class QuantileAccumulatorSaveable(saver.BaseSaverBuilder.SaveableObject): + """SaveableObject implementation for QuantileAccumulator.""" - name = _PATTERN.sub("", name) - with ops.name_scope(name, "QuantileAccumulator") as name: - self._quantile_accumulator_handle = ( - gen_quantile_ops.quantile_stream_resource_handle_op( - container=container, shared_name=name, name=name)) - self._create_op = gen_quantile_ops.create_quantile_accumulator( - self._quantile_accumulator_handle, - init_stamp_token, - epsilon=epsilon, - max_elements=max_elements, - num_quantiles=num_quantiles, - generate_quantiles=generate_quantiles) - is_initialized_op = gen_quantile_ops.quantile_accumulator_is_initialized( - self._quantile_accumulator_handle) - resources.register_resource(self._quantile_accumulator_handle, - self._create_op, is_initialized_op) - self._make_savable(name) - - def _make_savable(self, name): + def __init__(self, resource_handle, create_op, name): + self._resource_handle = resource_handle + self._create_op = create_op stamp_token, state, are_buckets_ready, buckets = ( - gen_quantile_ops.quantile_accumulator_serialize( - self._quantile_accumulator_handle)) + gen_quantile_ops.quantile_accumulator_serialize(resource_handle)) # slice_spec is useful for saving a slice from a variable. # It's not meaningful in quantile accumulator. slice_spec = "" @@ -96,9 +57,8 @@ def make_save_spec(tensor, suffix): specs += [make_save_spec(state, "_state")] specs += [make_save_spec(are_buckets_ready, "_are_buckets_ready")] specs += [make_save_spec(buckets, "buckets")] - super(QuantileAccumulator, - self).__init__(self._quantile_accumulator_handle, specs, name) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self) + super(QuantileAccumulatorSaveable, self).__init__(self._resource_handle, + specs, name) def restore(self, restored_tensors, unused_restored_shapes): """Restores the associated quantile accumulator from 'restored_tensors'. @@ -119,24 +79,94 @@ def restore(self, restored_tensors, unused_restored_shapes): buckets = restored_tensors[3] with ops.control_dependencies([self._create_op]): return gen_quantile_ops.quantile_accumulator_deserialize( - self._quantile_accumulator_handle, + self._resource_handle, stamp_token=stamp_token, stream_state=state, are_buckets_ready=are_buckets_ready, buckets=buckets) + +class QuantileAccumulator(tracking.TrackableResource): + """A resource that allows distributed quantile computation.""" + + def __init__(self, + init_stamp_token, + epsilon, + num_quantiles, + max_elements=None, + name=None, + container=None, + generate_quantiles=False): + """Creates a QuantileAccumulator object. + + Args: + init_stamp_token: The initial value for the stamp token. + epsilon: Error bound on the quantile computation. + num_quantiles: Number of quantiles to produce from the final summary. + max_elements: Maximum number of elements added to the accumulator. + name: the name to save the accumulator under. + container: An optional `string`. Defaults to `""` + generate_quantiles: Generate quantiles instead of approximate boundaries. + If true, exactly `num_quantiles` will be produced in the final summary. + """ + self._init_stamp_token = init_stamp_token + self._epsilon = epsilon + self._num_quantiles = num_quantiles + self._max_elements = max_elements + self._container = container + self._generate_quantiles = generate_quantiles + super(QuantileAccumulator, self).__init__() + + name = _PATTERN.sub("", name) + with ops.name_scope(name, "QuantileAccumulator") as name: + self._name = name + self._resource_handle = self.create_resource() + self._init_op = self.initialize() + is_initialized_op = self.is_initialized() + resources.register_resource(self.resource_handle, self._init_op, + is_initialized_op) + self._saveable = QuantileAccumulatorSaveable(self.resource_handle, + self._init_op, name) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) + + def create_resource(self): + return gen_quantile_ops.quantile_stream_resource_handle_op( + container=self._container, shared_name=self._name, name=self._name) + + def initialize(self): + return gen_quantile_ops.create_quantile_accumulator( + self.resource_handle, + self._init_stamp_token, + epsilon=self._epsilon, + max_elements=self._max_elements, + num_quantiles=self._num_quantiles, + generate_quantiles=self._generate_quantiles) + + @property + def initializer(self): + if self._init_op is None: + self._init_op = self.initialize() + return self._init_op + + def is_initialized(self): + return gen_quantile_ops.quantile_accumulator_is_initialized( + self.resource_handle) + + def _gather_saveables_for_checkpoint(self): + return {"quantile_accumulator", self.saveable} + def get_buckets(self, stamp_token): """Returns quantile buckets created during previous flush.""" are_buckets_ready, buckets = ( gen_quantile_ops.quantile_accumulator_get_buckets( - quantile_accumulator_handles=[self._quantile_accumulator_handle], + quantile_accumulator_handles=[self.resource_handle], stamp_token=stamp_token)) return are_buckets_ready[0], buckets[0] def schedule_get_buckets(self): """Returns a scheduled read of buckets created during previous flush.""" return batch_ops_utils.ScheduledStampedResourceOp( - resource_handle=self._quantile_accumulator_handle, + resource_handle=self.resource_handle, op=gen_quantile_ops.quantile_accumulator_get_buckets) def _make_summary(self, column, example_weights): @@ -161,14 +191,14 @@ def add_summary(self, stamp_token, column, example_weights): """Adds quantile summary to its stream in resource.""" summary = self._make_summary(column, example_weights) return gen_quantile_ops.quantile_accumulator_add_summaries( - quantile_accumulator_handles=[self._quantile_accumulator_handle], + quantile_accumulator_handles=[self.resource_handle], stamp_token=stamp_token, summaries=[summary]) def add_prebuilt_summary(self, stamp_token, summary): """Adds quantile summary to its stream in resource.""" return gen_quantile_ops.quantile_accumulator_add_summaries( - quantile_accumulator_handles=[self._quantile_accumulator_handle], + quantile_accumulator_handles=[self.resource_handle], stamp_token=stamp_token, summaries=[summary]) @@ -177,7 +207,7 @@ def schedule_add_summary(self, stamp_token, column, example_weights): summary = self._make_summary(column, example_weights) return batch_ops_utils.ScheduledStampedResourceOp( op=gen_quantile_ops.quantile_accumulator_add_summaries, - resource_handle=self._quantile_accumulator_handle, + resource_handle=self.resource_handle, summaries=summary) def flush(self, stamp_token, next_stamp_token): @@ -190,17 +220,14 @@ def flush(self, stamp_token, next_stamp_token): The flush operation. """ return gen_quantile_ops.quantile_accumulator_flush( - quantile_accumulator_handle=self._quantile_accumulator_handle, + quantile_accumulator_handle=self.resource_handle, stamp_token=stamp_token, next_stamp_token=next_stamp_token) def flush_summary(self, stamp_token, next_stamp_token): """Finalizes quantile summary stream and resets it for next iteration.""" result = gen_quantile_ops.quantile_accumulator_flush_summary( - quantile_accumulator_handle=self._quantile_accumulator_handle, + quantile_accumulator_handle=self.resource_handle, stamp_token=stamp_token, next_stamp_token=next_stamp_token) return result - - def resource(self): - return self._quantile_accumulator_handle diff --git a/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py b/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py index 2e94e353f325f0..ad1191d41236e7 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py @@ -26,12 +26,83 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import resources from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import tracking # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") -class StatsAccumulator(saver.BaseSaverBuilder.SaveableObject): +class StatsAccumulatorSaveable(saver.BaseSaverBuilder.SaveableObject): + """SaveableObject implementation for StatsAccumulator.""" + + def __init__(self, resource_handle, create_op, is_scalar, name): + self._create_op = create_op + self._resource_handle = resource_handle + self._is_scalar = is_scalar + slice_spec = "" + saver_name = self._resource_handle.name + (stamp_token, num_updates, partition_ids, feature_ids, gradients, + hessians) = self.serialize() + specs = [ + saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec, + saver_name + "_stamp"), + saver.BaseSaverBuilder.SaveSpec(num_updates, slice_spec, + saver_name + "_num_updates"), + saver.BaseSaverBuilder.SaveSpec(partition_ids, slice_spec, + saver_name + "_partition_ids"), + saver.BaseSaverBuilder.SaveSpec(feature_ids, slice_spec, + saver_name + "_feature_ids"), + saver.BaseSaverBuilder.SaveSpec(gradients, slice_spec, + saver_name + "_gradients"), + saver.BaseSaverBuilder.SaveSpec(hessians, slice_spec, + saver_name + "hessians"), + ] + super(StatsAccumulatorSaveable, self).__init__(self._resource_handle, specs, + name) + + def serialize(self): + """Serializes the stats accumulator state.""" + if self._is_scalar: + return gen_stats_accumulator_ops.stats_accumulator_scalar_serialize( + self._resource_handle) + else: + return gen_stats_accumulator_ops.stats_accumulator_tensor_serialize( + self._resource_handle) + + def deserialize(self, stamp_token, num_updates, partition_ids, feature_ids, + gradients, hessians): + """Resets the stats accumulator with the serialized state.""" + if self._is_scalar: + return gen_stats_accumulator_ops.stats_accumulator_scalar_deserialize( + self._resource_handle, stamp_token, num_updates, partition_ids, + feature_ids, gradients, hessians) + else: + return gen_stats_accumulator_ops.stats_accumulator_tensor_deserialize( + self._resource_handle, stamp_token, num_updates, partition_ids, + feature_ids, gradients, hessians) + + def restore(self, restored_tensors, unused_restored_shapes): + """Restores the associated tree ensemble from 'restored_tensors'. + + Args: + restored_tensors: the tensors that were loaded from a checkpoint. + unused_restored_shapes: the shapes this object should conform to after + restore. Not meaningful for trees. + + Returns: + The operation that restores the state of the tree ensemble variable. + """ + with ops.control_dependencies([self._create_op]): + return self.deserialize( + stamp_token=restored_tensors[0], + num_updates=restored_tensors[1], + partition_ids=restored_tensors[2], + feature_ids=restored_tensors[3], + gradients=restored_tensors[4], + hessians=restored_tensors[5]) + + +class StatsAccumulator(tracking.TrackableResource): """A resource that allows to accumulate gradients and hessians. For consistency guarantees, we use read and write stamp tokens. @@ -58,58 +129,69 @@ def __init__(self, Returns: A `Tensor` of type mutable `string`. The handle to the stats accumulator. """ + self._stamp_token = stamp_token + self._gradient_shape = gradient_shape + self._hessian_shape = hessian_shape + self._container = container + + if (gradient_shape == tensor_shape.scalar() and + hessian_shape == tensor_shape.scalar()): + self._is_scalar = True + else: + self._is_scalar = False + if name is not None: name = _PATTERN.sub("", name) with ops.name_scope(name, "StatsAccumulator") as name: - # Both values are scalars. - if (gradient_shape == tensor_shape.scalar() and - hessian_shape == tensor_shape.scalar()): - self._is_scalar = True - self._resource_handle = (gen_stats_accumulator_ops. - stats_accumulator_scalar_resource_handle_op( - container, name, name=name)) - - create_op = gen_stats_accumulator_ops.create_stats_accumulator_scalar( - self._resource_handle, stamp_token) - is_initialized_op = ( - gen_stats_accumulator_ops.stats_accumulator_scalar_is_initialized( - self._resource_handle)) - else: - self._is_scalar = False - self._resource_handle = (gen_stats_accumulator_ops. - stats_accumulator_tensor_resource_handle_op( - container, name, name=name)) - create_op = gen_stats_accumulator_ops.create_stats_accumulator_tensor( - self._resource_handle, stamp_token, gradient_shape.as_list(), - hessian_shape.as_list()) - is_initialized_op = ( - gen_stats_accumulator_ops.stats_accumulator_tensor_is_initialized( - self._resource_handle)) + self._name = name + self._resource_handle = self.create_resource() + self._init_op = self.initialize() + is_initialized_op = self.is_initialized() + resources.register_resource(self.resource_handle, self.initializer, + is_initialized_op) + self._saveable = StatsAccumulatorSaveable( + self.resource_handle, self.initializer, self._is_scalar, name) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) - self._create_op = create_op - slice_spec = "" - saver_name = self._resource_handle.name - (stamp_token, num_updates, partition_ids, feature_ids, gradients, - hessians) = self.serialize() - specs = [ - saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec, - saver_name + "_stamp"), - saver.BaseSaverBuilder.SaveSpec(num_updates, slice_spec, - saver_name + "_num_updates"), - saver.BaseSaverBuilder.SaveSpec(partition_ids, slice_spec, - saver_name + "_partition_ids"), - saver.BaseSaverBuilder.SaveSpec(feature_ids, slice_spec, - saver_name + "_feature_ids"), - saver.BaseSaverBuilder.SaveSpec(gradients, slice_spec, - saver_name + "_gradients"), - saver.BaseSaverBuilder.SaveSpec(hessians, slice_spec, - saver_name + "hessians"), - ] + def create_resource(self): + if self._is_scalar: + return ( + gen_stats_accumulator_ops.stats_accumulator_scalar_resource_handle_op( + self._container, self._name, name=self._name)) + else: + return ( + gen_stats_accumulator_ops.stats_accumulator_tensor_resource_handle_op( + self._container, self._name, name=self._name)) - super(StatsAccumulator, self).__init__(self._resource_handle, specs, name) - resources.register_resource(self._resource_handle, create_op, - is_initialized_op) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self) + def initialize(self): + if self._is_scalar: + return gen_stats_accumulator_ops.create_stats_accumulator_scalar( + self.resource_handle, self._stamp_token) + else: + return gen_stats_accumulator_ops.create_stats_accumulator_tensor( + self.resource_handle, self._stamp_token, + self._gradient_shape.as_list(), self._hessian_shape.as_list()) + + @property + def initializer(self): + if self._init_op is None: + self._init_op = self.initialize() + return self._init_op + + def is_initialized(self): + if self._is_scalar: + return gen_stats_accumulator_ops.stats_accumulator_scalar_is_initialized( + self.resource_handle) + else: + return gen_stats_accumulator_ops.stats_accumulator_tensor_is_initialized( + self.resource_handle) + + @property + def saveable(self): + return self._saveable + + def _gather_saveables_for_checkpoint(self): + return {"stats_accumulator", self.saveable} def add(self, stamp_token, partition_ids, feature_ids, gradients, hessians): """Updates the stats accumulator.""" @@ -117,11 +199,11 @@ def add(self, stamp_token, partition_ids, feature_ids, gradients, hessians): partition_ids, feature_ids, gradients, hessians)) if self._is_scalar: return gen_stats_accumulator_ops.stats_accumulator_scalar_add( - [self._resource_handle], stamp_token, [partition_ids], [feature_ids], + [self.resource_handle], stamp_token, [partition_ids], [feature_ids], [gradients], [hessians]) else: return gen_stats_accumulator_ops.stats_accumulator_tensor_add( - [self._resource_handle], stamp_token, [partition_ids], [feature_ids], + [self.resource_handle], stamp_token, [partition_ids], [feature_ids], [gradients], [hessians]) def schedule_add(self, partition_ids, feature_ids, gradients, hessians): @@ -131,7 +213,7 @@ def schedule_add(self, partition_ids, feature_ids, gradients, hessians): if self._is_scalar: return batch_ops_utils.ScheduledStampedResourceOp( op=gen_stats_accumulator_ops.stats_accumulator_scalar_add, - resource_handle=self._resource_handle, + resource_handle=self.resource_handle, partition_ids=partition_ids, feature_ids=feature_ids, gradients=gradients, @@ -139,7 +221,7 @@ def schedule_add(self, partition_ids, feature_ids, gradients, hessians): else: return batch_ops_utils.ScheduledStampedResourceOp( op=gen_stats_accumulator_ops.stats_accumulator_tensor_add, - resource_handle=self._resource_handle, + resource_handle=self.resource_handle, partition_ids=partition_ids, feature_ids=feature_ids, gradients=gradients, @@ -153,55 +235,11 @@ def _make_summary(self, partition_ids, feature_ids, gradients, hessians): return gen_stats_accumulator_ops.stats_accumulator_tensor_make_summary( partition_ids, feature_ids, gradients, hessians) - def deserialize(self, stamp_token, num_updates, partition_ids, feature_ids, - gradients, hessians): - """Resets the stats accumulator with the serialized state.""" - if self._is_scalar: - return gen_stats_accumulator_ops.stats_accumulator_scalar_deserialize( - self._resource_handle, stamp_token, num_updates, partition_ids, - feature_ids, gradients, hessians) - else: - return gen_stats_accumulator_ops.stats_accumulator_tensor_deserialize( - self._resource_handle, stamp_token, num_updates, partition_ids, - feature_ids, gradients, hessians) - def flush(self, stamp_token, next_stamp_token): """Flushes the stats accumulator.""" if self._is_scalar: return gen_stats_accumulator_ops.stats_accumulator_scalar_flush( - self._resource_handle, stamp_token, next_stamp_token) + self.resource_handle, stamp_token, next_stamp_token) else: return gen_stats_accumulator_ops.stats_accumulator_tensor_flush( - self._resource_handle, stamp_token, next_stamp_token) - - def serialize(self): - """Serializes the stats accumulator state.""" - if self._is_scalar: - return gen_stats_accumulator_ops.stats_accumulator_scalar_serialize( - self._resource_handle) - else: - return gen_stats_accumulator_ops.stats_accumulator_tensor_serialize( - self._resource_handle) - - def restore(self, restored_tensors, unused_restored_shapes): - """Restores the associated tree ensemble from 'restored_tensors'. - - Args: - restored_tensors: the tensors that were loaded from a checkpoint. - unused_restored_shapes: the shapes this object should conform to after - restore. Not meaningful for trees. - - Returns: - The operation that restores the state of the tree ensemble variable. - """ - with ops.control_dependencies([self._create_op]): - return self.deserialize( - stamp_token=restored_tensors[0], - num_updates=restored_tensors[1], - partition_ids=restored_tensors[2], - feature_ids=restored_tensors[3], - gradients=restored_tensors[4], - hessians=restored_tensors[5]) - - def resource(self): - return self._resource_handle + self.resource_handle, stamp_token, next_stamp_token) diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index bd5d5bb695684c..ab5713fbe26ab7 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -386,10 +386,21 @@ def __init__(self, learner_pb2.LearnerConfig.GROWING_MODE_UNSPECIFIED): learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER + if (learner_config.weak_learner_type == learner_pb2.LearnerConfig + .OBLIVIOUS_DECISION_TREE and learner_config.pruning_mode == learner_pb2 + .LearnerConfig.PRUNING_MODE_UNSPECIFIED): + learner_config.pruning_mode = learner_pb2.LearnerConfig.PRE_PRUNE + if (learner_config.pruning_mode == learner_pb2.LearnerConfig.PRUNING_MODE_UNSPECIFIED): learner_config.pruning_mode = learner_pb2.LearnerConfig.POST_PRUNE + if (learner_config.weak_learner_type == learner_pb2.LearnerConfig + .OBLIVIOUS_DECISION_TREE and + learner_config.pruning_mode == learner_pb2.LearnerConfig.POST_PRUNE): + raise ValueError( + "Post pruning is not implmented for oblivious decision trees.") + if learner_config.constraints.max_tree_depth == 0: # Use 6 as the default maximum depth. learner_config.constraints.max_tree_depth = 6 @@ -418,6 +429,11 @@ def __init__(self, sparse_float_shapes, sparse_int_indices, sparse_int_values, sparse_int_shapes) = extract_features( features, self._feature_columns, use_core_columns) + if (learner_config.weak_learner_type == learner_pb2.LearnerConfig + .OBLIVIOUS_DECISION_TREE and sparse_float_indices): + raise ValueError("Oblivious trees don't handle sparse float features yet." + ) + logging.info("Active Feature Columns: " + str(fc_names)) logging.info("Learner config: " + str(learner_config)) self._fc_names = fc_names @@ -976,7 +992,7 @@ def increment_step_counter_and_maybe_update_ensemble(self, predictions_dict, # Get accumulated steps and examples for the current layer. _, _, _, _, acc_examples, acc_steps = ( - steps_accumulator.serialize()) + steps_accumulator.saveable.serialize()) acc_examples = math_ops.cast(acc_examples[0], dtypes.int64) acc_steps = math_ops.cast(acc_steps[0], dtypes.int64) ensemble_update_ops.append( diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py index 996c0af10f92ca..5ecd4f341831ce 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py @@ -20,9 +20,12 @@ import abc +import six + from tensorflow.python.training.server_lib import ClusterSpec +@six.add_metaclass(abc.ABCMeta) class ClusterResolver(object): """Abstract class for all implementations of ClusterResolvers. diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index dea5a6f9662f69..d94b703700cfcd 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -279,10 +279,10 @@ tensorflow/contrib/linear_optimizer/kernels/g3doc tensorflow/contrib/linear_optimizer/python tensorflow/contrib/linear_optimizer/python/ops # TODO(drpngx): Fix failing imports -# tensorflow/contrib/lite -# tensorflow/contrib/lite/python -# tensorflow/contrib/lite/toco -# tensorflow/contrib/lite/toco/python +# tensorflow/lite +# tensorflow/lite/python +# tensorflow/lite/toco +# tensorflow/lite/toco/python tensorflow/contrib/lookup tensorflow/contrib/losses tensorflow/contrib/losses/python diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt index 42afbd9105ef37..013180c8908374 100644 --- a/tensorflow/contrib/cmake/python_protos.txt +++ b/tensorflow/contrib/cmake/python_protos.txt @@ -6,7 +6,7 @@ tensorflow/contrib/boosted_trees/proto tensorflow/contrib/cloud/kernels tensorflow/contrib/decision_trees/proto tensorflow/contrib/gdr -tensorflow/contrib/lite/toco +tensorflow/lite/toco tensorflow/contrib/mpi tensorflow/contrib/mpi_collectives tensorflow/contrib/session_bundle diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 6d86daf5f174a3..ef487d3509bf3c 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -222,17 +222,17 @@ endforeach(python_module) add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD COMMAND ${CMAKE_COMMAND} -E make_directory - "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite") + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/lite") add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD COMMAND ${CMAKE_COMMAND} -E make_directory - "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite/python") + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/lite/python") add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD COMMAND ${CMAKE_COMMAND} -E touch - "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite/python/__init__.py") + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/lite/python/__init__.py") add_custom_command( TARGET tf_python_copy_scripts_to_destination PRE_BUILD COMMAND ${CMAKE_COMMAND} -E touch - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite/python/lite.py) + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/lite/python/lite.py) # Generate the tensorflow.python.platform.build_info module. set(BUILD_INFO_PY "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/platform/build_info.py") diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py index 146f4b51a5e5b8..335ac7946485f2 100644 --- a/tensorflow/contrib/compiler/xla.py +++ b/tensorflow/contrib/compiler/xla.py @@ -179,14 +179,11 @@ def AddOp(self, op): if external_control_inputs: # Use an identity to pull control inputs as data inputs. Note that we # ignore ops which don't have outputs. TODO(phawkins): fix that. - with ops.control_dependencies(None): - self.Enter() - external_control_inputs = [ - array_ops.identity(x.outputs[0]).op - for x in external_control_inputs - if x.outputs - ] - self.Exit() + external_control_inputs = [ + array_ops.identity(x.outputs[0]).op + for x in external_control_inputs + if x.outputs + ] # pylint: disable=protected-access op._add_control_inputs(external_control_inputs) # pylint: enable=protected-access diff --git a/tensorflow/contrib/crf/__init__.py b/tensorflow/contrib/crf/__init__.py index fe5e34d258fbc1..d53549048f3316 100644 --- a/tensorflow/contrib/crf/__init__.py +++ b/tensorflow/contrib/crf/__init__.py @@ -14,8 +14,6 @@ # ============================================================================== """Linear-chain CRF layer. -See the [CRF](https://tensorflow.org/api_guides/python/contrib.crf) guide. - @@crf_binary_score @@crf_decode @@crf_log_likelihood diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index abda1fb2f0d4e8..6ba9187ae218a6 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -374,6 +374,9 @@ cuda_py_test( tags = [ "multi_and_single_gpu", "no_pip", + # TODO(b/118820960): Re-enable this test in guitar. + "manual", + "noguitar", ], ) @@ -518,6 +521,10 @@ cuda_py_test( tags = [ "multi_and_single_gpu", "no_pip", + # TODO(b/118768923): Re-enable {a,m,t}san test. + "noasan", + "nomsan", + "notsan", ], ) diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py index d99d7080bc1ca3..018512ae5a22ea 100644 --- a/tensorflow/contrib/distribute/python/estimator_training_test.py +++ b/tensorflow/contrib/distribute/python/estimator_training_test.py @@ -296,7 +296,6 @@ def predict_input_fn(): eval_distribute_cls=[ None, mirrored_strategy.MirroredStrategy, parameter_server_strategy.ParameterServerStrategy, - collective_all_reduce_strategy.CollectiveAllReduceStrategy ], required_gpus=[0, 1])) def test_complete_flow_standalone_client(self, train_distribute_cls, @@ -307,7 +306,8 @@ def test_complete_flow_standalone_client(self, train_distribute_cls, train_distribute = train_distribute_cls(num_gpus_per_worker=2) if eval_distribute_cls: - eval_distribute = eval_distribute_cls() + eval_distribute = eval_distribute_cls( + num_gpus_per_worker=context.num_gpus()) else: eval_distribute = None @@ -336,7 +336,8 @@ def test_estimator_standalone_client(self, train_distribute_cls, num_gpus_per_worker=context.num_gpus()) if eval_distribute_cls: - eval_distribute = eval_distribute_cls() + eval_distribute = eval_distribute_cls( + num_gpus_per_worker=context.num_gpus()) else: eval_distribute = None @@ -407,7 +408,6 @@ def _run_multiple_tasks_in_threads(self, cluster_spec, train_distribute, eval_distribute_cls=[ None, mirrored_strategy.MirroredStrategy, parameter_server_strategy.ParameterServerStrategy, - collective_all_reduce_strategy.CollectiveAllReduceStrategy ], required_gpus=[0, 1])) def test_complete_flow_indepedent_worker_between_graph( @@ -420,7 +420,8 @@ def test_complete_flow_indepedent_worker_between_graph( self.skipTest("`CollectiveAllReduceStrategy` needs at least two towers.") if eval_distribute_cls: - eval_distribute = eval_distribute_cls() + eval_distribute = eval_distribute_cls( + num_gpus_per_worker=context.num_gpus()) else: eval_distribute = None @@ -459,7 +460,8 @@ def test_complete_flow_indepedent_worker_in_graph(self, train_distribute_cls, num_gpus_per_worker=context.num_gpus()) if eval_distribute_cls: - eval_distribute = eval_distribute_cls() + eval_distribute = eval_distribute_cls( + num_gpus_per_worker=context.num_gpus()) else: eval_distribute = None diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py index c7036daa3e3321..0fd3acd045170c 100644 --- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py +++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py @@ -61,7 +61,6 @@ def get_input_datasets(use_bfloat16=False): # train dataset train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)) train_ds = train_ds.repeat() - train_ds = train_ds.shuffle(100) train_ds = train_ds.map(lambda x, y: (tf.cast(x, cast_dtype), y)) train_ds = train_ds.batch(64, drop_remainder=True) diff --git a/tensorflow/contrib/distribute/python/input_ops.py b/tensorflow/contrib/distribute/python/input_ops.py index 423952c9e254f5..f07ec8234dfe87 100644 --- a/tensorflow/contrib/distribute/python/input_ops.py +++ b/tensorflow/contrib/distribute/python/input_ops.py @@ -78,7 +78,7 @@ def _auto_shard_impl(dataset, found_reader_op): elif hasattr(dataset, "_map_func"): # TODO(priyag): Make this check more robust by enforcing some common # property on all map/flatmap/interleave datasets. - map_func_def = dataset._map_func.function_def + map_func_def = dataset._map_func.definition for node in map_func_def.node_def: if node.op in _READER_DATASET_OPS: found_reader_op = True diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index eccff1d9f57a67..33b8a61eb1aaf2 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -598,36 +598,33 @@ def test_calling_model_on_same_dataset(self, distribution): @combinations.generate(strategy_combinations()) def test_model_interleaved_eval_same_as_direct_eval(self, distribution): with self.cached_session(): - loss = 'mse' - user_controlled_model = get_model() - user_controlled_optimizer = gradient_descent.GradientDescentOptimizer( - 0.001) - user_controlled_metrics = ['mae', keras.metrics.CategoricalAccuracy()] - user_controlled_model.compile(user_controlled_optimizer, loss, - metrics=user_controlled_metrics, - distribute=distribution) + user_controlled_model.compile( + gradient_descent.GradientDescentOptimizer(0.001), + loss='mse', + metrics=['mae', keras.metrics.CategoricalAccuracy()], + distribute=distribution) interleaved_model = get_model() - interleaved_optimizer = gradient_descent.GradientDescentOptimizer(0.001) - interleaved_metrics = ['mae', keras.metrics.CategoricalAccuracy()] - interleaved_model.compile(interleaved_optimizer, loss, - metrics=interleaved_metrics, - distribute=distribution) + interleaved_model.set_weights(user_controlled_model.get_weights()) + interleaved_model.compile( + gradient_descent.GradientDescentOptimizer(0.001), + loss='mse', + metrics=['mae', keras.metrics.CategoricalAccuracy()], + distribute=distribution) dataset = get_dataset(distribution) # Call fit with validation interleaved - interleaved_output = interleaved_model.fit(dataset, epochs=2, - steps_per_epoch=2, verbose=0, - validation_data=dataset, - validation_steps=2) + interleaved_output = interleaved_model.fit( + dataset, epochs=2, steps_per_epoch=2, verbose=1, + validation_data=dataset, validation_steps=2, shuffle=False) # Manually control the validation running after each epoch. user_controlled_output = [] for _ in range(2): user_controlled_model.fit( - dataset, epochs=1, steps_per_epoch=2, verbose=0) + dataset, epochs=1, steps_per_epoch=2, verbose=1, shuffle=False) user_controlled_output.append( user_controlled_model.evaluate(dataset, steps=2)) @@ -800,8 +797,9 @@ def test_learning_phase_value(self): self.assertAlmostEqual(hist.history['acc'][0], 0, 0) model.set_weights(initial_weights) - evaluate_output = model.evaluate(dataset, steps=20) - self.assertAlmostEqual(evaluate_output[1], 1, 0) + # TODO(psv/anjalisridhar): Enable these lines after we fix b/117431185. + # evaluate_output = model.evaluate(dataset, steps=20) + # self.assertAlmostEqual(evaluate_output[1], 1, 0) inputs = np.ones((10, 1), dtype=np.float32) predict_dataset = dataset_ops.Dataset.from_tensor_slices(inputs) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index bbfd94ed5c0dd5..2aa7f1ae5d6f0b 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -293,6 +293,8 @@ def _call_for_each_replica(self, fn, *args, **kwargs): return mirrored_strategy._call_for_each_replica(self, fn, *args, **kwargs) def _verify_destinations_not_different_worker(self, destinations): + if not self._cluster_spec: + return if destinations is None: return for d in cross_tower_ops_lib.get_devices_from(destinations): diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index b8d5d0ecafce70..a9f643c6eccdae 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -355,10 +355,15 @@ def model_fn(): def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): d, master_target, sess_config = self._get_test_objects( task_type, task_id, num_gpus) - assert hasattr(d, '_cluster_spec') and d._cluster_spec - num_workers = len(d._cluster_spec.as_dict().get(WORKER)) - if CHIEF in d._cluster_spec.as_dict(): - num_workers += 1 + if task_type: + # Multi-worker + assert hasattr(d, '_cluster_spec') and d._cluster_spec + num_workers = len(d._cluster_spec.as_dict().get(WORKER)) + if CHIEF in d._cluster_spec.as_dict(): + num_workers += 1 + else: + # local + num_workers = 1 with ops.Graph().as_default(), \ self.cached_session(target=master_target, @@ -410,7 +415,8 @@ def step(): if context.num_gpus() < d._num_gpus_per_worker: return True - if multi_worker_util.is_chief(d._cluster_spec, task_type, task_id): + if (not task_type or + multi_worker_util.is_chief(d._cluster_spec, task_type, task_id)): variables.global_variables_initializer().run() # Workers waiting for chief worker's initializing variables. @@ -484,10 +490,15 @@ def testLocalSimpleIncrement(self, num_gpus): @combinations.generate( combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testMinimizeLossGraph(self, num_gpus): + def testMinimizeLossGraphDistributed(self, num_gpus): self._run_between_graph_clients(self._test_minimize_loss_graph, self._cluster_spec, num_gpus) + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testMinimizeLossGraphLocal(self, num_gpus): + self._test_minimize_loss_graph(None, None, num_gpus) + class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, parameterized.TestCase): diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 6b47ba499cbbf8..65ef21df09ba34 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -24,7 +24,6 @@ import functools from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib -from tensorflow.contrib.distribute.python import one_device_strategy from tensorflow.contrib.distribute.python import values from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu @@ -114,9 +113,8 @@ def _create_tpu_mirrored_variable(devices, real_mirrored_creator, *args, return result -# TODO(jhseu): Stop inheriting from OneDeviceStrategy. -class TPUStrategy(one_device_strategy.OneDeviceStrategy): - """Experimental TPU distribution strategy implementation.""" +class TPUStrategy(distribute_lib.DistributionStrategy): + """TPU distribution strategy implementation.""" def __init__(self, tpu_cluster_resolver, steps_per_run, num_cores=None): """Initializes the TPUStrategy object. @@ -132,9 +130,7 @@ def __init__(self, tpu_cluster_resolver, steps_per_run, num_cores=None): num_cores: Number of cores to use on the TPU. If None specified, then auto-detect the cores and topology of the TPU system. """ - # TODO(sourabhbajaj): OneDeviceStrategy should be initialized with the - # master node fetched from the cluster resolver. - super(TPUStrategy, self).__init__("/device:CPU:0") + super(TPUStrategy, self).__init__() self._tpu_cluster_resolver = tpu_cluster_resolver self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) @@ -146,6 +142,7 @@ def __init__(self, tpu_cluster_resolver, steps_per_run, num_cores=None): device_map = {d.name: i for i, d in enumerate(self._tpu_metadata.devices) if "device:TPU:" in d.name} self._device_index = values.PerDevice(device_map) + self._host_device = self.get_host_cpu_device(0) self._tpu_devices = sorted(device_map.keys()) # Only create variables for the number of replicas we're running. self._tpu_devices = self._tpu_devices[:self.num_replicas] @@ -323,7 +320,7 @@ def _call_for_each_replica(self, fn, *args, **kwargs): # TODO(jhseu): Consider making it so call_for_each_replica implies that # we're in a tpu.rewrite(), and update TPUMirroredVariable accordingly. kwargs.pop("run_concurrently", None) - with one_device_strategy._OneDeviceReplicaContext(self): # pylint: disable=protected-access + with _TPUReplicaContext(self): return fn(*args, **kwargs) def initialize(self): @@ -402,7 +399,7 @@ def _reduce(self, aggregation, value, destinations): devices = cross_tower_ops_lib.get_devices_from(destinations) if len(devices) == 1: assert device_util.canonicalize(devices[0]) == device_util.canonicalize( - self.get_host_cpu_device(0)) + self._host_device) else: raise ValueError("Multiple devices are not supported for TPUStrategy") @@ -453,6 +450,13 @@ def _unwrap(self, val): return val return [val] + def value_container(self, value): + return value + + def _broadcast(self, tensor, destinations): + del destinations + return tensor + @property def num_replicas(self): return self._num_cores_override or self._tpu_metadata.num_cores @@ -493,6 +497,21 @@ def worker_devices(self): def parameter_devices(self): return self._tpu_devices + def non_slot_devices(self, var_list): + return self._host_device + + def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): + del colocate_with + should_group = options.pop("grouped") + assert not options # Validate that we are processing all of the options. + with ops.device(self._host_device), distribute_lib.UpdateContext( + self._host_device): + result = fn(*args, **kwargs) + if should_group: + return result + else: + return nest.map_structure(self._unwrap, result) + def get_host(self, host_id): if self._tpu_cluster_resolver.get_master() in ("", "local"): return "/replica:0/task:0" @@ -513,3 +532,17 @@ def configure(self, cluster_spec = self._tpu_cluster_resolver.cluster_spec() if cluster_spec: session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) + + +class _TPUReplicaContext(distribute_lib.ReplicaContext): + """Replication Context class for TPU Strategy.""" + + # TODO(sourabhbajaj): Call for each tower should be updating this. + def __init__(self, distribution_strategy): + distribute_lib.ReplicaContext.__init__( + self, distribution_strategy, replica_id=0) + + @property + def device(self): + distribute_lib.require_replica_context(self) + return self._distribution_strategy.worker_devices[self._replica_id] diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD index a926ffd5982116..bbe335be3e1384 100644 --- a/tensorflow/contrib/feature_column/BUILD +++ b/tensorflow/contrib/feature_column/BUILD @@ -73,3 +73,45 @@ py_test( "//tensorflow/python/keras:layers", ], ) + +py_library( + name = "sequence_feature_column_v2", + srcs = ["python/feature_column/sequence_feature_column_v2.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:variable_scope", + "//tensorflow/python/feature_column", + "//tensorflow/python/feature_column:feature_column_v2", + ], +) + +py_test( + name = "sequence_feature_column_v2_test", + srcs = ["python/feature_column/sequence_feature_column_v2_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":sequence_feature_column", + ":sequence_feature_column_v2", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:training", + "//tensorflow/python/feature_column", + "//tensorflow/python/feature_column:feature_column_v2", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py new file mode 100644 index 00000000000000..6e775afb69af04 --- /dev/null +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py @@ -0,0 +1,547 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""This API defines FeatureColumn for sequential input. + +NOTE: This API is a work in progress and will likely be changing frequently. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import collections + + +from tensorflow.python.feature_column import feature_column as fc_old +from tensorflow.python.feature_column import feature_column_v2 as fc +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import variable_scope + +# pylint: disable=protected-access + + +def sequence_input_layer( + features, + feature_columns, + weight_collections=None, + trainable=True): + """"Builds input layer for sequence input. + + All `feature_columns` must be sequence dense columns with the same + `sequence_length`. The output of this method can be fed into sequence + networks, such as RNN. + + The output of this method is a 3D `Tensor` of shape `[batch_size, T, D]`. + `T` is the maximum sequence length for this batch, which could differ from + batch to batch. + + If multiple `feature_columns` are given with `Di` `num_elements` each, their + outputs are concatenated. So, the final `Tensor` has shape + `[batch_size, T, D0 + D1 + ... + Dn]`. + + Example: + + ```python + rating = sequence_numeric_column('rating') + watches = sequence_categorical_column_with_identity( + 'watches', num_buckets=1000) + watches_embedding = embedding_column(watches, dimension=10) + columns = [rating, watches] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + features: A dict mapping keys to tensors. + feature_columns: An iterable of dense sequence columns. Valid columns are + - `embedding_column` that wraps a `sequence_categorical_column_with_*` + - `sequence_numeric_column`. + weight_collections: A list of collection names to which the Variable will be + added. Note that variables will also be added to collections + `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`. + trainable: If `True` also add the variable to the graph collection + `GraphKeys.TRAINABLE_VARIABLES`. + + Returns: + An `(input_layer, sequence_length)` tuple where: + - input_layer: A float `Tensor` of shape `[batch_size, T, D]`. + `T` is the maximum sequence length for this batch, which could differ + from batch to batch. `D` is the sum of `num_elements` for all + `feature_columns`. + - sequence_length: An int `Tensor` of shape `[batch_size]`. The sequence + length for each example. + + Raises: + ValueError: If any of the `feature_columns` is the wrong type. + """ + feature_columns = fc_old._normalize_feature_columns(feature_columns) + for c in feature_columns: + if not isinstance(c, fc_old._SequenceDenseColumn): + raise ValueError( + 'All feature_columns must be of type _SequenceDenseColumn. ' + 'You can wrap a sequence_categorical_column with an embedding_column ' + 'or indicator_column. ' + 'Given (type {}): {}'.format(type(c), c)) + + with variable_scope.variable_scope( + None, default_name='sequence_input_layer', values=features.values()): + builder = fc_old._LazyBuilder(features) + output_tensors = [] + sequence_lengths = [] + ordered_columns = [] + + for column in sorted(feature_columns, key=lambda x: x.name): + ordered_columns.append(column) + with variable_scope.variable_scope( + None, default_name=column._var_scope_name): + dense_tensor, sequence_length = column._get_sequence_dense_tensor( + builder, + weight_collections=weight_collections, + trainable=trainable) + # Flattens the final dimension to produce a 3D Tensor. + num_elements = column._variable_shape.num_elements() + shape = array_ops.shape(dense_tensor) + target_shape = [shape[0], shape[1], num_elements] + output_tensors.append( + array_ops.reshape(dense_tensor, shape=target_shape)) + sequence_lengths.append(sequence_length) + + fc_old._verify_static_batch_size_equality(output_tensors, ordered_columns) + fc_old._verify_static_batch_size_equality(sequence_lengths, ordered_columns) + sequence_length = _assert_all_equal_and_return(sequence_lengths) + + return array_ops.concat(output_tensors, -1), sequence_length + + +def concatenate_context_input(context_input, sequence_input): + """Replicates `context_input` across all timesteps of `sequence_input`. + + Expands dimension 1 of `context_input` then tiles it `sequence_length` times. + This value is appended to `sequence_input` on dimension 2 and the result is + returned. + + Args: + context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`. + sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size, + padded_length, d0]`. + + Returns: + A `Tensor` of dtype `float32` and shape `[batch_size, padded_length, + d0 + d1]`. + + Raises: + ValueError: If `sequence_input` does not have rank 3 or `context_input` does + not have rank 2. + """ + seq_rank_check = check_ops.assert_rank( + sequence_input, + 3, + message='sequence_input must have rank 3', + data=[array_ops.shape(sequence_input)]) + seq_type_check = check_ops.assert_type( + sequence_input, + dtypes.float32, + message='sequence_input must have dtype float32; got {}.'.format( + sequence_input.dtype)) + ctx_rank_check = check_ops.assert_rank( + context_input, + 2, + message='context_input must have rank 2', + data=[array_ops.shape(context_input)]) + ctx_type_check = check_ops.assert_type( + context_input, + dtypes.float32, + message='context_input must have dtype float32; got {}.'.format( + context_input.dtype)) + with ops.control_dependencies( + [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]): + padded_length = array_ops.shape(sequence_input)[1] + tiled_context_input = array_ops.tile( + array_ops.expand_dims(context_input, 1), + array_ops.concat([[1], [padded_length], [1]], 0)) + return array_ops.concat([sequence_input, tiled_context_input], 2) + + +def sequence_categorical_column_with_identity( + key, num_buckets, default_value=None): + """Returns a feature column that represents sequences of integers. + + Pass this to `embedding_column` or `indicator_column` to convert sequence + categorical data into dense representation for input to sequence NN, such as + RNN. + + Example: + + ```python + watches = sequence_categorical_column_with_identity( + 'watches', num_buckets=1000) + watches_embedding = embedding_column(watches, dimension=10) + columns = [watches_embedding] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + key: A unique string identifying the input feature. + num_buckets: Range of inputs. Namely, inputs are expected to be in the + range `[0, num_buckets)`. + default_value: If `None`, this column's graph operations will fail for + out-of-range inputs. Otherwise, this value must be in the range + `[0, num_buckets)`, and will replace out-of-range inputs. + + Returns: + A `_SequenceCategoricalColumn`. + + Raises: + ValueError: if `num_buckets` is less than one. + ValueError: if `default_value` is not in range `[0, num_buckets)`. + """ + return fc_old._SequenceCategoricalColumn( + fc_old.categorical_column_with_identity( + key=key, + num_buckets=num_buckets, + default_value=default_value)) + + +def sequence_categorical_column_with_hash_bucket( + key, hash_bucket_size, dtype=dtypes.string): + """A sequence of categorical terms where ids are set by hashing. + + Pass this to `embedding_column` or `indicator_column` to convert sequence + categorical data into dense representation for input to sequence NN, such as + RNN. + + Example: + + ```python + tokens = sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=1000) + tokens_embedding = embedding_column(tokens, dimension=10) + columns = [tokens_embedding] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + key: A unique string identifying the input feature. + hash_bucket_size: An int > 1. The number of buckets. + dtype: The type of features. Only string and integer types are supported. + + Returns: + A `_SequenceCategoricalColumn`. + + Raises: + ValueError: `hash_bucket_size` is not greater than 1. + ValueError: `dtype` is neither string nor integer. + """ + return fc_old._SequenceCategoricalColumn( + fc_old.categorical_column_with_hash_bucket( + key=key, + hash_bucket_size=hash_bucket_size, + dtype=dtype)) + + +def sequence_categorical_column_with_vocabulary_file( + key, vocabulary_file, vocabulary_size=None, num_oov_buckets=0, + default_value=None, dtype=dtypes.string): + """A sequence of categorical terms where ids use a vocabulary file. + + Pass this to `embedding_column` or `indicator_column` to convert sequence + categorical data into dense representation for input to sequence NN, such as + RNN. + + Example: + + ```python + states = sequence_categorical_column_with_vocabulary_file( + key='states', vocabulary_file='/us/states.txt', vocabulary_size=50, + num_oov_buckets=5) + states_embedding = embedding_column(states, dimension=10) + columns = [states_embedding] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + key: A unique string identifying the input feature. + vocabulary_file: The vocabulary file name. + vocabulary_size: Number of the elements in the vocabulary. This must be no + greater than length of `vocabulary_file`, if less than length, later + values are ignored. If None, it is set to the length of `vocabulary_file`. + num_oov_buckets: Non-negative integer, the number of out-of-vocabulary + buckets. All out-of-vocabulary inputs will be assigned IDs in the range + `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of + the input value. A positive `num_oov_buckets` can not be specified with + `default_value`. + default_value: The integer ID value to return for out-of-vocabulary feature + values, defaults to `-1`. This can not be specified with a positive + `num_oov_buckets`. + dtype: The type of features. Only string and integer types are supported. + + Returns: + A `_SequenceCategoricalColumn`. + + Raises: + ValueError: `vocabulary_file` is missing or cannot be opened. + ValueError: `vocabulary_size` is missing or < 1. + ValueError: `num_oov_buckets` is a negative integer. + ValueError: `num_oov_buckets` and `default_value` are both specified. + ValueError: `dtype` is neither string nor integer. + """ + return fc_old._SequenceCategoricalColumn( + fc_old.categorical_column_with_vocabulary_file( + key=key, + vocabulary_file=vocabulary_file, + vocabulary_size=vocabulary_size, + num_oov_buckets=num_oov_buckets, + default_value=default_value, + dtype=dtype)) + + +def sequence_categorical_column_with_vocabulary_list( + key, vocabulary_list, dtype=None, default_value=-1, num_oov_buckets=0): + """A sequence of categorical terms where ids use an in-memory list. + + Pass this to `embedding_column` or `indicator_column` to convert sequence + categorical data into dense representation for input to sequence NN, such as + RNN. + + Example: + + ```python + colors = sequence_categorical_column_with_vocabulary_list( + key='colors', vocabulary_list=('R', 'G', 'B', 'Y'), + num_oov_buckets=2) + colors_embedding = embedding_column(colors, dimension=3) + columns = [colors_embedding] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + key: A unique string identifying the input feature. + vocabulary_list: An ordered iterable defining the vocabulary. Each feature + is mapped to the index of its value (if present) in `vocabulary_list`. + Must be castable to `dtype`. + dtype: The type of features. Only string and integer types are supported. + If `None`, it will be inferred from `vocabulary_list`. + default_value: The integer ID value to return for out-of-vocabulary feature + values, defaults to `-1`. This can not be specified with a positive + `num_oov_buckets`. + num_oov_buckets: Non-negative integer, the number of out-of-vocabulary + buckets. All out-of-vocabulary inputs will be assigned IDs in the range + `[len(vocabulary_list), len(vocabulary_list)+num_oov_buckets)` based on a + hash of the input value. A positive `num_oov_buckets` can not be specified + with `default_value`. + + Returns: + A `_SequenceCategoricalColumn`. + + Raises: + ValueError: if `vocabulary_list` is empty, or contains duplicate keys. + ValueError: `num_oov_buckets` is a negative integer. + ValueError: `num_oov_buckets` and `default_value` are both specified. + ValueError: if `dtype` is not integer or string. + """ + return fc_old._SequenceCategoricalColumn( + fc_old.categorical_column_with_vocabulary_list( + key=key, + vocabulary_list=vocabulary_list, + dtype=dtype, + default_value=default_value, + num_oov_buckets=num_oov_buckets)) + + +def sequence_numeric_column( + key, + shape=(1,), + default_value=0., + dtype=dtypes.float32, + normalizer_fn=None): + """Returns a feature column that represents sequences of numeric data. + + Example: + + ```python + temperature = sequence_numeric_column('temperature') + columns = [temperature] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + sequence_feature_layer = SequenceFeatureLayer(columns) + input_layer, sequence_length = sequence_feature_layer(features) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + key: A unique string identifying the input features. + shape: The shape of the input data per sequence id. E.g. if `shape=(2,)`, + each example must contain `2 * sequence_length` values. + default_value: A single value compatible with `dtype` that is used for + padding the sparse data into a dense `Tensor`. + dtype: The type of values. + normalizer_fn: If not `None`, a function that can be used to normalize the + value of the tensor after `default_value` is applied for parsing. + Normalizer function takes the input `Tensor` as its argument, and returns + the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that + even though the most common use case of this function is normalization, it + can be used for any kind of Tensorflow transformations. + + Returns: + A `SequenceNumericColumn`. + + Raises: + TypeError: if any dimension in shape is not an int. + ValueError: if any dimension in shape is not a positive integer. + ValueError: if `dtype` is not convertible to `tf.float32`. + """ + shape = fc._check_shape(shape=shape, key=key) + if not (dtype.is_integer or dtype.is_floating): + raise ValueError('dtype must be convertible to float. ' + 'dtype: {}, key: {}'.format(dtype, key)) + if normalizer_fn is not None and not callable(normalizer_fn): + raise TypeError( + 'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn)) + + return SequenceNumericColumn( + key, + shape=shape, + default_value=default_value, + dtype=dtype, + normalizer_fn=normalizer_fn) + + +def _assert_all_equal_and_return(tensors, name=None): + """Asserts that all tensors are equal and returns the first one.""" + with ops.name_scope(name, 'assert_all_equal', values=tensors): + if len(tensors) == 1: + return tensors[0] + assert_equal_ops = [] + for t in tensors[1:]: + assert_equal_ops.append(check_ops.assert_equal(tensors[0], t)) + with ops.control_dependencies(assert_equal_ops): + return array_ops.identity(tensors[0]) + + +class SequenceNumericColumn( + fc.SequenceDenseColumn, + collections.namedtuple( + 'SequenceNumericColumn', + ('key', 'shape', 'default_value', 'dtype', 'normalizer_fn'))): + """Represents sequences of numeric data.""" + + @property + def _is_v2_column(self): + return True + + @property + def name(self): + """See `FeatureColumn` base class.""" + return self.key + + @property + def parse_example_spec(self): + """See `FeatureColumn` base class.""" + return {self.key: parsing_ops.VarLenFeature(self.dtype)} + + def transform_feature(self, transformation_cache, state_manager): + """See `FeatureColumn` base class. + + In this case, we apply the `normalizer_fn` to the input tensor. + + Args: + transformation_cache: A `FeatureTransformationCache` object to access + features. + state_manager: A `StateManager` to create / access resources such as + lookup tables. + + Returns: + Normalized input tensor. + """ + input_tensor = transformation_cache.get(self.key, state_manager) + if self.normalizer_fn is not None: + input_tensor = self.normalizer_fn(input_tensor) + return input_tensor + + @property + def variable_shape(self): + """Returns a `TensorShape` representing the shape of sequence input.""" + return tensor_shape.TensorShape(self.shape) + + def get_sequence_dense_tensor(self, transformation_cache, state_manager): + """Returns a `TensorSequenceLengthPair`. + + Args: + transformation_cache: A `FeatureTransformationCache` object to access + features. + state_manager: A `StateManager` to create / access resources such as + lookup tables. + """ + sp_tensor = transformation_cache.get(self, state_manager) + dense_tensor = sparse_ops.sparse_tensor_to_dense( + sp_tensor, default_value=self.default_value) + # Reshape into [batch_size, T, variable_shape]. + dense_shape = array_ops.concat( + [array_ops.shape(dense_tensor)[:1], [-1], self.variable_shape], + axis=0) + dense_tensor = array_ops.reshape(dense_tensor, shape=dense_shape) + + # Get the number of timesteps per example + # For the 2D case, the raw values are grouped according to num_elements; + # for the 3D case, the grouping happens in the third dimension, and + # sequence length is not affected. + num_elements = (self.variable_shape.num_elements() + if sp_tensor.shape.ndims == 2 else 1) + seq_length = fc_old._sequence_length_from_sparse_tensor( + sp_tensor, num_elements=num_elements) + + return fc.SequenceDenseColumn.TensorSequenceLengthPair( + dense_tensor=dense_tensor, sequence_length=seq_length) + +# pylint: enable=protected-access diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py new file mode 100644 index 00000000000000..5ecd85807c55e5 --- /dev/null +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py @@ -0,0 +1,1507 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for sequential_feature_column.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as sfc_old +from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column_v2 as sfc +from tensorflow.python.feature_column import feature_column as fc_old +from tensorflow.python.feature_column import feature_column_v2 as fc +from tensorflow.python.feature_column.feature_column import _LazyBuilder +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test +from tensorflow.python.training import monitored_session + + +class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args_a': { + # example 0, ids [2] + # example 1, ids [0, 1] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (2, 0, 1), + 'dense_shape': (2, 2)}, + 'sparse_input_args_b': { + # example 0, ids [1] + # example 1, ids [2, 0] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (1, 2, 0), + 'dense_shape': (2, 2)}, + 'expected_input_layer': [ + # example 0, ids_a [2], ids_b [1] + [[5., 6., 14., 15., 16.], [0., 0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [2, 0] + [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]],], + 'expected_sequence_length': [1, 2]}, + {'testcase_name': '3D', + 'sparse_input_args_a': { + # feature 0, ids [[2], [0, 1]] + # feature 1, ids [[0, 0], [1]] + 'indices': ( + (0, 0, 0), (0, 1, 0), (0, 1, 1), + (1, 0, 0), (1, 0, 1), (1, 1, 0)), + 'values': (2, 0, 1, 0, 0, 1), + 'dense_shape': (2, 2, 2)}, + 'sparse_input_args_b': { + # feature 0, ids [[1, 1], [1]] + # feature 1, ids [[2], [0]] + 'indices': ((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)), + 'values': (1, 1, 1, 2, 0), + 'dense_shape': (2, 2, 2)}, + 'expected_input_layer': [ + # feature 0, [a: 2, -, b: 1, 1], [a: 0, 1, b: 1, -] + [[5., 6., 14., 15., 16.], [2., 3., 14., 15., 16.]], + # feature 1, [a: 0, 0, b: 2, -], [a: 1, -, b: 0, -] + [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]]], + 'expected_sequence_length': [2, 2]}, + ) + def test_embedding_column( + self, sparse_input_args_a, sparse_input_args_b, expected_input_layer, + expected_sequence_length): + + sparse_input_a = sparse_tensor.SparseTensorValue(**sparse_input_args_a) + sparse_input_b = sparse_tensor.SparseTensorValue(**sparse_input_args_b) + vocabulary_size = 3 + embedding_dimension_a = 2 + embedding_values_a = ( + (1., 2.), # id 0 + (3., 4.), # id 1 + (5., 6.) # id 2 + ) + embedding_dimension_b = 3 + embedding_values_b = ( + (11., 12., 13.), # id 0 + (14., 15., 16.), # id 1 + (17., 18., 19.) # id 2 + ) + def _get_initializer(embedding_dimension, embedding_values): + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + return _initializer + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column_a = fc_old.embedding_column( + categorical_column_a, dimension=embedding_dimension_a, + initializer=_get_initializer(embedding_dimension_a, embedding_values_a)) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + embedding_column_b = fc_old.embedding_column( + categorical_column_b, dimension=embedding_dimension_b, + initializer=_get_initializer(embedding_dimension_b, embedding_values_b)) + + input_layer, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + # Test that columns are reordered alphabetically. + feature_columns=[embedding_column_b, embedding_column_a]) + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('sequence_input_layer/aaa_embedding/embedding_weights:0', + 'sequence_input_layer/bbb_embedding/embedding_weights:0'), + tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values_a, global_vars[0].eval(session=sess)) + self.assertAllEqual(embedding_values_b, global_vars[1].eval(session=sess)) + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_embedding_column_with_non_sequence_categorical(self): + """Tests that error is raised for non-sequence embedding column.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = fc_old.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column_a = fc_old.embedding_column( + categorical_column_a, dimension=2) + + with self.assertRaisesRegexp( + ValueError, + r'In embedding_column: aaa_embedding\. categorical_column must be of ' + r'type _SequenceCategoricalColumn to use sequence_input_layer\.'): + _, _ = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[embedding_column_a]) + + def test_shared_embedding_column(self): + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [1] + # example 1, ids [2, 0] + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 0), + dense_shape=(2, 2)) + + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 4.), # id 1 + (5., 6.) # id 2 + ) + + def _get_initializer(embedding_dimension, embedding_values): + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + return _initializer + + expected_input_layer = [ + # example 0, ids_a [2], ids_b [1] + [[5., 6., 3., 4.], [0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [2, 0] + [[1., 2., 5., 6.], [3., 4., 1., 2.]], + ] + expected_sequence_length = [1, 2] + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + # Test that columns are reordered alphabetically. + shared_embedding_columns = fc_old.shared_embedding_columns( + [categorical_column_b, categorical_column_a], + dimension=embedding_dimension, + initializer=_get_initializer(embedding_dimension, embedding_values)) + + input_layer, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + feature_columns=shared_embedding_columns) + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('sequence_input_layer/aaa_bbb_shared_embedding/embedding_weights:0',), + tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_shared_embedding_column_with_non_sequence_categorical(self): + """Tests that error is raised for non-sequence shared embedding column.""" + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = fc_old.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = fc_old.categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + shared_embedding_columns = fc_old.shared_embedding_columns( + [categorical_column_a, categorical_column_b], dimension=2) + + with self.assertRaisesRegexp( + ValueError, + r'In embedding_column: aaa_shared_embedding\. categorical_column must ' + r'be of type _SequenceCategoricalColumn to use sequence_input_layer\.'): + _, _ = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b + }, + feature_columns=shared_embedding_columns) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args_a': { + # example 0, ids [2] + # example 1, ids [0, 1] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (2, 0, 1), + 'dense_shape': (2, 2)}, + 'sparse_input_args_b': { + # example 0, ids [1] + # example 1, ids [1, 0] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (1, 1, 0), + 'dense_shape': (2, 2)}, + 'expected_input_layer': [ + # example 0, ids_a [2], ids_b [1] + [[0., 0., 1., 0., 1.], [0., 0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [1, 0] + [[1., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]]], + 'expected_sequence_length': [1, 2]}, + {'testcase_name': '3D', + 'sparse_input_args_a': { + # feature 0, ids [[2], [0, 1]] + # feature 1, ids [[0, 0], [1]] + 'indices': ( + (0, 0, 0), (0, 1, 0), (0, 1, 1), + (1, 0, 0), (1, 0, 1), (1, 1, 0)), + 'values': (2, 0, 1, 0, 0, 1), + 'dense_shape': (2, 2, 2)}, + 'sparse_input_args_b': { + # feature 0, ids [[1, 1], [1]] + # feature 1, ids [[1], [0]] + 'indices': ((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)), + 'values': (1, 1, 1, 1, 0), + 'dense_shape': (2, 2, 2)}, + 'expected_input_layer': [ + # feature 0, [a: 2, -, b: 1, 1], [a: 0, 1, b: 1, -] + [[0., 0., 1., 0., 2.], [1., 1., 0., 0., 1.]], + # feature 1, [a: 0, 0, b: 1, -], [a: 1, -, b: 0, -] + [[2., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]]], + 'expected_sequence_length': [2, 2]}, + ) + def test_indicator_column( + self, sparse_input_args_a, sparse_input_args_b, expected_input_layer, + expected_sequence_length): + sparse_input_a = sparse_tensor.SparseTensorValue(**sparse_input_args_a) + sparse_input_b = sparse_tensor.SparseTensorValue(**sparse_input_args_b) + + vocabulary_size_a = 3 + vocabulary_size_b = 2 + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size_a) + indicator_column_a = fc_old.indicator_column(categorical_column_a) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size_b) + indicator_column_b = fc_old.indicator_column(categorical_column_b) + input_layer, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + # Test that columns are reordered alphabetically. + feature_columns=[indicator_column_b, indicator_column_a]) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_indicator_column_with_non_sequence_categorical(self): + """Tests that error is raised for non-sequence categorical column.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = fc_old.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column_a = fc_old.indicator_column(categorical_column_a) + + with self.assertRaisesRegexp( + ValueError, + r'In indicator_column: aaa_indicator\. categorical_column must be of ' + r'type _SequenceCategoricalColumn to use sequence_input_layer\.'): + _, _ = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[indicator_column_a]) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args': { + # example 0, values [0., 1] + # example 1, [10.] + 'indices': ((0, 0), (0, 1), (1, 0)), + 'values': (0., 1., 10.), + 'dense_shape': (2, 2)}, + 'expected_input_layer': [ + [[0.], [1.]], + [[10.], [0.]]], + 'expected_sequence_length': [2, 1]}, + {'testcase_name': '3D', + 'sparse_input_args': { + # feature 0, ids [[20, 3], [5]] + # feature 1, ids [[3], [8]] + 'indices': ((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)), + 'values': (20, 3, 5., 3., 8.), + 'dense_shape': (2, 2, 2)}, + 'expected_input_layer': [ + [[20.], [3.], [5.], [0.]], + [[3.], [0.], [8.], [0.]]], + 'expected_sequence_length': [2, 2]}, + ) + def test_numeric_column( + self, sparse_input_args, expected_input_layer, expected_sequence_length): + sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) + + numeric_column = sfc_old.sequence_numeric_column('aaa') + + input_layer, sequence_length = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[numeric_column]) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args': { + # example 0, values [0., 1., 2., 3., 4., 5., 6., 7.] + # example 1, [10., 11., 12., 13.] + 'indices': ((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), + (0, 7), (1, 0), (1, 1), (1, 2), (1, 3)), + 'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + 'dense_shape': (2, 8)}, + 'expected_input_layer': [ + # The output of numeric_column._get_dense_tensor should be flattened. + [[0., 1., 2., 3.], [4., 5., 6., 7.]], + [[10., 11., 12., 13.], [0., 0., 0., 0.]]], + 'expected_sequence_length': [2, 1]}, + {'testcase_name': '3D', + 'sparse_input_args': { + # example 0, values [[0., 1., 2., 3.]], [[4., 5., 6., 7.]] + # example 1, [[10., 11., 12., 13.], []] + 'indices': ((0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3), + (0, 1, 0), (0, 1, 1), (0, 1, 2), (0, 1, 3), + (1, 0, 0), (1, 0, 1), (1, 0, 2), (1, 0, 3)), + 'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + 'dense_shape': (2, 2, 4)}, + 'expected_input_layer': [ + # The output of numeric_column._get_dense_tensor should be flattened. + [[0., 1., 2., 3.], [4., 5., 6., 7.]], + [[10., 11., 12., 13.], [0., 0., 0., 0.]]], + 'expected_sequence_length': [2, 1]}, + ) + def test_numeric_column_multi_dim( + self, sparse_input_args, expected_input_layer, expected_sequence_length): + """Tests sequence_input_layer for multi-dimensional numeric_column.""" + sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) + + numeric_column = sfc_old.sequence_numeric_column('aaa', shape=(2, 2)) + + input_layer, sequence_length = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[numeric_column]) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_sequence_length_not_equal(self): + """Tests that an error is raised when sequence lengths are not equal.""" + # Input a with sequence_length = [2, 1] + sparse_input_a = sparse_tensor.SparseTensorValue( + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + # Input b with sequence_length = [1, 1] + sparse_input_b = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0)), + values=(1., 10.), + dense_shape=(2, 2)) + numeric_column_a = sfc_old.sequence_numeric_column('aaa') + numeric_column_b = sfc_old.sequence_numeric_column('bbb') + + _, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + feature_columns=[numeric_column_a, numeric_column_b]) + + with monitored_session.MonitoredSession() as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'\[Condition x == y did not hold element-wise:\] ' + r'\[x \(sequence_input_layer/aaa/sequence_length:0\) = \] \[2 1\] ' + r'\[y \(sequence_input_layer/bbb/sequence_length:0\) = \] \[1 1\]'): + sess.run(sequence_length) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args': { + # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] + # example 1, [[[10., 11.], [12., 13.]]] + 'indices': ((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), + (0, 7), (1, 0), (1, 1), (1, 2), (1, 3)), + 'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + 'dense_shape': (2, 8)}, + 'expected_shape': [2, 2, 4]}, + {'testcase_name': '3D', + 'sparse_input_args': { + # example 0, values [[0., 1., 2., 3.]], [[4., 5., 6., 7.]] + # example 1, [[10., 11., 12., 13.], []] + 'indices': ((0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3), + (0, 1, 0), (0, 1, 1), (0, 1, 2), (0, 1, 2), + (1, 0, 0), (1, 0, 1), (1, 0, 2), (1, 0, 3)), + 'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + 'dense_shape': (2, 2, 4)}, + 'expected_shape': [2, 2, 4]}, + ) + def test_static_shape_from_tensors_numeric( + self, sparse_input_args, expected_shape): + """Tests that we return a known static shape when we have one.""" + sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) + numeric_column = sfc_old.sequence_numeric_column('aaa', shape=(2, 2)) + + input_layer, _ = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[numeric_column]) + shape = input_layer.get_shape() + self.assertEqual(shape, expected_shape) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + 'indices': ((0, 0), (1, 0), (1, 1), (3, 0)), + 'values': (2, 0, 1, 1), + 'dense_shape': (4, 2)}, + 'expected_shape': [4, 2, 3]}, + {'testcase_name': '3D', + 'sparse_input_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + # example 2, ids [] + # example 3, ids [[1], [0, 2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0), + (3, 0, 0), (3, 1, 0), (3, 1, 1)), + 'values': (2, 0, 1, 2, 1, 0, 2), + 'dense_shape': (4, 2, 2)}, + 'expected_shape': [4, 2, 3]} + ) + def test_static_shape_from_tensors_indicator( + self, sparse_input_args, expected_shape): + """Tests that we return a known static shape when we have one.""" + sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=3) + indicator_column = fc_old.indicator_column(categorical_column) + + input_layer, _ = sfc.sequence_input_layer( + features={'aaa': sparse_input}, feature_columns=[indicator_column]) + shape = input_layer.get_shape() + self.assertEqual(shape, expected_shape) + + +class ConcatenateContextInputTest(test.TestCase, parameterized.TestCase): + """Tests the utility fn concatenate_context_input.""" + + def test_concatenate_context_input(self): + seq_input = ops.convert_to_tensor(np.arange(12).reshape(2, 3, 2)) + context_input = ops.convert_to_tensor(np.arange(10).reshape(2, 5)) + seq_input = math_ops.cast(seq_input, dtype=dtypes.float32) + context_input = math_ops.cast(context_input, dtype=dtypes.float32) + input_layer = sfc.concatenate_context_input(context_input, seq_input) + + expected = np.array([ + [[0, 1, 0, 1, 2, 3, 4], [2, 3, 0, 1, 2, 3, 4], [4, 5, 0, 1, 2, 3, 4]], + [[6, 7, 5, 6, 7, 8, 9], [8, 9, 5, 6, 7, 8, 9], [10, 11, 5, 6, 7, 8, 9]] + ], dtype=np.float32) + with monitored_session.MonitoredSession() as sess: + output = sess.run(input_layer) + self.assertAllEqual(expected, output) + + @parameterized.named_parameters( + {'testcase_name': 'rank_lt_3', + 'seq_input_arg': np.arange(100).reshape(10, 10)}, + {'testcase_name': 'rank_gt_3', + 'seq_input_arg': np.arange(100).reshape(5, 5, 2, 2)} + ) + def test_sequence_input_throws_error(self, seq_input_arg): + seq_input = ops.convert_to_tensor(seq_input_arg) + context_input = ops.convert_to_tensor(np.arange(100).reshape(10, 10)) + seq_input = math_ops.cast(seq_input, dtype=dtypes.float32) + context_input = math_ops.cast(context_input, dtype=dtypes.float32) + with self.assertRaisesRegexp(ValueError, 'sequence_input must have rank 3'): + sfc.concatenate_context_input(context_input, seq_input) + + @parameterized.named_parameters( + {'testcase_name': 'rank_lt_2', + 'context_input_arg': np.arange(100)}, + {'testcase_name': 'rank_gt_2', + 'context_input_arg': np.arange(100).reshape(5, 5, 4)} + ) + def test_context_input_throws_error(self, context_input_arg): + context_input = ops.convert_to_tensor(context_input_arg) + seq_input = ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4)) + seq_input = math_ops.cast(seq_input, dtype=dtypes.float32) + context_input = math_ops.cast(context_input, dtype=dtypes.float32) + with self.assertRaisesRegexp(ValueError, 'context_input must have rank 2'): + sfc.concatenate_context_input(context_input, seq_input) + + def test_integer_seq_input_throws_error(self): + seq_input = ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4)) + context_input = ops.convert_to_tensor(np.arange(100).reshape(10, 10)) + context_input = math_ops.cast(context_input, dtype=dtypes.float32) + with self.assertRaisesRegexp( + TypeError, 'sequence_input must have dtype float32'): + sfc.concatenate_context_input(context_input, seq_input) + + def test_integer_context_input_throws_error(self): + seq_input = ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4)) + context_input = ops.convert_to_tensor(np.arange(100).reshape(10, 10)) + seq_input = math_ops.cast(seq_input, dtype=dtypes.float32) + with self.assertRaisesRegexp( + TypeError, 'context_input must have dtype float32'): + sfc.concatenate_context_input(context_input, seq_input) + + +class InputLayerTest(test.TestCase): + """Tests input_layer with sequence feature columns.""" + + def test_embedding_column(self): + """Tests that error is raised for sequence embedding column.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column_a = fc_old.embedding_column( + categorical_column_a, dimension=2) + + with self.assertRaisesRegexp( + ValueError, + r'In embedding_column: aaa_embedding\. categorical_column must not be ' + r'of type _SequenceCategoricalColumn\.'): + _ = fc_old.input_layer( + features={'aaa': sparse_input}, + feature_columns=[embedding_column_a]) + + def test_indicator_column(self): + """Tests that error is raised for sequence indicator column.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column_a = fc_old.indicator_column(categorical_column_a) + + with self.assertRaisesRegexp( + ValueError, + r'In indicator_column: aaa_indicator\. categorical_column must not be ' + r'of type _SequenceCategoricalColumn\.'): + _ = fc_old.input_layer( + features={'aaa': sparse_input}, + feature_columns=[indicator_column_a]) + + +def _assert_sparse_tensor_value(test_case, expected, actual): + _assert_sparse_tensor_indices_shape(test_case, expected, actual) + + test_case.assertEqual( + np.array(expected.values).dtype, np.array(actual.values).dtype) + test_case.assertAllEqual(expected.values, actual.values) + + +def _assert_sparse_tensor_indices_shape(test_case, expected, actual): + test_case.assertEqual(np.int64, np.array(actual.indices).dtype) + test_case.assertAllEqual(expected.indices, actual.indices) + + test_case.assertEqual(np.int64, np.array(actual.dense_shape).dtype) + test_case.assertAllEqual(expected.dense_shape, actual.dense_shape) + + +class SequenceCategoricalColumnWithIdentityTest( + test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (1, 2, 0), + 'dense_shape': (2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 0), (1, 0, 0), (1, 1, 0)), + 'values': np.array((1, 2, 0), dtype=np.int64), + 'dense_shape': (2, 2, 1)}}, + {'testcase_name': '3D', + 'inputs_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': (6, 7, 8), + 'dense_shape': (2, 2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': (6, 7, 8), + 'dense_shape': (2, 2, 2)}} + ) + def test_get_sparse_tensors(self, inputs_args, expected_args): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + expected = sparse_tensor.SparseTensorValue(**expected_args) + column = sfc.sequence_categorical_column_with_identity('aaa', num_buckets=9) + + id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + _assert_sparse_tensor_value( + self, expected, id_weight_pair.id_tensor.eval(session=sess)) + + +class SequenceCategoricalColumnWithHashBucketTest( + test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': ('omar', 'stringer', 'marlo'), + 'dense_shape': (2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 0), (1, 0, 0), (1, 1, 0)), + # Ignored to avoid hash dependence in test. + 'values': np.array((0, 0, 0), dtype=np.int64), + 'dense_shape': (2, 2, 1)}}, + {'testcase_name': '3D', + 'inputs_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': ('omar', 'stringer', 'marlo'), + 'dense_shape': (2, 2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + # Ignored to avoid hash dependence in test. + 'values': np.array((0, 0, 0), dtype=np.int64), + 'dense_shape': (2, 2, 2)}} + ) + def test_get_sparse_tensors(self, inputs_args, expected_args): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + expected = sparse_tensor.SparseTensorValue(**expected_args) + column = sfc.sequence_categorical_column_with_hash_bucket( + 'aaa', hash_bucket_size=10) + + id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + _assert_sparse_tensor_indices_shape( + self, expected, id_weight_pair.id_tensor.eval(session=sess)) + + +class SequenceCategoricalColumnWithVocabularyFileTest( + test.TestCase, parameterized.TestCase): + + def _write_vocab(self, vocab_strings, file_name): + vocab_file = os.path.join(self.get_temp_dir(), file_name) + with open(vocab_file, 'w') as f: + f.write('\n'.join(vocab_strings)) + return vocab_file + + def setUp(self): + super(SequenceCategoricalColumnWithVocabularyFileTest, self).setUp() + + vocab_strings = ['omar', 'stringer', 'marlo'] + self._wire_vocabulary_file_name = self._write_vocab(vocab_strings, + 'wire_vocabulary.txt') + self._wire_vocabulary_size = 3 + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': ('marlo', 'skywalker', 'omar'), + 'dense_shape': (2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 0), (1, 0, 0), (1, 1, 0)), + 'values': np.array((2, -1, 0), dtype=np.int64), + 'dense_shape': (2, 2, 1)}}, + {'testcase_name': '3D', + 'inputs_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': ('omar', 'skywalker', 'marlo'), + 'dense_shape': (2, 2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': np.array((0, -1, 2), dtype=np.int64), + 'dense_shape': (2, 2, 2)}} + ) + def test_get_sparse_tensors(self, inputs_args, expected_args): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + expected = sparse_tensor.SparseTensorValue(**expected_args) + column = sfc.sequence_categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size) + + id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + _assert_sparse_tensor_value( + self, expected, id_weight_pair.id_tensor.eval(session=sess)) + + def test_get_sparse_tensors_dynamic_zero_length(self): + """Tests _get_sparse_tensors with a dynamic sequence length.""" + inputs = sparse_tensor.SparseTensorValue( + indices=np.zeros((0, 2)), values=[], dense_shape=(2, 0)) + expected = sparse_tensor.SparseTensorValue( + indices=np.zeros((0, 3)), + values=np.array((), dtype=np.int64), + dense_shape=(2, 0, 1)) + column = sfc.sequence_categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size) + input_placeholder_shape = list(inputs.dense_shape) + # Make second dimension (sequence length) dynamic. + input_placeholder_shape[1] = None + input_placeholder = array_ops.sparse_placeholder( + dtypes.string, shape=input_placeholder_shape) + id_weight_pair = column._get_sparse_tensors( + _LazyBuilder({'aaa': input_placeholder})) + + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + result = id_weight_pair.id_tensor.eval( + session=sess, feed_dict={input_placeholder: inputs}) + _assert_sparse_tensor_value( + self, expected, result) + + +class SequenceCategoricalColumnWithVocabularyListTest( + test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': ('marlo', 'skywalker', 'omar'), + 'dense_shape': (2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 0), (1, 0, 0), (1, 1, 0)), + 'values': np.array((2, -1, 0), dtype=np.int64), + 'dense_shape': (2, 2, 1)}}, + {'testcase_name': '3D', + 'inputs_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': ('omar', 'skywalker', 'marlo'), + 'dense_shape': (2, 2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': np.array((0, -1, 2), dtype=np.int64), + 'dense_shape': (2, 2, 2)}} + ) + def test_get_sparse_tensors(self, inputs_args, expected_args): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + expected = sparse_tensor.SparseTensorValue(**expected_args) + column = sfc.sequence_categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo')) + + id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + _assert_sparse_tensor_value( + self, expected, id_weight_pair.id_tensor.eval(session=sess)) + + +class SequenceEmbeddingColumnTest( + test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + 'indices': ((0, 0), (1, 0), (1, 1), (3, 0)), + 'values': (2, 0, 1, 1), + 'dense_shape': (4, 2)}, + 'expected': [ + # example 0, ids [2] + [[7., 11.], [0., 0.]], + # example 1, ids [0, 1] + [[1., 2.], [3., 5.]], + # example 2, ids [] + [[0., 0.], [0., 0.]], + # example 3, ids [1] + [[3., 5.], [0., 0.]]]}, + {'testcase_name': '3D', + 'inputs_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + # example 2, ids [] + # example 3, ids [[1], [0, 2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0), + (3, 0, 0), (3, 1, 0), (3, 1, 1)), + 'values': (2, 0, 1, 2, 1, 0, 2), + 'dense_shape': (4, 2, 2)}, + 'expected': [ + # example 0, ids [[2]] + [[7., 11.], [0., 0.]], + # example 1, ids [[0, 1], [2]] + [[2, 3.5], [7., 11.]], + # example 2, ids [] + [[0., 0.], [0., 0.]], + # example 3, ids [[1], [0, 2]] + [[3., 5.], [4., 6.5]]]} + ) + def test_get_sequence_dense_tensor(self, inputs_args, expected): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + vocabulary_size = 3 + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc_old.embedding_column( + categorical_column, dimension=embedding_dimension, + initializer=_initializer) + + embedding_lookup, _ = embedding_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': inputs})) + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('embedding_weights:0',), tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) + self.assertAllEqual(expected, embedding_lookup.eval(session=sess)) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (2, 0, 1), + 'dense_shape': (2, 2)}, + 'expected_sequence_length': [1, 2]}, + {'testcase_name': '3D', + 'inputs_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), + 'values': (2, 0, 1, 2), + 'dense_shape': (2, 2, 2)}, + 'expected_sequence_length': [1, 2]} + ) + def test_sequence_length(self, inputs_args, expected_sequence_length): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + vocabulary_size = 3 + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc_old.embedding_column( + categorical_column, dimension=2) + + _, sequence_length = embedding_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': inputs})) + + with monitored_session.MonitoredSession() as sess: + sequence_length = sess.run(sequence_length) + self.assertAllEqual(expected_sequence_length, sequence_length) + self.assertEqual(np.int64, sequence_length.dtype) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [] + # example 1, ids [2] + # example 2, ids [0, 1] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [] + indices=((1, 0), (2, 0), (2, 1), (4, 0)), + values=(2, 0, 1, 1), + dense_shape=(6, 2)) + expected_sequence_length = [0, 1, 2, 0, 1, 0] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc_old.embedding_column( + categorical_column, dimension=2) + + _, sequence_length = embedding_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +class SequenceSharedEmbeddingColumnTest(test.TestCase): + + def test_get_sequence_dense_tensor(self): + vocabulary_size = 3 + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 1), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [1] + # example 1, ids [0, 2] + # example 2, ids [0] + # example 3, ids [] + indices=((0, 0), (1, 0), (1, 1), (2, 0)), + values=(1, 0, 2, 0), + dense_shape=(4, 2)) + + expected_lookups_a = [ + # example 0, ids [2] + [[7., 11.], [0., 0.]], + # example 1, ids [0, 1] + [[1., 2.], [3., 5.]], + # example 2, ids [] + [[0., 0.], [0., 0.]], + # example 3, ids [1] + [[3., 5.], [0., 0.]], + ] + + expected_lookups_b = [ + # example 0, ids [1] + [[3., 5.], [0., 0.]], + # example 1, ids [0, 2] + [[1., 2.], [7., 11.]], + # example 2, ids [0] + [[1., 2.], [0., 0.]], + # example 3, ids [] + [[0., 0.], [0., 0.]], + ] + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + shared_embedding_columns = fc_old.shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + initializer=_initializer) + + embedding_lookup_a = shared_embedding_columns[0]._get_sequence_dense_tensor( + _LazyBuilder({ + 'aaa': sparse_input_a + }))[0] + embedding_lookup_b = shared_embedding_columns[1]._get_sequence_dense_tensor( + _LazyBuilder({ + 'bbb': sparse_input_b + }))[0] + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual(('embedding_weights:0',), + tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) + self.assertAllEqual( + expected_lookups_a, embedding_lookup_a.eval(session=sess)) + self.assertAllEqual( + expected_lookups_b, embedding_lookup_b.eval(session=sess)) + + def test_sequence_length(self): + vocabulary_size = 3 + + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + expected_sequence_length_a = [1, 2] + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [0, 2] + # example 1, ids [1] + indices=((0, 0), (0, 1), (1, 0)), + values=(0, 2, 1), + dense_shape=(2, 2)) + expected_sequence_length_b = [2, 1] + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + shared_embedding_columns = fc_old.shared_embedding_columns( + [categorical_column_a, categorical_column_b], dimension=2) + + sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( + _LazyBuilder({ + 'aaa': sparse_input_a + }))[1] + sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor( + _LazyBuilder({ + 'bbb': sparse_input_b + }))[1] + + with monitored_session.MonitoredSession() as sess: + sequence_length_a = sess.run(sequence_length_a) + self.assertAllEqual(expected_sequence_length_a, sequence_length_a) + self.assertEqual(np.int64, sequence_length_a.dtype) + sequence_length_b = sess.run(sequence_length_b) + self.assertAllEqual(expected_sequence_length_b, sequence_length_b) + self.assertEqual(np.int64, sequence_length_b.dtype) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [] + # example 1, ids [2] + # example 2, ids [0, 1] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [] + indices=((1, 0), (2, 0), (2, 1), (4, 0)), + values=(2, 0, 1, 1), + dense_shape=(6, 2)) + expected_sequence_length_a = [0, 1, 2, 0, 1, 0] + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [] + # example 2, ids [] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [0, 1] + indices=((0, 0), (4, 0), (5, 0), (5, 1)), + values=(2, 1, 0, 1), + dense_shape=(6, 2)) + expected_sequence_length_b = [1, 0, 0, 0, 1, 2] + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + + shared_embedding_columns = fc_old.shared_embedding_columns( + [categorical_column_a, categorical_column_b], dimension=2) + + sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( + _LazyBuilder({ + 'aaa': sparse_input_a + }))[1] + sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor( + _LazyBuilder({ + 'bbb': sparse_input_b + }))[1] + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length_a, sequence_length_a.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length_b, sequence_length_b.eval(session=sess)) + + +class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + 'indices': ((0, 0), (1, 0), (1, 1), (3, 0)), + 'values': (2, 0, 1, 1), + 'dense_shape': (4, 2)}, + 'expected': [ + # example 0, ids [2] + [[0., 0., 1.], [0., 0., 0.]], + # example 1, ids [0, 1] + [[1., 0., 0.], [0., 1., 0.]], + # example 2, ids [] + [[0., 0., 0.], [0., 0., 0.]], + # example 3, ids [1] + [[0., 1., 0.], [0., 0., 0.]]]}, + {'testcase_name': '3D', + 'inputs_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + # example 2, ids [] + # example 3, ids [[1], [2, 2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0), + (3, 0, 0), (3, 1, 0), (3, 1, 1)), + 'values': (2, 0, 1, 2, 1, 2, 2), + 'dense_shape': (4, 2, 2)}, + 'expected': [ + # example 0, ids [[2]] + [[0., 0., 1.], [0., 0., 0.]], + # example 1, ids [[0, 1], [2]] + [[1., 1., 0.], [0., 0., 1.]], + # example 2, ids [] + [[0., 0., 0.], [0., 0., 0.]], + # example 3, ids [[1], [2, 2]] + [[0., 1., 0.], [0., 0., 2.]]]} + ) + def test_get_sequence_dense_tensor(self, inputs_args, expected): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + vocabulary_size = 3 + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column = fc_old.indicator_column(categorical_column) + + indicator_tensor, _ = indicator_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': inputs})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected, indicator_tensor.eval(session=sess)) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (2, 0, 1), + 'dense_shape': (2, 2)}, + 'expected_sequence_length': [1, 2]}, + {'testcase_name': '3D', + 'inputs_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), + 'values': (2, 0, 1, 2), + 'dense_shape': (2, 2, 2)}, + 'expected_sequence_length': [1, 2]} + ) + def test_sequence_length(self, inputs_args, expected_sequence_length): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + vocabulary_size = 3 + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column = fc_old.indicator_column(categorical_column) + + _, sequence_length = indicator_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': inputs})) + + with monitored_session.MonitoredSession() as sess: + sequence_length = sess.run(sequence_length) + self.assertAllEqual(expected_sequence_length, sequence_length) + self.assertEqual(np.int64, sequence_length.dtype) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [] + # example 1, ids [2] + # example 2, ids [0, 1] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [] + indices=((1, 0), (2, 0), (2, 1), (4, 0)), + values=(2, 0, 1, 1), + dense_shape=(6, 2)) + expected_sequence_length = [0, 1, 2, 0, 1, 0] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column = fc.indicator_column(categorical_column) + + _, sequence_length = indicator_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +def _get_sequence_dense_tensor(column, features): + return column.get_sequence_dense_tensor( + fc.FeatureTransformationCache(features), None) + + +class SequenceNumericColumnTest(test.TestCase, parameterized.TestCase): + + def test_defaults(self): + a = sfc.sequence_numeric_column('aaa') + self.assertEqual('aaa', a.key) + self.assertEqual('aaa', a.name) + self.assertEqual((1,), a.shape) + self.assertEqual(0., a.default_value) + self.assertEqual(dtypes.float32, a.dtype) + self.assertIsNone(a.normalizer_fn) + + def test_shape_saved_as_tuple(self): + a = sfc.sequence_numeric_column('aaa', shape=[1, 2]) + self.assertEqual((1, 2), a.shape) + + def test_shape_must_be_positive_integer(self): + with self.assertRaisesRegexp(TypeError, 'shape dimensions must be integer'): + sfc.sequence_numeric_column('aaa', shape=[1.0]) + + with self.assertRaisesRegexp( + ValueError, 'shape dimensions must be greater than 0'): + sfc.sequence_numeric_column('aaa', shape=[0]) + + def test_dtype_is_convertible_to_float(self): + with self.assertRaisesRegexp( + ValueError, 'dtype must be convertible to float'): + sfc.sequence_numeric_column('aaa', dtype=dtypes.string) + + def test_normalizer_fn_must_be_callable(self): + with self.assertRaisesRegexp(TypeError, 'must be a callable'): + sfc.sequence_numeric_column('aaa', normalizer_fn='NotACallable') + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + # example 0, values [0., 1] + # example 1, [10.] + 'indices': ((0, 0), (0, 1), (1, 0)), + 'values': (0., 1., 10.), + 'dense_shape': (2, 2)}, + 'expected': [ + [[0.], [1.]], + [[10.], [0.]]]}, + {'testcase_name': '3D', + 'inputs_args': { + # feature 0, ids [[20, 3], [5]] + # feature 1, ids [[3], [8]] + 'indices': ((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)), + 'values': (20, 3, 5., 3., 8.), + 'dense_shape': (2, 2, 2)}, + 'expected': [ + [[20.], [3.], [5.], [0.]], + [[3.], [0.], [8.], [0.]]]}, + ) + def test_get_sequence_dense_tensor(self, inputs_args, expected): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + numeric_column = sfc.sequence_numeric_column('aaa') + + dense_tensor, _ = _get_sequence_dense_tensor( + numeric_column, {'aaa': inputs}) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected, dense_tensor.eval(session=sess)) + + def test_get_sequence_dense_tensor_with_normalizer_fn(self): + + def _increment_two(input_sparse_tensor): + return sparse_ops.sparse_add( + input_sparse_tensor, + sparse_tensor.SparseTensor(((0, 0), (1, 1)), (2.0, 2.0), (2, 2)) + ) + + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + + # Before _increment_two: + # [[0.], [1.]], + # [[10.], [0.]], + # After _increment_two: + # [[2.], [1.]], + # [[10.], [2.]], + expected_dense_tensor = [ + [[2.], [1.]], + [[10.], [2.]], + ] + numeric_column = sfc.sequence_numeric_column( + 'aaa', normalizer_fn=_increment_two) + + dense_tensor, _ = _get_sequence_dense_tensor( + numeric_column, {'aaa': sparse_input}) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args': { + # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] + # example 1, [[[10., 11.], [12., 13.]]] + 'indices': ((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), + (0, 7), (1, 0), (1, 1), (1, 2), (1, 3)), + 'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + 'dense_shape': (2, 8)}, + 'expected_dense_tensor': [ + [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]], + [[[10., 11.], [12., 13.]], [[0., 0.], [0., 0.]]]]}, + {'testcase_name': '3D', + 'sparse_input_args': { + 'indices': ((0, 0, 0), (0, 0, 2), (0, 0, 4), (0, 0, 6), + (0, 1, 0), (0, 1, 2), (0, 1, 4), (0, 1, 6), + (1, 0, 0), (1, 0, 2), (1, 0, 4), (1, 0, 6)), + 'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + 'dense_shape': (2, 2, 8)}, + 'expected_dense_tensor': [ + [[[0., 0.], [1., 0.]], [[2., 0.], [3., 0.]], + [[4., 0.], [5., 0.]], [[6., 0.], [7., 0.]]], + [[[10., 0.], [11., 0.]], [[12., 0.], [13., 0.]], + [[0., 0.], [0., 0.]], [[0., 0.], [0., 0.]]]]}, + ) + def test_get_dense_tensor_multi_dim( + self, sparse_input_args, expected_dense_tensor): + """Tests get_sequence_dense_tensor for multi-dim numeric_column.""" + sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) + numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) + + dense_tensor, _ = _get_sequence_dense_tensor( + numeric_column, {'aaa': sparse_input}) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (2., 0., 1.), + 'dense_shape': (2, 2)}, + 'expected_sequence_length': [1, 2], + 'shape': (1,)}, + {'testcase_name': '3D', + 'inputs_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), + 'values': (2., 0., 1., 2.), + 'dense_shape': (2, 2, 2)}, + 'expected_sequence_length': [1, 2], + 'shape': (1,)}, + {'testcase_name': '2D_with_shape', + 'inputs_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (2., 0., 1.), + 'dense_shape': (2, 2)}, + 'expected_sequence_length': [1, 1], + 'shape': (2,)}, + {'testcase_name': '3D_with_shape', + 'inputs_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), + 'values': (2., 0., 1., 2.), + 'dense_shape': (2, 2, 2)}, + 'expected_sequence_length': [1, 2], + 'shape': (2,)}, + ) + def test_sequence_length(self, inputs_args, expected_sequence_length, shape): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + numeric_column = sfc.sequence_numeric_column('aaa', shape=shape) + + _, sequence_length = _get_sequence_dense_tensor( + numeric_column, {'aaa': inputs}) + + with monitored_session.MonitoredSession() as sess: + sequence_length = sess.run(sequence_length) + self.assertAllEqual(expected_sequence_length, sequence_length) + self.assertEqual(np.int64, sequence_length.dtype) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [] + # example 1, values [[0.], [1.]] + # example 2, [[2.]] + # example 3, values [] + # example 4, [[3.]] + # example 5, values [] + indices=((1, 0), (1, 1), (2, 0), (4, 0)), + values=(0., 1., 2., 3.), + dense_shape=(6, 2)) + expected_sequence_length = [0, 2, 1, 0, 1, 0] + numeric_column = sfc.sequence_numeric_column('aaa') + + _, sequence_length = _get_sequence_dense_tensor( + numeric_column, {'aaa': sparse_input}) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index 249debbdf6dff4..cd747df4d69d2c 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -1,15 +1,16 @@ # Description: # contains parts of TensorFlow that are experimental or unstable and which are not supported. -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - package(default_visibility = [ "//learning/brain:__subpackages__", "//tensorflow:__subpackages__", + "//tensorflow_model_optimization:__subpackages__", ]) +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 95f5ba90aba6ff..e72e50585a3861 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -15,10 +15,6 @@ """Framework utilities. -See the -[Contrib Framework](https://tensorflow.org/api_guides/python/contrib.framework) -guide. - @@assert_same_float_dtype @@assert_scalar @@assert_scalar_int diff --git a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py index 288d4853207176..936b29a4f50794 100644 --- a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py +++ b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py @@ -22,6 +22,8 @@ import ssl import struct +import six + from tensorflow.contrib.ignite.python.ops import gen_dataset_ops from tensorflow.contrib.ignite.python.ops import ignite_op_loader # pylint: disable=unused-import from tensorflow.python.data.ops import dataset_ops @@ -30,6 +32,7 @@ from tensorflow.python.framework import tensor_shape +@six.add_metaclass(abc.ABCMeta) class Readable(object): """Readable abstract class that exposes methods to do reading-related diff --git a/tensorflow/contrib/integrate/python/ops/odes.py b/tensorflow/contrib/integrate/python/ops/odes.py index 7b7ac4f347e30d..b7d77130bd03ba 100644 --- a/tensorflow/contrib/integrate/python/ops/odes.py +++ b/tensorflow/contrib/integrate/python/ops/odes.py @@ -540,7 +540,8 @@ def odeint(func, **options) -class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)): +@six.add_metaclass(abc.ABCMeta) +class _FixedGridIntegrator(object): """Base class for fixed-grid ODE integrators.""" def integrate(self, evol_func, y0, time_grid, dt_grid, steps_on_intervals): diff --git a/tensorflow/contrib/kernel_methods/python/mappers/dense_kernel_mapper.py b/tensorflow/contrib/kernel_methods/python/mappers/dense_kernel_mapper.py index db38b471520e19..04ecdbfdb66257 100644 --- a/tensorflow/contrib/kernel_methods/python/mappers/dense_kernel_mapper.py +++ b/tensorflow/contrib/kernel_methods/python/mappers/dense_kernel_mapper.py @@ -35,7 +35,6 @@ class DenseKernelMapper(object): This class is abstract. Users should not create instances of this class. """ - __metaclass__ = abc.ABCMeta @abc.abstractmethod def map(self, input_tensor): diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index b4fe8cac74cb7d..e6596bfdfb9b15 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -1,15 +1,16 @@ # Description: # contains parts of TensorFlow that are experimental or unstable and which are not supported. -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - package(default_visibility = [ "//learning/brain:__subpackages__", "//tensorflow:__subpackages__", + "//tensorflow_model_optimization:__subpackages__", ]) +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") load("//tensorflow:tensorflow.bzl", "py_test") diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index af8e673f5906ad..32f3006b749e3b 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -14,10 +14,6 @@ # ============================================================================== """Ops for building neural network layers, regularizers, summaries, etc. -See the -[Contrib Layers](https://tensorflow.org/api_guides/python/contrib.layers) -guide. - @@avg_pool2d @@avg_pool3d @@batch_norm diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index 53c8ae5d089364..222404b19db2b9 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -194,6 +194,7 @@ class _DeepEmbeddingLookupArguments( pass +@six.add_metaclass(abc.ABCMeta) class _FeatureColumn(object): """Represents a feature column abstraction. @@ -205,7 +206,6 @@ class _FeatureColumn(object): Following classes (_SparseColumn, _RealValuedColumn, ...) are concrete instances. """ - __metaclass__ = abc.ABCMeta @abc.abstractproperty @deprecation.deprecated( diff --git a/tensorflow/contrib/learn/__init__.py b/tensorflow/contrib/learn/__init__.py index 28a6f5aed99b14..7bf2ac62d76d67 100644 --- a/tensorflow/contrib/learn/__init__.py +++ b/tensorflow/contrib/learn/__init__.py @@ -19,9 +19,6 @@ [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) for migration instructions. -See the [Contrib Learn](https://tensorflow.org/api_guides/python/contrib.learn) -guide. - @@BaseEstimator @@Estimator @@Trainable diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 3efceab3375d3a..8bc869db895b75 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -404,7 +404,6 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, Users should not instantiate or subclass this class. Instead, use an `Estimator`. """ - __metaclass__ = abc.ABCMeta # Note that for Google users, this is overridden with # learn_runner.EstimatorConfig. diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index 9e5aaf3118dfed..8a461a0bd7ba45 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -1368,7 +1368,7 @@ def testMutableHashTableIsOnPs(self): table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val) input_string = constant_op.constant(['brain', 'salad', 'tank']) output = table.lookup(input_string) - self.assertDeviceEqual('/job:ps/task:0', table._table_ref.device) + self.assertDeviceEqual('/job:ps/task:0', table.resource_handle.device) self.assertDeviceEqual('/job:ps/task:0', output.device) def testMutableHashTableIsLocal(self): @@ -1378,7 +1378,7 @@ def testMutableHashTableIsLocal(self): table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val) input_string = constant_op.constant(['brain', 'salad', 'tank']) output = table.lookup(input_string) - self.assertDeviceEqual('', table._table_ref.device) + self.assertDeviceEqual('', table.resource_handle.device) self.assertDeviceEqual('', output.device) def testTaskIsSetOnWorkerWhenJobNameIsSet(self): diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index c6f79e00d5a5a5..c1b97d8b49613e 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -55,6 +55,7 @@ from tensorflow.python.util.deprecation import deprecated +@six.add_metaclass(abc.ABCMeta) class Head(object): """Interface for the head/top of a model. @@ -132,7 +133,6 @@ def _train_op_fn(loss): ... update train_op and hooks in ModelFnOps and return ``` """ - __metaclass__ = abc.ABCMeta @abc.abstractproperty def logits_dimension(self): @@ -504,7 +504,6 @@ def no_op_train_fn(loss): class _SingleHead(Head): """Interface for a single head/top of a model.""" - __metaclass__ = abc.ABCMeta def __init__( self, problem_type, logits_dimension, label_name=None, diff --git a/tensorflow/contrib/learn/python/learn/evaluable.py b/tensorflow/contrib/learn/python/learn/evaluable.py index 10881ca885599b..5dedf548f73d27 100644 --- a/tensorflow/contrib/learn/python/learn/evaluable.py +++ b/tensorflow/contrib/learn/python/learn/evaluable.py @@ -25,7 +25,10 @@ import abc +import six + +@six.add_metaclass(abc.ABCMeta) class Evaluable(object): """Interface for objects that are evaluatable by, e.g., `Experiment`. @@ -33,7 +36,6 @@ class Evaluable(object): [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) for general migration instructions. """ - __metaclass__ = abc.ABCMeta @abc.abstractproperty def model_dir(self): diff --git a/tensorflow/contrib/learn/python/learn/trainable.py b/tensorflow/contrib/learn/python/learn/trainable.py index a1a3f20dcd8cb5..1ea9e5d67a95df 100644 --- a/tensorflow/contrib/learn/python/learn/trainable.py +++ b/tensorflow/contrib/learn/python/learn/trainable.py @@ -25,13 +25,15 @@ import abc +import six + +@six.add_metaclass(abc.ABCMeta) class Trainable(object): """Interface for objects that are trainable by, e.g., `Experiment`. THIS CLASS IS DEPRECATED. """ - __metaclass__ = abc.ABCMeta @abc.abstractmethod def fit(self, diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py index de58db90e9f375..a001555e8f257c 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py @@ -30,7 +30,9 @@ from tensorflow.python.ops import math_ops -class ShardedMutableDenseHashTable(lookup.LookupInterface): +# TODO(rohanj): This should subclass Checkpointable and implement +# _gather_saveables_for_checkpoint. +class ShardedMutableDenseHashTable(object): """A sharded version of MutableDenseHashTable. It is designed to be interface compatible with LookupInterface and @@ -52,9 +54,10 @@ def __init__(self, num_shards=1, checkpoint=True, name='ShardedMutableHashTable'): + self._key_dtype = key_dtype + self._value_dtype = value_dtype with ops.name_scope(name, 'sharded_mutable_hash_table') as scope: - super(ShardedMutableDenseHashTable, self).__init__(key_dtype, - value_dtype, scope) + self._table_name = scope table_shards = [] for i in range(num_shards): table_shards.append( @@ -72,6 +75,10 @@ def __init__(self, self._value_shape = self._table_shards[0]._value_shape # pylint: enable=protected-access + @property + def name(self): + return self._table_name + @property def _num_shards(self): return len(self._table_shards) @@ -106,6 +113,7 @@ def _check_keys(self, keys): keys.get_shape()) def lookup(self, keys, name=None): + """Looks up `keys` in a table, outputs the corresponding values.""" if keys.dtype.base_dtype != self._key_dtype: raise TypeError('Signature mismatch. Keys must be dtype %s, got %s.' % (self._key_dtype, keys.dtype)) @@ -134,6 +142,7 @@ def lookup(self, keys, name=None): return result def insert(self, keys, values, name=None): + """Inserts `keys` in a table.""" self._check_keys(keys) num_shards = self._num_shards if num_shards == 1: diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/BUILD b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/BUILD deleted file mode 100644 index 626f733540264c..00000000000000 --- a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -# Description: -# TensorFlow Lite microcontroller example. - -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 - -load( - "//tensorflow/contrib/lite/experimental/micro/testing:micro_test.bzl", - "tflite_micro_cc_test", -) - -tflite_micro_cc_test( - name = "micro_speech_test", - srcs = [ - "micro_speech_test.cc", - "no_features_data.cc", - "no_features_data.h", - "tiny_conv_model_data.cc", - "tiny_conv_model_data.h", - "yes_features_data.cc", - "yes_features_data.h", - ], - tags = [ - "nomsan", - ], - deps = [ - "//tensorflow/contrib/lite:schema_fbs_version", - "//tensorflow/contrib/lite/experimental/micro:micro_framework", - "//tensorflow/contrib/lite/experimental/micro/kernels:all_ops_resolver", - "//tensorflow/contrib/lite/experimental/micro/kernels:micro_ops", - "//tensorflow/contrib/lite/experimental/micro/testing:micro_test", - "//tensorflow/contrib/lite/schema:schema_fbs", - ], -) diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/BUILD b/tensorflow/contrib/lite/experimental/micro/kernels/BUILD deleted file mode 100644 index a012f950e6f58f..00000000000000 --- a/tensorflow/contrib/lite/experimental/micro/kernels/BUILD +++ /dev/null @@ -1,107 +0,0 @@ -package(default_visibility = [ - "//visibility:public", -]) - -licenses(["notice"]) # Apache 2.0 - -load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") -load( - "//tensorflow/contrib/lite/experimental/micro/testing:micro_test.bzl", - "tflite_micro_cc_test", -) - -cc_library( - name = "micro_ops", - srcs = [ - "depthwise_conv.cc", - "fully_connected.cc", - "softmax.cc", - ], - hdrs = [ - ], - copts = tflite_copts(), - deps = [ - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/experimental/micro:micro_framework", - "//tensorflow/contrib/lite/kernels:kernel_util", - "//tensorflow/contrib/lite/kernels:op_macros", - "//tensorflow/contrib/lite/kernels:padding", - "//tensorflow/contrib/lite/kernels/internal:quantization_util", - "//tensorflow/contrib/lite/kernels/internal:reference_base", - "//tensorflow/contrib/lite/kernels/internal:tensor", - ], -) - -cc_library( - name = "all_ops_resolver", - srcs = [ - "all_ops_resolver.cc", - ], - hdrs = [ - "all_ops_resolver.h", - ], - copts = tflite_copts(), - deps = [ - ":micro_ops", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/experimental/micro:micro_framework", - ], -) - -cc_library( - name = "test_utils", - srcs = [ - ], - hdrs = [ - "test_utils.h", - ], - copts = tflite_copts(), - deps = [ - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/core/api", - "//tensorflow/contrib/lite/experimental/micro:micro_framework", - "//tensorflow/contrib/lite/experimental/micro/testing:micro_test", - ], -) - -tflite_micro_cc_test( - name = "depthwise_conv_test", - srcs = [ - "depthwise_conv_test.cc", - ], - deps = [ - ":all_ops_resolver", - ":test_utils", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/experimental/micro:micro_framework", - "//tensorflow/contrib/lite/experimental/micro/testing:micro_test", - ], -) - -tflite_micro_cc_test( - name = "fully_connected_test", - srcs = [ - "fully_connected_test.cc", - ], - deps = [ - ":all_ops_resolver", - ":test_utils", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/experimental/micro:micro_framework", - "//tensorflow/contrib/lite/experimental/micro/testing:micro_test", - ], -) - -tflite_micro_cc_test( - name = "softmax_test", - srcs = [ - "softmax_test.cc", - ], - deps = [ - ":all_ops_resolver", - ":test_utils", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/experimental/micro:micro_framework", - "//tensorflow/contrib/lite/experimental/micro/testing:micro_test", - ], -) diff --git a/tensorflow/contrib/lite/kernels/reshape_test.cc b/tensorflow/contrib/lite/kernels/reshape_test.cc deleted file mode 100644 index 52d71350d3ba9a..00000000000000 --- a/tensorflow/contrib/lite/kernels/reshape_test.cc +++ /dev/null @@ -1,122 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" - -namespace tflite { -namespace { - -using ::testing::ElementsAreArray; -using ::testing::IsEmpty; - -class ReshapeOpModel : public SingleOpModel { - public: - ReshapeOpModel(std::initializer_list input_shape, - std::initializer_list new_shape, - bool use_shape_input_tensor = false) { - input_ = AddInput(TensorType_FLOAT32); - output_ = AddOutput(TensorType_FLOAT32); - int shape_input_tensor = - use_shape_input_tensor ? AddInput(TensorType_INT32) : -1; - SetBuiltinOp( - BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions, - CreateReshapeOptions(builder_, builder_.CreateVector(new_shape)) - .Union()); - if (use_shape_input_tensor) { - BuildInterpreter({input_shape, GetShape(shape_input_tensor)}); - PopulateTensor(shape_input_tensor, new_shape); - } else { - BuildInterpreter({input_shape}); - } - } - - void SetInput(std::initializer_list data) { - PopulateTensor(input_, data); - } - std::vector GetOutput() { return ExtractVector(output_); } - std::vector GetOutputShape() { return GetTensorShape(output_); } - - private: - int input_; - int output_; -}; - -TEST(ReshapeOpTest, MismatchedDimensions) { - EXPECT_DEATH(ReshapeOpModel({1, 2, 4, 1}, {2, 1}), - "num_input_elements != num_output_elements"); -} - -TEST(ReshapeOpTest, TooManyDimensions) { - EXPECT_DEATH( - ReshapeOpModel({1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 2, 3, 4, 5, 6, 7, 8, 9}), - "Found too many dimensions"); -} - -TEST(ReshapeOpTest, TooManySpecialDimensions) { - EXPECT_DEATH(ReshapeOpModel({1, 2, 4, 1}, {-1, -1, 2, 4}), - "stretch_dim != -1"); -} - -TEST(ReshapeOpTest, SimpleTest) { - ReshapeOpModel m({1, 2, 4, 1}, {2, 2, 2}); - m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); - m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8})); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); -} - -TEST(ReshapeOpTest, ShapeTensorInput) { - ReshapeOpModel m({1, 2, 4, 1}, {2, 2, 2}, /*use_shape_input_tensor=*/true); - m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); - m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8})); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); -} - -TEST(ReshapeOpTest, WithStretchDimension) { - ReshapeOpModel m({1, 2, 4, 1}, {2, 1, -1}); - m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); - m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8})); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 4})); -} - -TEST(ReshapeOpTest, ScalarOutput) { - ReshapeOpModel m({1}, {}); - m.SetInput({3}); - m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); - EXPECT_THAT(m.GetOutputShape(), IsEmpty()); -} - -TEST(ReshapeOpTest, LegacyScalarOutput) { - ReshapeOpModel m({1}, {0}); - m.SetInput({3}); - m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); - EXPECT_THAT(m.GetOutputShape(), IsEmpty()); -} - -} // namespace -} // namespace tflite - -int main(int argc, char** argv) { - ::tflite::LogToStderr(); - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index efc2398d116d2a..893ddd78231c8a 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -1,192 +1,12 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) - -load("//tensorflow:tensorflow.bzl", "py_test") - -filegroup( - name = "interpreter_test_data", - srcs = glob(["**/testdata/*"]), - visibility = ["//tensorflow:__subpackages__"], -) - -py_library( - name = "interpreter", - srcs = [ - "interpreter.py", - ], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/contrib/lite/python/interpreter_wrapper:tensorflow_wrap_interpreter_wrapper", - "//tensorflow/python:util", - "//third_party/py/numpy", - ], -) - -py_test( - name = "interpreter_test", - srcs = ["interpreter_test.py"], - data = [":interpreter_test_data"], - srcs_version = "PY2AND3", - tags = ["no_oss"], - deps = [ - ":interpreter", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform", - "//third_party/py/numpy", - ], -) - -py_binary( - name = "tflite_convert", - srcs = ["tflite_convert.py"], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], - deps = [ - ":lite", - ], -) +licenses(["notice"]) +# DO NOT USE THIS TARGET. TensorFlow Lite has moved to tensorflow/lite. py_library( name = "lite", - srcs = ["lite.py"], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], - deps = [ - ":convert", - ":convert_saved_model", - ":interpreter", - ":lite_constants", - ":op_hint", - "//tensorflow/python:graph_util", - "//tensorflow/python/keras", - "//tensorflow/python/saved_model:constants", - "//tensorflow/python/saved_model:loader", - ], -) - -py_test( - name = "lite_test", - srcs = ["lite_test.py"], - data = ["@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb"], - srcs_version = "PY2AND3", - tags = [ - "no_oss", - "no_windows", - ], - deps = [ - ":lite", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", - ], -) - -py_library( - name = "lite_constants", - srcs = ["lite_constants.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/lite/toco:toco_flags_proto_py", - ], -) - -py_library( - name = "convert", - srcs = ["convert.py"], + srcs = ["__init__.py"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - ":lite_constants", - "//tensorflow/contrib/lite/toco:model_flags_proto_py", - "//tensorflow/contrib/lite/toco:toco_flags_proto_py", - "//tensorflow/contrib/lite/toco/python:tensorflow_wrap_toco", - "//tensorflow/contrib/lite/toco/python:toco_from_protos", - "//tensorflow/python:platform", - ], -) - -py_library( - name = "op_hint", - srcs = ["op_hint.py"], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/contrib/framework:framework_py", - "//tensorflow/contrib/graph_editor:graph_editor_py", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:framework", - "//tensorflow/python:platform", - "//tensorflow/python:util", - ], -) - -py_test( - name = "convert_test", - srcs = ["convert_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no-internal-py3", - "no_oss", - ], - deps = [ - ":convert", - ":interpreter", - ":op_hint", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:platform_test", - "//tensorflow/python:session", - ], -) - -py_library( - name = "convert_saved_model", - srcs = ["convert_saved_model.py"], - srcs_version = "PY2AND3", - visibility = ["//tensorflow/contrib/lite:__subpackages__"], - deps = [ - ":convert", - "//tensorflow/python:graph_util", - "//tensorflow/python:platform", - "//tensorflow/python/saved_model", - ], -) - -py_binary( - name = "create_custom_op", - srcs = ["create_custom_op.py"], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/contrib/framework:framework_py", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:platform", - "@absl_py//absl/flags", - ], -) - -py_test( - name = "convert_saved_model_test", - srcs = ["convert_saved_model_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_oss", - "no_windows", - ], - visibility = ["//visibility:public"], - deps = [ - ":convert_saved_model", - "//tensorflow/python:client_testlib", - "//tensorflow/python:layers", - "//tensorflow/python:nn", - "//tensorflow/python:platform_test", - "//tensorflow/python:session", - "//tensorflow/python/estimator:estimator_py", - "//tensorflow/python/keras", - "//tensorflow/python/ops/losses", - "//tensorflow/python/saved_model", + "//tensorflow/lite/python:lite", ], ) diff --git a/tensorflow/contrib/lite/python/__init__.py b/tensorflow/contrib/lite/python/__init__.py new file mode 100644 index 00000000000000..27b1ffb251e764 --- /dev/null +++ b/tensorflow/contrib/lite/python/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow import lite + +import warnings as _warnings + +WARNING = ("WARNING: TF Lite has moved from tf.contrib.lite to tf.lite. Please " + "update your imports. This will be a breaking error in TensorFlow " + "version 2.0.") +_warnings.warn(WARNING, PendingDeprecationWarning) diff --git a/tensorflow/contrib/lite/testdata/add.bin b/tensorflow/contrib/lite/testdata/add.bin deleted file mode 100644 index aef0fe3d82c9d9..00000000000000 Binary files a/tensorflow/contrib/lite/testdata/add.bin and /dev/null differ diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index 5abef822e82a1e..e52fb5ab1431e0 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -42,7 +42,6 @@ from tensorflow.python.ops.lookup_ops import TextFileInitializer from tensorflow.python.ops.lookup_ops import TextFileStringTableInitializer # pylint: enable=unused-import -from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.training.saver import BaseSaverBuilder from tensorflow.python.util.deprecation import deprecated @@ -92,7 +91,7 @@ def index_table_from_tensor(mapping, The bucket ID range is `[mapping size, mapping size + num_oov_buckets - 1]`. The underlying table must be initialized by calling - `tf.tables_initializer.run()` or `table.init.run()` once. + `tf.tables_initializer.run()` or `table.initializer.run()` once. Elements in `mapping` cannot have duplicates, otherwise when executing the table initializer op, it will throw a `FailedPreconditionError`. @@ -203,7 +202,7 @@ def index_to_string_table_from_tensor(mapping, default_value="UNK", name=None): (an out-of-vocabulary entry) is assigned the `default_value` The underlying table must be initialized by calling - `tf.tables_initializer.run()` or `table.init.run()` once. + `tf.tables_initializer.run()` or `table.initializer.run()` once. Elements in `mapping` cannot have duplicates, otherwise when executing the table initializer op, it will throw a `FailedPreconditionError`. @@ -289,7 +288,7 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None): return table.lookup(tensor) -class MutableHashTable(LookupInterface, checkpointable.CheckpointableBase): +class MutableHashTable(LookupInterface): """A generic mutable hash table implementation. Data can be inserted by calling the insert method and removed by calling the @@ -339,43 +338,56 @@ def __init__(self, self._default_value = ops.convert_to_tensor(default_value, dtype=value_dtype) self._value_shape = self._default_value.get_shape() + self._checkpoint = checkpoint + self._key_dtype = key_dtype + self._value_dtype = value_dtype + self._name = name - executing_eagerly = context.executing_eagerly() - if executing_eagerly and shared_name is None: + if context.executing_eagerly() and shared_name is None: # TODO(allenl): This will leak memory due to kernel caching by the # shared_name attribute value (but is better than the alternative of # sharing everything by default when executing eagerly; hopefully creating # tables in a loop is uncommon). shared_name = "table_%d" % (ops.uid(),) + self._shared_name = shared_name + super(MutableHashTable, self).__init__(key_dtype, value_dtype) + + self._resource_handle = self.create_resource() + if checkpoint: + saveable = MutableHashTable._Saveable(self, name) + if not context.executing_eagerly(): + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + + def create_resource(self): # The table must be shared if checkpointing is requested for multi-worker # training to work correctly. Use the node name if no shared_name has been # explicitly specified. - use_node_name_sharing = checkpoint and shared_name is None + use_node_name_sharing = self._checkpoint and self._shared_name is None if self._default_value.get_shape().ndims == 0: - self._table_ref = gen_lookup_ops.mutable_hash_table_v2( - shared_name=shared_name, + table_ref = gen_lookup_ops.mutable_hash_table_v2( + shared_name=self._shared_name, use_node_name_sharing=use_node_name_sharing, - key_dtype=key_dtype, - value_dtype=value_dtype, - name=name) + key_dtype=self._key_dtype, + value_dtype=self._value_dtype, + name=self._name) else: - self._table_ref = gen_lookup_ops.mutable_hash_table_of_tensors_v2( - shared_name=shared_name, + table_ref = gen_lookup_ops.mutable_hash_table_of_tensors_v2( + shared_name=self._shared_name, use_node_name_sharing=use_node_name_sharing, - key_dtype=key_dtype, - value_dtype=value_dtype, + key_dtype=self._key_dtype, + value_dtype=self._value_dtype, value_shape=self._default_value.get_shape(), - name=name) - if executing_eagerly: - op_name = None + name=self._name) + + if context.executing_eagerly(): + self._table_name = None else: - op_name = self._table_ref.op.name.split("/")[-1] - super(MutableHashTable, self).__init__(key_dtype, value_dtype, - op_name) + self._table_name = table_ref.op.name.split("/")[-1] + return table_ref - if checkpoint: - saveable = MutableHashTable._Saveable(self, name) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + @property + def name(self): + return self._table_name def size(self, name=None): """Compute the number of elements in this table. @@ -386,10 +398,11 @@ def size(self, name=None): Returns: A scalar tensor containing the number of elements in this table. """ - with ops.name_scope(name, "%s_Size" % self._name, - [self._table_ref]) as name: - with ops.colocate_with(self._table_ref): - return gen_lookup_ops.lookup_table_size_v2(self._table_ref, name=name) + with ops.name_scope(name, "%s_Size" % self.name, + [self.resource_handle]) as name: + with ops.colocate_with(self.resource_handle): + return gen_lookup_ops.lookup_table_size_v2( + self.resource_handle, name=name) def remove(self, keys, name=None): """Removes `keys` and its associated values from the table. @@ -411,11 +424,12 @@ def remove(self, keys, name=None): raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % (self._key_dtype, keys.dtype)) - with ops.name_scope(name, "%s_lookup_table_remove" % self._name, - (self._table_ref, keys, self._default_value)) as name: + with ops.name_scope( + name, "%s_lookup_table_remove" % self.name, + (self.resource_handle, keys, self._default_value)) as name: # pylint: disable=protected-access op = gen_lookup_ops.lookup_table_remove_v2( - self._table_ref, keys, name=name) + self.resource_handle, keys, name=name) return op @@ -436,12 +450,13 @@ def lookup(self, keys, name=None): Raises: TypeError: when `keys` do not match the table data types. """ - with ops.name_scope(name, "%s_lookup_table_find" % self._name, - (self._table_ref, keys, self._default_value)) as name: + with ops.name_scope( + name, "%s_lookup_table_find" % self.name, + (self.resource_handle, keys, self._default_value)) as name: keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") - with ops.colocate_with(self._table_ref): + with ops.colocate_with(self.resource_handle): values = gen_lookup_ops.lookup_table_find_v2( - self._table_ref, keys, self._default_value, name=name) + self.resource_handle, keys, self._default_value, name=name) return values def insert(self, keys, values, name=None): @@ -461,14 +476,14 @@ def insert(self, keys, values, name=None): TypeError: when `keys` or `values` doesn't match the table data types. """ - with ops.name_scope(name, "%s_lookup_table_insert" % self._name, - [self._table_ref, keys, values]) as name: + with ops.name_scope(name, "%s_lookup_table_insert" % self.name, + [self.resource_handle, keys, values]) as name: keys = ops.convert_to_tensor(keys, self._key_dtype, name="keys") values = ops.convert_to_tensor(values, self._value_dtype, name="values") - with ops.colocate_with(self._table_ref): + with ops.colocate_with(self.resource_handle): # pylint: disable=protected-access op = gen_lookup_ops.lookup_table_insert_v2( - self._table_ref, keys, values, name=name) + self.resource_handle, keys, values, name=name) return op def export(self, name=None): @@ -481,11 +496,11 @@ def export(self, name=None): A pair of tensors with the first tensor containing all keys and the second tensors containing all values in the table. """ - with ops.name_scope(name, "%s_lookup_table_export_values" % self._name, - [self._table_ref]) as name: - with ops.colocate_with(self._table_ref): + with ops.name_scope(name, "%s_lookup_table_export_values" % self.name, + [self.resource_handle]) as name: + with ops.colocate_with(self.resource_handle): exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( - self._table_ref, self._key_dtype, self._value_dtype, name=name) + self.resource_handle, self._key_dtype, self._value_dtype, name=name) return exported_keys, exported_values def _gather_saveables_for_checkpoint(self): @@ -507,12 +522,12 @@ def __init__(self, table, name): def restore(self, restored_tensors, restored_shapes): del restored_shapes # unused # pylint: disable=protected-access - with ops.colocate_with(self.op._table_ref): + with ops.colocate_with(self.op.resource_handle): return gen_lookup_ops.lookup_table_import_v2( - self.op._table_ref, restored_tensors[0], restored_tensors[1]) + self.op.resource_handle, restored_tensors[0], restored_tensors[1]) -class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase): +class MutableDenseHashTable(LookupInterface): """A generic mutable hash table implementation using tensors as backing store. Data can be inserted by calling the insert method and removed by calling the @@ -581,42 +596,55 @@ def __init__(self, """ self._default_value = ops.convert_to_tensor( default_value, dtype=value_dtype, name="default_value") + self._key_dtype = key_dtype + self._value_dtype = value_dtype + self._initial_num_buckets = initial_num_buckets self._value_shape = self._default_value.get_shape() + self._checkpoint = checkpoint + self._name = name - # The table must be shared if checkpointing is requested for multi-worker - # training to work correctly. Use the node name if no shared_name has been - # explicitly specified. - use_node_name_sharing = checkpoint and shared_name is None - empty_key = ops.convert_to_tensor( + self._empty_key = ops.convert_to_tensor( empty_key, dtype=key_dtype, name="empty_key") - deleted_key = ops.convert_to_tensor( + self._deleted_key = ops.convert_to_tensor( deleted_key, dtype=key_dtype, name="deleted_key") - executing_eagerly = context.executing_eagerly() - if executing_eagerly and shared_name is None: + if context.executing_eagerly() and shared_name is None: # TODO(allenl): This will leak memory due to kernel caching by the # shared_name attribute value (but is better than the alternative of # sharing everything by default when executing eagerly; hopefully creating # tables in a loop is uncommon). shared_name = "table_%d" % (ops.uid(),) - self._table_ref = gen_lookup_ops.mutable_dense_hash_table_v2( - empty_key=empty_key, - deleted_key=deleted_key, - shared_name=shared_name, + self._shared_name = shared_name + super(MutableDenseHashTable, self).__init__(key_dtype, value_dtype) + + self._resource_handle = self.create_resource() + if checkpoint: + saveable = MutableDenseHashTable._Saveable(self, name) + if not context.executing_eagerly(): + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + + def create_resource(self): + # The table must be shared if checkpointing is requested for multi-worker + # training to work correctly. Use the node name if no shared_name has been + # explicitly specified. + use_node_name_sharing = self._checkpoint and self._shared_name is None + table_ref = gen_lookup_ops.mutable_dense_hash_table_v2( + empty_key=self._empty_key, + deleted_key=self._deleted_key, + shared_name=self._shared_name, use_node_name_sharing=use_node_name_sharing, - value_dtype=value_dtype, + value_dtype=self._value_dtype, value_shape=self._value_shape, - initial_num_buckets=initial_num_buckets, - name=name) - if executing_eagerly: - op_name = None + initial_num_buckets=self._initial_num_buckets, + name=self._name) + if context.executing_eagerly(): + self._table_name = None else: - op_name = self._table_ref.op.name.split("/")[-1] - super(MutableDenseHashTable, self).__init__( - key_dtype, value_dtype, op_name) + self._table_name = table_ref.op.name.split("/")[-1] + return table_ref - if checkpoint: - saveable = MutableDenseHashTable._Saveable(self, name) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + @property + def name(self): + return self._table_name def size(self, name=None): """Compute the number of elements in this table. @@ -627,10 +655,11 @@ def size(self, name=None): Returns: A scalar tensor containing the number of elements in this table. """ - with ops.name_scope(name, "%s_Size" % self._name, - [self._table_ref]) as name: - with ops.colocate_with(self._table_ref): - return gen_lookup_ops.lookup_table_size_v2(self._table_ref, name=name) + with ops.name_scope(name, "%s_Size" % self.name, + [self.resource_handle]) as name: + with ops.colocate_with(self.resource_handle): + return gen_lookup_ops.lookup_table_size_v2( + self.resource_handle, name=name) def lookup(self, keys, name=None): """Looks up `keys` in a table, outputs the corresponding values. @@ -649,12 +678,12 @@ def lookup(self, keys, name=None): Raises: TypeError: when `keys` do not match the table data types. """ - with ops.name_scope(name, "%s_lookup_table_find" % self._name, - [self._table_ref, keys]) as name: + with ops.name_scope(name, "%s_lookup_table_find" % self.name, + [self.resource_handle, keys]) as name: keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") - with ops.colocate_with(self._table_ref): + with ops.colocate_with(self.resource_handle): values = gen_lookup_ops.lookup_table_find_v2( - self._table_ref, keys, self._default_value, name=name) + self.resource_handle, keys, self._default_value, name=name) return values @@ -675,14 +704,14 @@ def insert(self, keys, values, name=None): TypeError: when `keys` or `values` doesn't match the table data types. """ - with ops.name_scope(name, "%s_lookup_table_insert" % self._name, - [self._table_ref, keys, values]) as name: + with ops.name_scope(name, "%s_lookup_table_insert" % self.name, + [self.resource_handle, keys, values]) as name: keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") values = ops.convert_to_tensor( values, dtype=self._value_dtype, name="values") - with ops.colocate_with(self._table_ref): + with ops.colocate_with(self.resource_handle): op = gen_lookup_ops.lookup_table_insert_v2( - self._table_ref, keys, values, name=name) + self.resource_handle, keys, values, name=name) return op def remove(self, keys, name=None): @@ -705,11 +734,12 @@ def remove(self, keys, name=None): raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % (self._key_dtype, keys.dtype)) - with ops.name_scope(name, "%s_lookup_table_remove" % self._name, - (self._table_ref, keys, self._default_value)) as name: + with ops.name_scope( + name, "%s_lookup_table_remove" % self.name, + (self.resource_handle, keys, self._default_value)) as name: # pylint: disable=protected-access op = gen_lookup_ops.lookup_table_remove_v2( - self._table_ref, keys, name=name) + self.resource_handle, keys, name=name) return op @@ -723,11 +753,11 @@ def export(self, name=None): A pair of tensors with the first tensor containing all keys and the second tensors containing all values in the table. """ - with ops.name_scope(name, "%s_lookup_table_export_values" % self._name, - [self._table_ref]) as name: - with ops.colocate_with(self._table_ref): + with ops.name_scope(name, "%s_lookup_table_export_values" % self.name, + [self.resource_handle]) as name: + with ops.colocate_with(self.resource_handle): exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( - self._table_ref, self._key_dtype, self._value_dtype, name=name) + self.resource_handle, self._key_dtype, self._value_dtype, name=name) return exported_keys, exported_values @@ -751,6 +781,6 @@ def __init__(self, table, name): def restore(self, restored_tensors, restored_shapes): del restored_shapes # unused # pylint: disable=protected-access - with ops.colocate_with(self.op._table_ref): + with ops.colocate_with(self.op.resource_handle): return gen_lookup_ops.lookup_table_import_v2( - self.op._table_ref, restored_tensors[0], restored_tensors[1]) + self.op.resource_handle, restored_tensors[0], restored_tensors[1]) diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 35b0d1bc4447cd..5e99ef460518fa 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -50,7 +50,7 @@ def testHashTable(self): values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup.HashTable( lookup.KeyValueTensorInitializer(keys, values), default_val) - table.init.run() + table.initializer.run() self.assertAllEqual(3, table.size().eval()) @@ -74,7 +74,7 @@ def testHashTableFindHighRank(self): values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup.HashTable( lookup.KeyValueTensorInitializer(keys, values), default_val) - table.init.run() + table.initializer.run() self.assertAllEqual(3, table.size().eval()) @@ -94,7 +94,7 @@ def testHashTableInitWithPythonArrays(self): lookup.KeyValueTensorInitializer( keys, values, value_dtype=dtypes.int64), default_val) - table.init.run() + table.initializer.run() self.assertAllEqual(3, table.size().eval()) @@ -111,7 +111,7 @@ def testHashTableInitWithNumPyArrays(self): values = np.array([0, 1, 2], dtype=np.int64) table = lookup.HashTable( lookup.KeyValueTensorInitializer(keys, values), default_val) - table.init.run() + table.initializer.run() self.assertAllEqual(3, table.size().eval()) @@ -156,7 +156,7 @@ def testHashTableWithTensorDefault(self): values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup.HashTable( lookup.KeyValueTensorInitializer(keys, values), default_val) - table.init.run() + table.initializer.run() input_string = constant_op.constant(["brain", "salad", "tank"]) output = table.lookup(input_string) @@ -171,7 +171,7 @@ def testHashTableWithSparseTensorInput(self): values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup.HashTable( lookup.KeyValueTensorInitializer(keys, values), default_val) - table.init.run() + table.initializer.run() sp_indices = [[0, 0], [0, 1], [1, 0]] sp_shape = [2, 2] @@ -194,7 +194,7 @@ def testSignatureMismatch(self): values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup.HashTable( lookup.KeyValueTensorInitializer(keys, values), default_val) - table.init.run() + table.initializer.run() # Ref types do not produce a lookup signature mismatch. input_string_ref = variables.Variable("brain") @@ -238,10 +238,10 @@ def testInitializeTwice(self): values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup.HashTable( lookup.KeyValueTensorInitializer(keys, values), default_val) - table.init.run() + table.initializer.run() with self.assertRaisesOpError("Table already initialized"): - table.init.run() + table.initializer.run() def testInitializationWithInvalidDimensions(self): with self.cached_session(): @@ -273,13 +273,13 @@ def testMultipleSessions(self): # Init the table in the first session. with session1: - table.init.run() + table.initializer.run() self.assertAllEqual(3, table.size().eval()) # Init the table in the second session and verify that we do not get a # "Table already initialized" error. with session2: - table.init.run() + table.initializer.run() self.assertAllEqual(3, table.size().eval()) def testHashTableInt32String(self): @@ -289,7 +289,7 @@ def testHashTableInt32String(self): values = constant_op.constant(["brain", "salad", "surgery"]) table = lookup.HashTable( lookup.KeyValueTensorInitializer(keys, values), default_val) - table.init.run() + table.initializer.run() input_tensor = constant_op.constant([0, 1, -1]) output = table.lookup(input_tensor) @@ -1669,7 +1669,7 @@ def test_index_table_from_file_with_vocab_size_too_large(self): table = lookup.index_table_from_file( vocabulary_file=vocabulary_file, vocab_size=4) self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "Invalid vocab_size", table.init.run) + "Invalid vocab_size", table.initializer.run) def test_index_table_from_file_with_vocab_size(self): vocabulary_file = self._createVocabFile("f2i_vocab8.txt") @@ -1717,14 +1717,14 @@ def test_string(self): init = lookup.KeyValueTensorInitializer( ("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64) table = lookup.HashTable(init, default_value=-1) - table.init.run() + table.initializer.run() def test_int64(self): with ops.Graph().as_default(), self.cached_session(): init = lookup.KeyValueTensorInitializer( (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64) table = lookup.HashTable(init, default_value=-1) - table.init.run() + table.initializer.run() def test_int32(self): with ops.Graph().as_default(), self.cached_session(): @@ -1733,7 +1733,7 @@ def test_int32(self): table = lookup.HashTable(init, default_value=-1) with self.assertRaisesRegexp( errors_impl.OpError, "No OpKernel was registered"): - table.init.run() + table.initializer.run() class IndexTableFromTensor(test.TestCase): @@ -2021,7 +2021,7 @@ def testInitializeStringTable(self): dtypes.int64, lookup.TextFileIndex.LINE_NUMBER), default_value) - self.evaluate(table.init) + self.evaluate(table.initializer) output = table.lookup(constant_op.constant(["brain", "salad", "tank"])) @@ -2040,7 +2040,7 @@ def testInitializeInt64Table(self): dtypes.int64, lookup.TextFileIndex.LINE_NUMBER), default_value) - table.init.run() + table.initializer.run() output = table.lookup( constant_op.constant((42, 1, 11), dtype=dtypes.int64)) @@ -2059,7 +2059,7 @@ def testInitializeIndexTable(self): lookup.TextFileInitializer(vocabulary_file, dtypes.int64, key_index, dtypes.string, value_index), default_value) - table.init.run() + table.initializer.run() input_values = constant_op.constant([0, 1, 2, 3], dtypes.int64) output = table.lookup(input_values) @@ -2081,7 +2081,7 @@ def testMultiColumn(self): lookup.TextFileInitializer(vocabulary_file, dtypes.string, key_index, dtypes.int64, value_index), default_value) - table.init.run() + table.initializer.run() input_string = constant_op.constant(["brain", "salad", "surgery"]) output = table.lookup(input_string) @@ -2103,7 +2103,7 @@ def testInvalidDataTypeInMultiColumn(self): key_index, dtypes.int64, value_index), default_value) with self.assertRaisesOpError("is not a valid"): - table.init.run() + table.initializer.run() def testInvalidDataType(self): vocabulary_file = self._createVocabFile("one_column_3.txt") @@ -2131,7 +2131,7 @@ def testInvalidIndex(self): default_value) with self.assertRaisesOpError("Invalid number of columns"): - table.init.run() + table.initializer.run() def testInitializeSameTableWithMultipleNodes(self): vocabulary_file = self._createVocabFile("one_column_5.txt") @@ -2200,7 +2200,7 @@ def testInitializeWithVocabSize(self): default_value) # Initialize from file. - table1.init.run() + table1.initializer.run() self.assertEquals(vocab_size, table1.size().eval()) vocabulary_file2 = self._createVocabFile("one_column7.txt") @@ -2215,7 +2215,7 @@ def testInitializeWithVocabSize(self): vocab_size=vocab_size), default_value) with self.assertRaisesOpError("Invalid vocab_size"): - table2.init.run() + table2.initializer.run() vocab_size = 1 vocabulary_file3 = self._createVocabFile("one_column3.txt") @@ -2230,7 +2230,7 @@ def testInitializeWithVocabSize(self): default_value) # Smaller vocab size reads only vocab_size records. - table3.init.run() + table3.initializer.run() self.assertEquals(vocab_size, table3.size().eval()) def testFeedVocabularyName(self): @@ -2248,11 +2248,11 @@ def testFeedVocabularyName(self): # Initialize with non existing file (old_file.txt) should fail. # TODO(yleon): Update message, which might change per FileSystem. with self.assertRaisesOpError("old_file.txt"): - table.init.run() + table.initializer.run() # Initialize the model feeding the vocabulary file. filenames = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS) - table.init.run(feed_dict={filenames[0]: vocabulary_file}) + table.initializer.run(feed_dict={filenames[0]: vocabulary_file}) input_string = constant_op.constant(["brain", "salad", "tank"]) output = table.lookup(input_string) @@ -2294,7 +2294,7 @@ def testIdToStringTable(self): vocab_file, vocab_size=vocab_size), default_value) - table.init.run() + table.initializer.run() input_values = constant_op.constant([0, 1, 2, 3], dtypes.int64) @@ -2311,7 +2311,7 @@ def testStringToIdTable(self): lookup.TextFileIdTableInitializer( vocab_file, vocab_size=vocab_size), default_value) - table.init.run() + table.initializer.run() input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"]) @@ -2329,7 +2329,7 @@ def testInt64ToIdTable(self): lookup.TextFileIdTableInitializer( vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64), default_value) - table.init.run() + table.initializer.run() out = table.lookup( constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64)) @@ -2358,7 +2358,7 @@ def testStringIdTableWithHashBuckets(self): default_value), oov_buckets) - table.init.run() + table.initializer.run() input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"]) @@ -2380,7 +2380,7 @@ def testInt32IdTableWithHashBuckets(self): oov_buckets, key_dtype=dtypes.int32) - table.init.run() + table.initializer.run() values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int32) @@ -2401,7 +2401,7 @@ def testInt64IdTableWithHashBuckets(self): default_value), oov_buckets) - table.init.run() + table.initializer.run() values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64) @@ -2416,7 +2416,7 @@ def testStringIdTableWithOnlyHashBucket(self): # Set a table that only uses hash buckets, for each input value returns # an id calculated by fingerprint("input") mod oov_buckets. table = lookup.IdTableWithHashBuckets(None, oov_buckets) - table.init.run() + table.initializer.run() values = constant_op.constant(("brain", "salad", "surgery")) @@ -2438,7 +2438,7 @@ def testInt32IdTableWithOnlyHashBucket(self): # an id calculated by fingerprint("input") mod oov_buckets. table = lookup.IdTableWithHashBuckets( None, oov_buckets, key_dtype=dtypes.int32) - table.init.run() + table.initializer.run() input_string = constant_op.constant([42, 1, -1000], dtype=dtypes.int32) @@ -2520,7 +2520,7 @@ def testIdTableWithHashBucketsInitializationAcrossSessions(self): shared_name=shared_name), oov_buckets) - table1.init.run() + table1.initializer.run() input_string_1 = constant_op.constant( ["brain", "salad", "surgery", "UNK"]) @@ -2536,7 +2536,7 @@ def testIdTableWithHashBucketsInitializationAcrossSessions(self): oov_buckets = 1 # Underlying lookup table already initialized in previous session. - # No need to call table2.init.run() + # No need to call table2.initializer.run() table2 = lookup.IdTableWithHashBuckets( lookup.HashTable( lookup.TextFileIdTableInitializer( @@ -2605,7 +2605,7 @@ def testSparseTensor(self): vocab_file, vocab_size=3), -1), 1) - table.init.run() + table.initializer.run() sp_ids = table.lookup(sp_features) @@ -2634,7 +2634,7 @@ def testInt32SparseTensor(self): -1), 1, key_dtype=dtypes.int32) - table.init.run() + table.initializer.run() sp_ids = table.lookup(sp_features) @@ -2663,7 +2663,7 @@ def testInt64SparseTensor(self): -1), 1, key_dtype=dtypes.int64) - table.init.run() + table.initializer.run() sp_ids = table.lookup(sp_features) diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index 5b29f0185f275b..7ea6e34cf50ed8 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -740,22 +740,22 @@ ifeq ($(WITH_TFLITE_FLEX), true) TF_CC_SRCS += $(EAGER_CC_SRCS) TF_LITE_CORE_CC_ALL_SRCS := \ - $(wildcard tensorflow/contrib/lite/*.cc) \ - $(wildcard tensorflow/contrib/lite/*.c) \ - $(wildcard tensorflow/contrib/lite/c/*.c) \ - $(wildcard tensorflow/contrib/lite/core/api/*.cc) + $(wildcard tensorflow/lite/*.cc) \ + $(wildcard tensorflow/lite/*.c) \ + $(wildcard tensorflow/lite/c/*.c) \ + $(wildcard tensorflow/lite/core/api/*.cc) TF_LITE_CORE_CC_ALL_SRCS += \ - $(wildcard tensorflow/contrib/lite/kernels/*.cc) \ - $(wildcard tensorflow/contrib/lite/kernels/internal/*.cc) \ - $(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.cc) \ - $(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.cc) \ + $(wildcard tensorflow/lite/kernels/*.cc) \ + $(wildcard tensorflow/lite/kernels/internal/*.cc) \ + $(wildcard tensorflow/lite/kernels/internal/optimized/*.cc) \ + $(wildcard tensorflow/lite/kernels/internal/reference/*.cc) \ $(PROFILER_SRCS) \ - $(wildcard tensorflow/contrib/lite/kernels/*.c) \ - $(wildcard tensorflow/contrib/lite/kernels/internal/*.c) \ - $(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.c) \ - $(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.c) \ - $(wildcard tensorflow/contrib/lite/delegates/flex/*.cc) + $(wildcard tensorflow/lite/kernels/*.c) \ + $(wildcard tensorflow/lite/kernels/internal/*.c) \ + $(wildcard tensorflow/lite/kernels/internal/optimized/*.c) \ + $(wildcard tensorflow/lite/kernels/internal/reference/*.c) \ + $(wildcard tensorflow/lite/delegates/flex/*.cc) # Hack. This shouldn't be here? TF_LITE_CORE_CC_ALL_SRCS += \ @@ -764,14 +764,14 @@ ifeq ($(WITH_TFLITE_FLEX), true) # Remove any duplicates. TF_LITE_CORE_CC_ALL_SRCS := $(sort $(TF_LITE_CORE_CC_ALL_SRCS)) TF_LITE_CORE_CC_EXCLUDE_SRCS := \ - $(wildcard tensorflow/contrib/lite/*test.cc) \ - $(wildcard tensorflow/contrib/lite/*/*test.cc) \ - $(wildcard tensorflow/contrib/lite/*/*/*test.cc) \ - $(wildcard tensorflow/contrib/lite/*/*/*/*test.cc) \ - $(wildcard tensorflow/contrib/lite/kernels/test_util.cc) \ - $(wildcard tensorflow/contrib/lite/delegates/flex/test_util.cc) \ - $(wildcard tensorflow/contrib/lite/nnapi_delegate.cc) \ - $(wildcard tensorflow/contrib/lite/mmap_allocation_disabled.cc) + $(wildcard tensorflow/lite/*test.cc) \ + $(wildcard tensorflow/lite/*/*test.cc) \ + $(wildcard tensorflow/lite/*/*/*test.cc) \ + $(wildcard tensorflow/lite/*/*/*/*test.cc) \ + $(wildcard tensorflow/lite/kernels/test_util.cc) \ + $(wildcard tensorflow/lite/delegates/flex/test_util.cc) \ + $(wildcard tensorflow/lite/nnapi_delegate.cc) \ + $(wildcard tensorflow/lite/mmap_allocation_disabled.cc) # Filter out all the excluded files. TF_LITE_CC_SRCS := $(filter-out $(TF_LITE_CORE_CC_EXCLUDE_SRCS), $(TF_LITE_CORE_CC_ALL_SRCS)) diff --git a/tensorflow/contrib/makefile/README.md b/tensorflow/contrib/makefile/README.md index 6c3b02e12b3082..1293e59cbcba86 100644 --- a/tensorflow/contrib/makefile/README.md +++ b/tensorflow/contrib/makefile/README.md @@ -142,7 +142,7 @@ First, download and install JetPack for Android version 3.2 or greater from [Nvi git clone https://github.com/tensorflow/tensorflow.git cd tensorflow JETPACK=$HOME/JetPack_Android_3.2 -TEGRA_LIBS="$JETPACK/cuDNN/aarch64/cuda/lib64/libcudnn.so $JETPACK/cuda-9.0/extras/CUPTI/lib64/libcupti.so $JETPACK/cuda/targets/aarch64-linux-androideabi/lib64/libcufft.so" +TEGRA_LIBS="$JETPACK/cuDNN/aarch64/cuda/lib64/libcudnn.so $JETPACK/cuda/extras/CUPTI/lib64/libcupti.so $JETPACK/cuda/targets/aarch64-linux-androideabi/lib64/libcufft.so" ``` #### Building all CUDA-enabled native binaries: diff --git a/tensorflow/contrib/makefile/build_all_android.sh b/tensorflow/contrib/makefile/build_all_android.sh index fb9e77ae1bcfc3..dc29694449729f 100755 --- a/tensorflow/contrib/makefile/build_all_android.sh +++ b/tensorflow/contrib/makefile/build_all_android.sh @@ -34,7 +34,7 @@ echo "********************************************************************" echo "TensorFlow Lite is the recommended library for mobile and embedded machine learning inference." echo "You are currently using an older version. Please switch over to TensorFlow Lite." echo "" -echo "Link to the code: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite" +echo "Link to the code: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite" echo "********************************************************************" echo "" diff --git a/tensorflow/contrib/makefile/build_all_ios.sh b/tensorflow/contrib/makefile/build_all_ios.sh index 9cee4f5916d3b4..9a8059ce50041f 100755 --- a/tensorflow/contrib/makefile/build_all_ios.sh +++ b/tensorflow/contrib/makefile/build_all_ios.sh @@ -35,7 +35,7 @@ echo "********************************************************************" echo "TensorFlow Lite is the recommended library for mobile and embedded machine learning inference." echo "You are currently using an older version. Please switch over to TensorFlow Lite." echo "" -echo "Link to the code: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite" +echo "Link to the code: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite" echo "********************************************************************" echo "" diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index eab93f2cc5ed3d..24a4a03f232272 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -248,6 +248,7 @@ tensorflow/core/kernels/spectrogram_op.cc tensorflow/core/kernels/split_lib_cpu.cc tensorflow/core/kernels/split_op.cc tensorflow/core/kernels/split_v_op.cc +tensorflow/core/kernels/stack.cc tensorflow/core/kernels/stack_ops.cc tensorflow/core/kernels/strided_slice_op.cc tensorflow/core/kernels/strided_slice_op_inst_0.cc diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index 15d1171e113e6f..f789c83e005ab7 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -22,6 +22,8 @@ import abc +import six + from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.framework import dtypes @@ -40,6 +42,7 @@ from tensorflow.python.util import nest +@six.add_metaclass(abc.ABCMeta) class _OptimizableVariable(object): """Interface for abstracting over variables in the optimizers.""" diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD index 94a2d9672dba74..b35c4fde1a2c70 100644 --- a/tensorflow/contrib/quantize/BUILD +++ b/tensorflow/contrib/quantize/BUILD @@ -205,6 +205,10 @@ py_test( size = "large", srcs = ["python/quantize_parameterized_test.py"], srcs_version = "PY2AND3", + # TODO(b/118839526): Re-enable msan test. + tags = [ + "nomsan", + ], deps = [ ":fold_batch_norms", ":quantize", diff --git a/tensorflow/contrib/quantize/README.md b/tensorflow/contrib/quantize/README.md index 0ab19c91bb036a..a1f2b5902663e9 100644 --- a/tensorflow/contrib/quantize/README.md +++ b/tensorflow/contrib/quantize/README.md @@ -145,7 +145,7 @@ Mobilenet-v2, and Inception-v3) using this tool: Our pre-trained models are available in the -TensorFlow Lite model repository. The code used to generate +TensorFlow Lite model repository. The code used to generate these models is available. diff --git a/tensorflow/contrib/quantize/python/graph_matcher.py b/tensorflow/contrib/quantize/python/graph_matcher.py index aa3ca991c060b2..cfbf5bf30f9ba2 100644 --- a/tensorflow/contrib/quantize/python/graph_matcher.py +++ b/tensorflow/contrib/quantize/python/graph_matcher.py @@ -21,7 +21,10 @@ import abc import itertools +import six + +@six.add_metaclass(abc.ABCMeta) class Pattern(object): """The parent class of all patterns (e.g. OpTypePattern and OneofPattern).""" diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 92ca3f20395441..338923f75125ed 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -91,48 +91,50 @@ def Quantize(graph, # If `scope` is given, only quantize it if the consumer of weights # (the layer op) is in the right scope. - _InsertQuantOp( - context, - 'weights_quant', - layer_match.weight_tensor.op, - input_to_ops_map.ConsumerOperations(layer_match.weight_tensor.op), - is_training, - moving_avg=False, - ema_decay=ema_decay, - quant_delay=quant_delay, - narrow_range=True, - vars_collection=vars_collection, - bits=weight_bits, - symmetric=symmetric, - consumer_scope=scope) + if layer_match.weight_tensor is not None: + _InsertQuantOp( + context, + 'weights_quant', + layer_match.weight_tensor.op, + input_to_ops_map.ConsumerOperations(layer_match.weight_tensor.op), + is_training, + moving_avg=False, + ema_decay=ema_decay, + quant_delay=quant_delay, + narrow_range=True, + vars_collection=vars_collection, + bits=weight_bits, + symmetric=symmetric, + consumer_scope=scope) # Quantize the activations. - consumer_ops = input_to_ops_map.ConsumerOperations( - layer_match.activation_op) - add_context = context - if layer_match.bypass_op: - pattern_match_result = re.search(r'^(.*)/([^/]+)', context) - if pattern_match_result is not None: - add_context = pattern_match_result.group(1) - else: - add_context = '' - # If `scope` is given, only quantize it if the producer of weights - # (usually it's the layer op) is in the right scope. - _InsertQuantOp( - add_context, - 'act_quant', - layer_match.activation_op, - consumer_ops, - is_training, - moving_avg=True, - ema_decay=ema_decay, - quant_delay=quant_delay, - vars_collection=vars_collection, - bits=activation_bits, - symmetric=symmetric, - init_min=0.0, - producer_scope=scope) - quantized_ops.add(layer_match.activation_op) + if layer_match.activation_op is not None: + consumer_ops = input_to_ops_map.ConsumerOperations( + layer_match.activation_op) + add_context = context + if layer_match.bypass_op: + pattern_match_result = re.search(r'^(.*)/([^/]+)', context) + if pattern_match_result is not None: + add_context = pattern_match_result.group(1) + else: + add_context = '' + # If `scope` is given, only quantize it if the producer of weights + # (usually it's the layer op) is in the right scope. + _InsertQuantOp( + add_context, + 'act_quant', + layer_match.activation_op, + consumer_ops, + is_training, + moving_avg=True, + ema_decay=ema_decay, + quant_delay=quant_delay, + vars_collection=vars_collection, + bits=activation_bits, + symmetric=symmetric, + init_min=0.0, + producer_scope=scope) + quantized_ops.add(layer_match.activation_op) # Quantize the inputs and output to the bypass (if it exists). The input to # the bypass is the bias add, and the output is the activation. @@ -547,6 +549,8 @@ def _FindLayersToQuantize(graph): for match_result in sep_conv_matcher.match_graph(graph): layer_op = match_result.get_op(layer_pattern) weight_tensor = match_result.get_tensor(weight_identity_pattern) + if weight_tensor is None: + weight_tensor = match_result.get_tensor(weight_resource_var_pattern) activation_op = match_result.get_op(layer_pattern) if layer_op not in matched_layer_set: matched_layer_set.add(layer_op) diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py index 212d902a3c6479..5681a213fe5eaf 100644 --- a/tensorflow/contrib/quantize/python/quantize_test.py +++ b/tensorflow/contrib/quantize/python/quantize_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function +from tensorflow.contrib.framework.python.ops import variables from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.quantize.python import quantize from tensorflow.python.framework import ops @@ -26,6 +27,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn from tensorflow.python.ops import nn_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variable_scope @@ -525,6 +527,43 @@ def _TestSkipReshapeQuantization(self, is_training): self.assertTrue( 'FakeQuantWithMinMaxVars' in [i.op.type for i in reshape.op.inputs]) + def testSeparableConvWithResourceVar(self): + graph = ops.Graph() + with graph.as_default(): + with variable_scope.variable_scope('', use_resource=True): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + kernel_size, depth_multiplier = 3, 1 + depthwise_shape = [kernel_size, kernel_size, depth, depth_multiplier] + depthwise_weights = variables.model_variable( + 'depthwise_weights', shape=depthwise_shape) + strides = [1, 1, 1, 1] + with variable_scope.variable_scope('depthwise_conv_1'): + conv1 = nn.depthwise_conv2d( + input1, depthwise_weights, strides, padding='SAME') + with variable_scope.variable_scope('depthwise_conv_2'): + conv2 = nn.depthwise_conv2d( + conv1, depthwise_weights, strides, padding='SAME') + math_ops.add(conv2, input1, name='add') + + quantize.Quantize(graph, True) + + # Test that the weights and activations of all convs have been quantized. + quant_node_name = 'FakeQuantWithMinMaxVars' + weights_quant = graph.get_operation_by_name( + 'depthwise_conv_1/weights_quant/' + quant_node_name) + self.assertEqual(weights_quant.type, quant_node_name) + act_quant = graph.get_operation_by_name('depthwise_conv_1/act_quant/' + + quant_node_name) + self.assertEqual(act_quant.type, quant_node_name) + + weights_quant = graph.get_operation_by_name( + 'depthwise_conv_2/weights_quant/' + quant_node_name) + self.assertEqual(weights_quant.type, quant_node_name) + act_quant = graph.get_operation_by_name('depthwise_conv_2/act_quant/' + + quant_node_name) + self.assertEqual(act_quant.type, quant_node_name) + def _WeightInit(self, stddev): """Returns truncated normal variable initializer. diff --git a/tensorflow/contrib/resampler/BUILD b/tensorflow/contrib/resampler/BUILD index b3f32b8f34e7b9..38fcca03116721 100644 --- a/tensorflow/contrib/resampler/BUILD +++ b/tensorflow/contrib/resampler/BUILD @@ -50,6 +50,7 @@ tf_kernel_library( prefix = "resampler_ops", deps = [ ":resampler_ops_op_lib", + "//tensorflow/compiler/tf2xla/kernels:resampler_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", ], diff --git a/tensorflow/contrib/resampler/ops/resampler_ops.cc b/tensorflow/contrib/resampler/ops/resampler_ops.cc index 5ab212032e50ac..f785d4ee5fcd63 100644 --- a/tensorflow/contrib/resampler/ops/resampler_ops.cc +++ b/tensorflow/contrib/resampler/ops/resampler_ops.cc @@ -25,7 +25,7 @@ REGISTER_OP("Resampler") .Input("data: T") .Input("warp: T") .Output("output: T") - .Attr("T: {half, float, double}") + .Attr("T: {half, bfloat16, float, double}") .SetShapeFn([](InferenceContext* c) { ShapeHandle data; ShapeHandle warp; @@ -48,7 +48,7 @@ REGISTER_OP("ResamplerGrad") .Input("grad_output: T") .Output("grad_data: T") .Output("grad_warp: T") - .Attr("T: {half, float, double}") + .Attr("T: {half, bfloat16, float, double}") .SetShapeFn([](InferenceContext* c) { c->set_output(0, c->input(0)); c->set_output(1, c->input(1)); diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index 391df8cdb4b1c6..e124867415f94f 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -196,6 +196,7 @@ cuda_py_tests( srcs = ["python/kernel_tests/lstm_ops_test.py"], additional_deps = [ ":rnn_py", + "@absl_py//absl/testing:parameterized", "//third_party/py/numpy", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/rnn/__init__.py b/tensorflow/contrib/rnn/__init__.py index 026bf08ced33cf..cbc8af5350276b 100644 --- a/tensorflow/contrib/rnn/__init__.py +++ b/tensorflow/contrib/rnn/__init__.py @@ -14,8 +14,6 @@ # ============================================================================== """RNN Cells and additional RNN operations. -See [Contrib RNN](https://tensorflow.org/api_guides/python/contrib.rnn) guide. - @@RNNCell @@LayerRNNCell diff --git a/tensorflow/contrib/rnn/kernels/blas_gemm.h b/tensorflow/contrib/rnn/kernels/blas_gemm.h index 9535a76566748e..d37210d4b81203 100644 --- a/tensorflow/contrib/rnn/kernels/blas_gemm.h +++ b/tensorflow/contrib/rnn/kernels/blas_gemm.h @@ -32,15 +32,26 @@ struct TensorCuBlasGemm { const T* b, int ldb, float beta, T* c, int ldc); }; +template +struct gemm_compute_type { + typedef T type; +}; + +template <> +struct gemm_compute_type { + typedef float type; +}; + template struct TensorBlasGemm; template struct TensorBlasGemm { static void compute(OpKernelContext* ctx, const Device& d, bool transa, - bool transb, float alpha, + bool transb, typename gemm_compute_type::type alpha, typename TTypes::ConstMatrix a, - typename TTypes::ConstMatrix b, float beta, + typename TTypes::ConstMatrix b, + typename gemm_compute_type::type beta, typename TTypes::Matrix c) { int64 m = c.dimensions()[0]; int64 n = c.dimensions()[1]; @@ -55,19 +66,23 @@ struct TensorBlasGemm { template struct TensorBlasGemm { static void compute(OpKernelContext* ctx, const Device& d, bool transa, - bool transb, T alpha, typename TTypes::ConstMatrix a, - typename TTypes::ConstMatrix b, T beta, + bool transb, typename gemm_compute_type::type alpha, + typename TTypes::ConstMatrix a, + typename TTypes::ConstMatrix b, + typename gemm_compute_type::type beta, typename TTypes::Matrix c) { Eigen::array, 1> contract_pairs; contract_pairs[0] = Eigen::IndexPair(transa == false, transb == true); - if (alpha == T(1) && beta == T(0)) { + if (alpha == typename gemm_compute_type::type(1.f) && + beta == typename gemm_compute_type::type(0.f)) { c.device(d) = a.contract(b, contract_pairs); - } else if (alpha == T(1) && beta == T(1)) { + } else if (alpha == typename gemm_compute_type::type(1.f) && + beta == typename gemm_compute_type::type(1.f)) { c.device(d) += a.contract(b, contract_pairs); } else { - c.device(d) = c.constant(alpha) * a.contract(b, contract_pairs) + - c.constant(beta) * c; + c.device(d) = c.constant(T(alpha)) * a.contract(b, contract_pairs) + + c.constant(T(beta)) * c; } } }; diff --git a/tensorflow/contrib/rnn/kernels/gru_ops.h b/tensorflow/contrib/rnn/kernels/gru_ops.h index 3e2cb39e64bb3f..38be58fa104f8b 100644 --- a/tensorflow/contrib/rnn/kernels/gru_ops.h +++ b/tensorflow/contrib/rnn/kernels/gru_ops.h @@ -88,7 +88,9 @@ struct GRUBlockCellFprop : public GRUCell { typename TTypes::ConstMatrix const_x_h_prev(x_h_prev.data(), x_h_prev.dimensions()); TensorBlasGemm::compute( - ctx, d, false, false, T(1), const_x_h_prev, w_ru, T(0), r_u_bar); + ctx, d, false, false, typename gemm_compute_type::type(1.f), + const_x_h_prev, w_ru, typename gemm_compute_type::type(0.f), + r_u_bar); // Creating a bias matrix for adding by broadcasting 'b_ru' Eigen::array broadcast_shape({batch_size_, 1}); @@ -107,7 +109,8 @@ struct GRUBlockCellFprop : public GRUCell { typename TTypes::ConstMatrix const_x_h_prevr(x_h_prevr.data(), x_h_prevr.dimensions()); TensorBlasGemm::compute( - ctx, d, false, false, T(1), const_x_h_prevr, w_c, T(0), c); + ctx, d, false, false, typename gemm_compute_type::type(1.f), + const_x_h_prevr, w_c, typename gemm_compute_type::type(0.f), c); Eigen::array b_c_shape({1, b_c.dimensions()[0]}); c.device(d) += (b_c.reshape(b_c_shape).broadcast(broadcast_shape)); @@ -148,9 +151,10 @@ struct GRUBlockCellBprop : public GRUCell { // [2nd_component_of_d_x d_h_prevr] = d_c_bar X w_c^T typename TTypes::ConstMatrix const_d_c_bar(d_c_bar.data(), d_c_bar.dimensions()); - TensorBlasGemm::compute(ctx, d, false, true, T(1), - const_d_c_bar, w_c, T(0), - d_x_comp2_and_h_prevr); + TensorBlasGemm::compute( + ctx, d, false, true, typename gemm_compute_type::type(1.f), + const_d_c_bar, w_c, typename gemm_compute_type::type(0.f), + d_x_comp2_and_h_prevr); d_hr.device(d) = d_x_comp2_and_h_prevr.slice(h_offsets(), h_extends()); d_r_bar.device(d) = (d_hr * h_prev * r) * (r.constant(T(1)) - r); @@ -164,7 +168,8 @@ struct GRUBlockCellBprop : public GRUCell { typename TTypes::ConstMatrix const_d_r_bar_u_bar( d_r_bar_u_bar.data(), d_r_bar_u_bar.dimensions()); TensorBlasGemm::compute( - ctx, d, false, true, T(1), const_d_r_bar_u_bar, w_ru, T(0), + ctx, d, false, true, typename gemm_compute_type::type(1.f), + const_d_r_bar_u_bar, w_ru, typename gemm_compute_type::type(0.f), d_x_comp1_and_h_prev_comp1); // d_x = d_x_comp1 + d_x_comp2 diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.cc b/tensorflow/contrib/rnn/kernels/lstm_ops.cc index ee08d306f84baa..d369bc12ae88da 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops.cc +++ b/tensorflow/contrib/rnn/kernels/lstm_ops.cc @@ -61,7 +61,8 @@ void LSTMBlockCellFpropWithEigen( // states1 = xh * w + b typename TTypes::ConstMatrix const_xh(xh.data(), xh.dimensions()); TensorBlasGemm::compute( - ctx, d, false, false, T(1), const_xh, w, T(0), icfo); + ctx, d, false, false, typename gemm_compute_type::type(1.f), const_xh, + w, typename gemm_compute_type::type(0.f), icfo); Eigen::array b_shape({1, b.dimensions()[0]}); Eigen::array broadcast_shape({cell.batch_size(), 1}); icfo.device(d) += b.reshape(b_shape).broadcast(broadcast_shape); @@ -87,11 +88,11 @@ void LSTMBlockCellFpropWithEigen( if (use_peephole) { auto f_peep = cs_prev * wcf.reshape(p_shape).broadcast(p_broadcast_shape); f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) + - f.constant(forget_bias) + f_peep) + f.constant(T(forget_bias)) + f_peep) .sigmoid(); } else { f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) + - f.constant(forget_bias)) + f.constant(T(forget_bias))) .sigmoid(); } @@ -100,7 +101,7 @@ void LSTMBlockCellFpropWithEigen( if (cell_clip > 0.0f) { cs.device(d) = - cs.binaryExpr(cs.constant(cell_clip), Eigen::scalar_clip_op()); + cs.binaryExpr(cs.constant(T(cell_clip)), Eigen::scalar_clip_op()); } // co = tanh(cs) @@ -225,6 +226,7 @@ void LSTMBlockCellBpropWithEigen( template struct LSTMBlockCellBprop; DEFINE_CPU_SPECS(float); +DEFINE_CPU_SPECS(Eigen::half); #undef DEFINE_CPU_SPECS } // namespace functor @@ -373,7 +375,7 @@ class LSTMBlockCellOp : public OpKernel { Name("LSTMBlockCell").Device(DEVICE_CPU).TypeConstraint("T"), \ LSTMBlockCellOp); REGISTER_KERNEL(float); -// REGISTER_KERNEL(double); +REGISTER_KERNEL(Eigen::half); #undef REGISTER_KERNEL #if GOOGLE_CUDA @@ -398,7 +400,6 @@ namespace functor { DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(Eigen::half); -// DECLARE_GPU_SPEC(double); #undef DECLARE_GPU_SPEC } // end namespace functor @@ -661,7 +662,7 @@ class LSTMBlockCellGradOp : public OpKernel { Name("LSTMBlockCellGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ LSTMBlockCellGradOp); REGISTER_KERNEL(float); -// REGISTER_KERNEL(double); +REGISTER_KERNEL(Eigen::half); #undef REGISTER_KERNEL #if GOOGLE_CUDA @@ -1008,7 +1009,7 @@ class BlockLSTMOp : public OpKernel { Name("BlockLSTM").Device(DEVICE_CPU).TypeConstraint("T"), \ BlockLSTMOp); REGISTER_KERNEL(float); -// REGISTER_KERNEL(double); +REGISTER_KERNEL(Eigen::half); #undef REGISTER_KERNEL #if GOOGLE_CUDA @@ -1283,7 +1284,7 @@ class BlockLSTMGradOp : public OpKernel { Name("BlockLSTMGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ BlockLSTMGradOp); REGISTER_KERNEL(float); -// REGISTER_KERNEL(double); +REGISTER_KERNEL(Eigen::half); #undef REGISTER_KERNEL #if GOOGLE_CUDA diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc index b664b0f45ee086..057e851aba68c4 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc +++ b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc @@ -169,7 +169,7 @@ __global__ void lstm_gates(const T* icfo, const T* b, const T* cs_prev, f[cid] = f_local; T cs_local = i_local * ci_local + f_local * cs_prev[cid]; - if (cell_clip_t > strict_cast(0.0f)) { + if (cell_clip > 0.0f) { cs_local = clip_op(cs_local, cell_clip_t); } cs[cid] = cs_local; @@ -248,7 +248,8 @@ void LSTMBlockCellFpropWithCUDA( // states1 = xh * w typename TTypes::ConstMatrix const_xh(xh.data(), xh.dimensions()); TensorBlasGemm::compute( - ctx, d, false, false, 1.f, const_xh, w, 0.f, icfo); + ctx, d, false, false, typename gemm_compute_type::type(1.f), const_xh, + w, typename gemm_compute_type::type(0.f), icfo); // Add bias, apply non-linearities and gating. // diff --git a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py index 9ce0b399ba173b..d5700d2a200f6c 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.contrib.rnn.python.kernel_tests import benchmarking @@ -27,6 +28,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import gen_bitwise_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops from tensorflow.python.ops import rnn @@ -38,7 +41,70 @@ block_lstm = lstm_ops._block_lstm # pylint: disable=protected-access -def blocks_match(sess, use_peephole): +class _MaskedRandomUniformInitializer(init_ops.RandomUniform): + """Initializer for uniform dist tensors with trailing bits zeroed-out. + + Allow returning tensors with last few mantissa bits set to 0. This potentially + helps avoid getting into precision issues when testing low precision (float16) + computation. + """ + + def __init__(self, + minval=0, + maxval=None, + seed=None, + dtype=dtypes.float16, + num_valid_mantissa_bits=4): + """Constructor. + + Args: + minval: A python scalar or a scalar tensor. Lower bound of the range of + random values to generate. + maxval: A python scalar or a scalar tensor. Upper bound of the range of + random values to generate. Defaults to 1 for float types. + seed: A Python integer. Used to create random seeds. See + `tf.set_random_seed` for behavior. + dtype: The data type. Only supports tf.float16 for now. + num_valid_mantissa_bits: number of non-zero mantissa bits, default to 4. + + Raises: + ValueError: An error if `dtype` is not tf.float16. + """ + if dtype not in (dtypes.float16,): + raise ValueError("dtype: %s not supported" % dtype.name) + + super(_MaskedRandomUniformInitializer, self).__init__( + minval=minval, maxval=maxval, seed=seed, dtype=dtype) + self._num_mantissa_bits = 10 + self._num_valid_mantissa_bits = num_valid_mantissa_bits + + def __call__(self, shape, dtype=dtypes.float16, partition_info=None): + if dtype and dtype != dtypes.float16: + raise ValueError("dtype: %s not supported" % dtype.name) + res = super(_MaskedRandomUniformInitializer, self).__call__( + shape, dtype, partition_info) + # get uint16 view of the underlying buffer. + res = gen_array_ops.bitcast(res, dtypes.uint16) + + # mask the last `shift` mantissa bits. + shift = self._num_mantissa_bits - self._num_valid_mantissa_bits + mask = (0xffff >> shift) << shift + res = gen_bitwise_ops.bitwise_and(res, mask) + + # restore float16 view. + return gen_array_ops.bitcast(res, dtype) + + +def _get_initializer(init_bound, dtype, seed): + if dtype == dtypes.float16: + return _MaskedRandomUniformInitializer( + -init_bound, init_bound, dtype=dtype, seed=seed) + else: + return init_ops.random_uniform_initializer( + -init_bound, init_bound, dtype=dtype, seed=seed) + + +def blocks_match(sess, use_peephole, dtype=dtypes.float32, cell_clip=None): batch_size = 2 input_size = 3 cell_size = 4 @@ -47,36 +113,42 @@ def blocks_match(sess, use_peephole): inputs = [] for _ in range(sequence_length): inp = ops.convert_to_tensor( - np.random.randn(batch_size, input_size), dtype=dtypes.float32) + np.random.randn(batch_size, input_size), dtype=dtype) inputs.append(inp) stacked_inputs = array_ops.stack(inputs) - initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=19890212) + init_bound = 1e-1 if dtype == dtypes.float16 else 1e-2 + initializer = _get_initializer(init_bound, dtype=dtype, seed=19890212) with variable_scope.variable_scope("test", initializer=initializer): # magic naming so that the cells pick up these variables and reuse them if use_peephole: wci = variable_scope.get_variable( - "rnn/lstm_cell/w_i_diag", shape=[cell_size], dtype=dtypes.float32) + "rnn/lstm_cell/w_i_diag", shape=[cell_size], dtype=dtype) wcf = variable_scope.get_variable( - "rnn/lstm_cell/w_f_diag", shape=[cell_size], dtype=dtypes.float32) + "rnn/lstm_cell/w_f_diag", shape=[cell_size], dtype=dtype) wco = variable_scope.get_variable( - "rnn/lstm_cell/w_o_diag", shape=[cell_size], dtype=dtypes.float32) + "rnn/lstm_cell/w_o_diag", shape=[cell_size], dtype=dtype) w = variable_scope.get_variable( "rnn/lstm_cell/kernel", shape=[input_size + cell_size, cell_size * 4], - dtype=dtypes.float32) + dtype=dtype) b = variable_scope.get_variable( "rnn/lstm_cell/bias", shape=[cell_size * 4], - dtype=dtypes.float32, + dtype=dtype, initializer=init_ops.zeros_initializer()) basic_cell = rnn_cell.LSTMCell( - cell_size, use_peepholes=use_peephole, state_is_tuple=True, reuse=True) + cell_size, + use_peepholes=use_peephole, + cell_clip=cell_clip, + dtype=dtype, + state_is_tuple=True, + reuse=True) basic_outputs_op, basic_state_op = rnn.static_rnn( - basic_cell, inputs, dtype=dtypes.float32) + basic_cell, inputs, dtype=dtype) if use_peephole: _, _, _, _, _, _, block_outputs_op = block_lstm( @@ -87,7 +159,7 @@ def blocks_match(sess, use_peephole): wci=wci, wcf=wcf, wco=wco, - cell_clip=0, + cell_clip=cell_clip, use_peephole=True) else: _, _, _, _, _, _, block_outputs_op = block_lstm( @@ -95,13 +167,15 @@ def blocks_match(sess, use_peephole): inputs, w, b, - cell_clip=0) + cell_clip=cell_clip) fused_cell = lstm_ops.LSTMBlockFusedCell( - cell_size, cell_clip=0, use_peephole=use_peephole, reuse=True, + cell_size, + cell_clip=cell_clip, + use_peephole=use_peephole, + reuse=True, name="rnn/lstm_cell") - fused_outputs_op, fused_state_op = fused_cell( - stacked_inputs, dtype=dtypes.float32) + fused_outputs_op, fused_state_op = fused_cell(stacked_inputs, dtype=dtype) sess.run([variables.global_variables_initializer()]) basic_outputs, basic_state = sess.run([basic_outputs_op, basic_state_op[0]]) @@ -127,7 +201,19 @@ def blocks_match(sess, use_peephole): block_wgrads, fused_wgrads) -class LSTMBlockCellTest(test.TestCase): +class LSTMBlockCellTest(test.TestCase, parameterized.TestCase): + + TEST_CASES = ({ + "testcase_name": "Fp32", + "dtype": dtypes.float32, + "rtol": 1e-6, + "atol": 1e-6 + }, { + "testcase_name": "Fp16", + "dtype": dtypes.float16, + "rtol": 8e-3, + "atol": 8e-4 + }) def testNoneDimsWithDynamicRNN(self): with self.session(use_gpu=True, graph=ops.Graph()) as sess: @@ -314,41 +400,43 @@ def testLSTMBasicToBlockCellPeeping(self): for basic, block in zip(basic_res, block_res): self.assertAllClose(basic, block) - def testLSTMBasicToBlock(self): - with self.session(use_gpu=True) as sess: + def LSTMBasicToBlockTestHelper(self, + dtype=dtypes.float32, + use_peephole=False, + cell_clip=None, + rtol=1e-6, + atol=1e-6): + with self.session(use_gpu=True, graph=ops.Graph()) as sess: (basic_state, fused_state, basic_outputs, block_outputs, fused_outputs, basic_grads, block_grads, fused_grads, basic_wgrads, block_wgrads, fused_wgrads) = blocks_match( - sess, use_peephole=False) + sess, use_peephole=use_peephole, dtype=dtype, cell_clip=cell_clip) - self.assertAllClose(basic_outputs, block_outputs) - self.assertAllClose(basic_grads, block_grads) + self.assertAllClose(basic_outputs, block_outputs, rtol=rtol, atol=atol) + self.assertAllClose(basic_grads, block_grads, rtol=rtol, atol=atol) for basic, block in zip(basic_wgrads, block_wgrads): - self.assertAllClose(basic, block, rtol=1e-6, atol=1e-6) + self.assertAllClose(basic, block, rtol=rtol, atol=atol) - self.assertAllClose(basic_outputs, fused_outputs) - self.assertAllClose(basic_state, fused_state) - self.assertAllClose(basic_grads, fused_grads) - for basic, fused in zip(block_wgrads, fused_wgrads): - self.assertAllClose(basic, fused, rtol=1e-6, atol=1e-6) + self.assertAllClose(basic_outputs, fused_outputs, rtol=rtol, atol=atol) + self.assertAllClose(basic_state, fused_state, rtol=rtol, atol=atol) + self.assertAllClose(basic_grads, fused_grads, rtol=rtol, atol=atol) + for basic, fused in zip(basic_wgrads, fused_wgrads): + self.assertAllClose(basic, fused, rtol=rtol, atol=atol) - def testLSTMBasicToBlockPeeping(self): - with self.session(use_gpu=True) as sess: - (basic_state, fused_state, basic_outputs, block_outputs, fused_outputs, - basic_grads, block_grads, fused_grads, basic_wgrads, block_wgrads, - fused_wgrads) = blocks_match( - sess, use_peephole=True) + @parameterized.named_parameters(*TEST_CASES) + def testLSTMBasicToBlock(self, dtype, rtol, atol): + self.LSTMBasicToBlockTestHelper( + dtype, use_peephole=False, rtol=rtol, atol=atol) - self.assertAllClose(basic_outputs, block_outputs) - self.assertAllClose(basic_grads, block_grads) - for basic, block in zip(basic_wgrads, block_wgrads): - self.assertAllClose(basic, block, rtol=1e-6, atol=1e-6) + @parameterized.named_parameters(*TEST_CASES) + def testLSTMBasicToBlockPeeping(self, dtype, rtol, atol): + self.LSTMBasicToBlockTestHelper( + dtype, use_peephole=True, rtol=rtol, atol=atol) - self.assertAllClose(basic_outputs, fused_outputs) - self.assertAllClose(basic_state, fused_state) - self.assertAllClose(basic_grads, fused_grads) - for basic, fused in zip(block_wgrads, fused_wgrads): - self.assertAllClose(basic, fused, rtol=1e-6, atol=1e-6) + @parameterized.named_parameters(*TEST_CASES) + def testLSTMBasicToBlockCellClip(self, dtype, rtol, atol): + self.LSTMBasicToBlockTestHelper( + dtype, use_peephole=True, cell_clip=0.5, rtol=rtol, atol=atol) def testLSTMFusedSequenceLengths(self): """Verify proper support for sequence lengths in LSTMBlockFusedCell.""" @@ -444,16 +532,21 @@ def benchmarkLSTMBlockCellFpropWithDynamicRNN(self): "batch_size": [1, 8, 13, 32, 67, 128], "cell_size": [128, 250, 512, 650, 1024, 1350], "time_steps": [40], - "use_gpu": [True, False] + "use_gpu": [True, False], + "dtype": ["float32", "float16"], }): + dtype = dtypes.float32 if config["dtype"] == "float32" else dtypes.float16 with ops.Graph().as_default(): with benchmarking.device(use_gpu=config["use_gpu"]): inputs = variable_scope.get_variable( "x", - [config["time_steps"], config["batch_size"], config["cell_size"]]) - cell = lstm_ops.LSTMBlockCell(config["cell_size"]) - outputs = rnn.dynamic_rnn( - cell, inputs, time_major=True, dtype=dtypes.float32) + dtype=dtype, + shape=[ + config["time_steps"], config["batch_size"], + config["cell_size"] + ]) + cell = lstm_ops.LSTMBlockCell(config["cell_size"], dtype=dtype) + outputs = rnn.dynamic_rnn(cell, inputs, time_major=True, dtype=dtype) init_op = variables.global_variables_initializer() with session.Session() as sess: @@ -464,12 +557,14 @@ def benchmarkLSTMBlockCellFpropWithDynamicRNN(self): # is set, this will produce a copy-paste-able CSV file. print(",".join( map(str, [ - config["batch_size"], config["cell_size"], config["cell_size"], - config["time_steps"], config["use_gpu"], wall_time + config["dtype"], config["batch_size"], config["cell_size"], + config["cell_size"], config["time_steps"], config["use_gpu"], + wall_time ]))) benchmark_name_template = "_".join([ - "LSTMBlockCell_fprop", "BS%(batch_size)i", "CS%(cell_size)i", - "IS%(cell_size)i", "TS%(time_steps)i", "gpu_%(use_gpu)s" + "LSTMBlockCell_fprop", "DT_%(dtype)s", "BS%(batch_size)i", + "CS%(cell_size)i", "IS%(cell_size)i", "TS%(time_steps)i", + "gpu_%(use_gpu)s" ]) self.report_benchmark( @@ -488,8 +583,10 @@ def benchmarkLSTMBlockCellBpropWithDynamicRNN(self): "batch_size": [1, 8, 13, 32, 67, 128], "cell_size": [128, 250, 512, 650, 1024, 1350], "time_steps": [40], - "use_gpu": [True, False] + "use_gpu": [True, False], + "dtype": ["float32", "float16"], }): + dtype = dtypes.float32 if config["dtype"] == "float32" else dtypes.float16 with ops.Graph().as_default(): with benchmarking.device(use_gpu=config["use_gpu"]): time_steps = config["time_steps"] @@ -498,21 +595,21 @@ def benchmarkLSTMBlockCellBpropWithDynamicRNN(self): inputs = variable_scope.get_variable( "x", [time_steps, batch_size, cell_size], trainable=False, - dtype=dtypes.float32) + dtype=dtype) with variable_scope.variable_scope( "rnn", reuse=variable_scope.AUTO_REUSE): w = variable_scope.get_variable( "rnn/lstm_cell/kernel", shape=[input_size + cell_size, cell_size * 4], - dtype=dtypes.float32) + dtype=dtype) b = variable_scope.get_variable( "rnn/lstm_cell/bias", shape=[cell_size * 4], - dtype=dtypes.float32, + dtype=dtype, initializer=init_ops.zeros_initializer()) - cell = lstm_ops.LSTMBlockCell(cell_size) + cell = lstm_ops.LSTMBlockCell(cell_size, dtype=dtype) outputs = rnn.dynamic_rnn( - cell, inputs, time_major=True, dtype=dtypes.float32) + cell, inputs, time_major=True, dtype=dtype) grads = gradients_impl.gradients(outputs, [inputs, w, b]) init_op = variables.global_variables_initializer() @@ -524,12 +621,13 @@ def benchmarkLSTMBlockCellBpropWithDynamicRNN(self): # is set, this will produce a copy-paste-able CSV file. print(",".join( map(str, [ - batch_size, cell_size, cell_size, time_steps, config["use_gpu"], - wall_time + config["dtype"], batch_size, cell_size, cell_size, time_steps, + config["use_gpu"], wall_time ]))) benchmark_name_template = "_".join([ - "LSTMBlockCell_bprop", "BS%(batch_size)i", "CS%(cell_size)i", - "IS%(cell_size)i", "TS%(time_steps)i", "gpu_%(use_gpu)s" + "LSTMBlockCell_bprop", "DT_%(dtype)s", "BS%(batch_size)i", + "CS%(cell_size)i", "IS%(cell_size)i", "TS%(time_steps)i", + "gpu_%(use_gpu)s" ]) self.report_benchmark( diff --git a/tensorflow/contrib/rnn/python/ops/fused_rnn_cell.py b/tensorflow/contrib/rnn/python/ops/fused_rnn_cell.py index b7393d8b988071..f90fd40990a32d 100644 --- a/tensorflow/contrib/rnn/python/ops/fused_rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/fused_rnn_cell.py @@ -20,10 +20,13 @@ import abc +import six + from tensorflow.python.ops import array_ops from tensorflow.python.ops import rnn +@six.add_metaclass(abc.ABCMeta) class FusedRNNCell(object): """Abstract object representing a fused RNN cell. @@ -38,8 +41,6 @@ class FusedRNNCell(object): Every `FusedRNNCell` must implement `__call__` with the following signature. """ - __metaclass__ = abc.ABCMeta - @abc.abstractmethod def __call__(self, inputs, diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py index 7edb0f110ca862..4db431f85a4673 100644 --- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -19,6 +19,8 @@ import abc +import six + from tensorflow.contrib.rnn.ops import gen_lstm_ops from tensorflow.contrib.util import loader from tensorflow.python.framework import dtypes @@ -154,7 +156,7 @@ def _block_lstm(seq_len_max, Args: seq_len_max: A `Tensor` of type `int64`. - x: A list of at least 1 `Tensor` objects of the same type in: `float32`. + x: A list of at least 1 `Tensor` objects of the same type. w: A `Tensor`. Must have the same type as `x`. b: A `Tensor`. Must have the same type as `x`. cs_prev: A `Tensor`. Must have the same type as `x`. @@ -187,6 +189,7 @@ def _block_lstm(seq_len_max, Raises: ValueError: If `b` does not have a valid shape. """ + dtype = x[0].dtype batch_size = x[0].get_shape().with_rank(2).dims[0].value cell_size4 = b.get_shape().with_rank(1).dims[0].value if cell_size4 is None: @@ -195,13 +198,13 @@ def _block_lstm(seq_len_max, zero_state = None if cs_prev is None or h_prev is None: zero_state = array_ops.constant( - 0, dtype=dtypes.float32, shape=[batch_size, cell_size]) + 0, dtype=dtype, shape=[batch_size, cell_size]) if cs_prev is None: cs_prev = zero_state if h_prev is None: h_prev = zero_state if wci is None: - wci = array_ops.constant(0, dtype=dtypes.float32, shape=[cell_size]) + wci = array_ops.constant(0, dtype=dtype, shape=[cell_size]) wcf = wci wco = wci @@ -439,6 +442,7 @@ def call(self, inputs, state): return h, new_state +@six.add_metaclass(abc.ABCMeta) class LSTMBlockWrapper(base_layer.Layer): """This is a helper class that provides housekeeping for LSTM cells. diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD index 291ff83791c7cd..395a68c6446302 100644 --- a/tensorflow/contrib/saved_model/BUILD +++ b/tensorflow/contrib/saved_model/BUILD @@ -82,7 +82,6 @@ py_library( name = "keras_saved_model", srcs = ["python/saved_model/keras_saved_model.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], visibility = ["//visibility:public"], deps = [ "//tensorflow/python:array_ops", @@ -103,7 +102,14 @@ py_test( size = "medium", srcs = ["python/saved_model/keras_saved_model_test.py"], srcs_version = "PY2AND3", - tags = ["notsan"], + tags = [ + "no_windows", + # TODO(b/119022845): Re-enable this test in TAP. + "manual", + "notap", + "notsan", + "no_oss", + ], deps = [ ":keras_saved_model", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py index 6aae4bc5e2981c..27b5b6d22e0fc1 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py @@ -19,6 +19,7 @@ from __future__ import print_function import os +import six from tensorflow.python.client import session from tensorflow.python.estimator import keras as estimator_keras_util @@ -30,6 +31,7 @@ from tensorflow.python.keras import models as models_lib from tensorflow.python.keras import optimizers from tensorflow.python.keras.engine import sequential +from tensorflow.python.keras.metrics import Metric from tensorflow.python.keras.models import model_from_json from tensorflow.python.lib.io import file_io from tensorflow.python.ops import variables @@ -276,11 +278,29 @@ def _create_signature_def_map(model, mode): inputs_dict.update(targets_dict) outputs_dict = {name: x for name, x in zip(model.output_names, model.outputs)} + metrics = estimator_keras_util._convert_keras_metrics_to_estimator(model) + + # Add metric variables to the `LOCAL_VARIABLES` collection. Metric variables + # are by default not added to any collections. We are doing this here, so + # that metric variables get initialized. + local_vars = set(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)) + vars_to_add = set() + if metrics is not None: + for key, value in six.iteritems(metrics): + if isinstance(value, Metric): + vars_to_add.update(value.variables) + # Convert Metric instances to (value_tensor, update_op) tuple. + metrics[key] = (value.result(), value.updates[0]) + # Remove variables that are in the local variables collection already. + vars_to_add = vars_to_add.difference(local_vars) + for v in vars_to_add: + ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, v) + export_outputs = model_fn_lib.export_outputs_for_mode( mode, predictions=outputs_dict, loss=model.total_loss if model.optimizer else None, - metrics=estimator_keras_util._convert_keras_metrics_to_estimator(model)) + metrics=metrics) return export_helpers.build_all_signature_defs( inputs_dict, export_outputs=export_outputs, diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py index 364b65e06a3cdc..4970ebc31992c2 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py @@ -150,8 +150,6 @@ def test_saving_with_tf_optimizer(self): x = np.random.random((1, 3)) y = np.random.random((1, 3)) model.train_on_batch(x, y) - model.train_on_batch(x, y) - ref_y = model.predict(x) temp_saved_model = self._save_model_dir() @@ -308,6 +306,7 @@ def testSaveAndLoadSavedModelExport( self, model_builder, uses_learning_phase, optimizer, train_before_export): saved_model_path = self._save_model_dir() with self.session(graph=ops.Graph()): + np.random.seed(130) input_arr = np.random.random((1, 3)) target_arr = np.random.random((1, 3)) @@ -346,6 +345,11 @@ def testSaveAndLoadSavedModelExport( inputs, outputs = load_model(sess, output_path, model_fn_lib.ModeKeys.EVAL) + sess.run(outputs['metrics/mae/update_op'], { + inputs[input_name]: input_arr, + inputs[target_name]: target_arr + }) + eval_results = sess.run(outputs, {inputs[input_name]: input_arr, inputs[target_name]: target_arr}) @@ -353,7 +357,7 @@ def testSaveAndLoadSavedModelExport( sess.run(training_module.get_global_step())) self.assertAllClose(ref_loss, eval_results['loss'], atol=1e-05) self.assertAllClose( - ref_mae, eval_results['metrics/mae/update_op'], atol=1e-05) + ref_mae, eval_results['metrics/mae/value'], atol=1e-05) self.assertAllClose( ref_predict, eval_results['predictions/' + output_name], atol=1e-05) diff --git a/tensorflow/contrib/slim/python/slim/data/data_decoder.py b/tensorflow/contrib/slim/python/slim/data/data_decoder.py index 5a32be6c5a3290..46d33597e42912 100644 --- a/tensorflow/contrib/slim/python/slim/data/data_decoder.py +++ b/tensorflow/contrib/slim/python/slim/data/data_decoder.py @@ -39,12 +39,13 @@ def Decode(self, data, items): import abc +import six + +@six.add_metaclass(abc.ABCMeta) class DataDecoder(object): """An abstract class which is used to decode data for a provider.""" - __metaclass__ = abc.ABCMeta - @abc.abstractmethod def decode(self, data, items): """Decodes the data to returns the tensors specified by the list of items. diff --git a/tensorflow/contrib/slim/python/slim/data/data_provider.py b/tensorflow/contrib/slim/python/slim/data/data_provider.py index a49c0969d96bf7..3252b4fe8470f5 100644 --- a/tensorflow/contrib/slim/python/slim/data/data_provider.py +++ b/tensorflow/contrib/slim/python/slim/data/data_provider.py @@ -38,7 +38,10 @@ import abc +import six + +@six.add_metaclass(abc.ABCMeta) class DataProvider(object): """Maps a list of requested data items to tensors from a data source. @@ -46,7 +49,6 @@ class DataProvider(object): method which returns arbitrary types of data. No assumption is made about the source of the data nor the mechanism for providing it. """ - __metaclass__ = abc.ABCMeta def __init__(self, items_to_tensors, num_samples): """Constructs the Data Provider. diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py index a6ce45c20365d9..1b2b6acacca838 100644 --- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py +++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py @@ -25,6 +25,8 @@ import abc +import six + from tensorflow.contrib.slim.python.slim.data import data_decoder from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor @@ -37,6 +39,7 @@ from tensorflow.python.ops import sparse_ops +@six.add_metaclass(abc.ABCMeta) class ItemHandler(object): """Specifies the item-to-Features mapping for tf.parse_example. @@ -45,8 +48,6 @@ class ItemHandler(object): parsing. """ - __metaclass__ = abc.ABCMeta - def __init__(self, keys): """Constructs the handler with the name of the tf.Feature keys to use. diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index c96ca302d9e3f9..20bcd2447e6fd7 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -312,15 +312,20 @@ tf_cuda_cc_test( ], deps = [ ":trt_conversion", + "@com_google_googletest//:gtest", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/clusters:cluster", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_base", "//tensorflow/core:direct_session", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", ] + if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]), @@ -340,6 +345,10 @@ tf_cuda_cc_test( ":trt_conversion", ":trt_plugins", "@com_google_googletest//:gtest", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/core/grappler/costs:graph_properties", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 26f13b02a895b1..1f5591fe2a6025 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -81,12 +81,13 @@ std::vector GetLoadedTensorRTVersion() { return {ver_major, ver_minor, ver_patch}; } -namespace { +TrtCandidateSelector::TrtCandidateSelector( + const grappler::GraphProperties& graph_properties) + : graph_properties_(graph_properties) {} -bool IsTensorRTCandidate(const tensorflow::Node* node) { +Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { + // TODO(laigd): move this set to TrtNodeValidator where it should belong. // LINT.IfChange - // TODO(jie): Segmentation shouldn't associated with op name. - // Split it into a registration for each kernel. static const std::set candidate_ops = { "Identity", "Snapshot", @@ -127,13 +128,29 @@ bool IsTensorRTCandidate(const tensorflow::Node* node) { "Prod", "Max", "Min", - // TODO(ben,jie): ... }; // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.cc) - return (candidate_ops.count(node->type_string()) || - PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())); + const bool is_supported_op_type = + (candidate_ops.count(node->type_string()) || + PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())); + if (!is_supported_op_type) { + return errors::Unimplemented("Op type ", node->type_string(), + " is not supported."); + } + + std::vector input_edges; + TF_RETURN_IF_ERROR(node->input_edges(&input_edges)); + std::vector> input_node_and_ports; + for (const Edge* input_edge : input_edges) { + input_node_and_ports.emplace_back(&input_edge->src()->def(), + input_edge->src_output()); + } + return validator_.ValidateNode(node->def(), input_node_and_ports, + graph_properties_); } +namespace { + tensorflow::Status BuildNodeMap( const tensorflow::Graph& graph, std::unordered_map* node_map) { @@ -846,9 +863,15 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { } segment_options.minimum_segment_size = params.minimum_segment_size; tensorflow::tensorrt::segment::SegmentNodesVector initial_segments; + TrtCandidateSelector candidate_selector(*params.graph_properties); TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph( - &graph, IsTensorRTCandidate, InputEdgeValidator(*params.graph_properties), - OutputEdgeValidator(), segment_options, &initial_segments)); + &graph, + std::bind(&TrtCandidateSelector::IsTensorRTCandidate, &candidate_selector, + std::placeholders::_1), + // Input validation is already done by TrtCandidateSelector, so we don't + // need to check the input edges. + [](const Edge* edge) { return true; }, OutputEdgeValidator(), + segment_options, &initial_segments)); if (initial_segments.size() > 1) { VLOG(0) << "MULTIPLE tensorrt candidate conversion: " << initial_segments.size(); diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h index 3525202369841f..1c9d82105a7b38 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.h +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h @@ -31,6 +31,26 @@ namespace tensorflow { namespace tensorrt { namespace convert { +// Helper class for the segmenter to determine whether given TF node is +// supported by TRT. +class TrtCandidateSelector { + public: + TrtCandidateSelector(const grappler::GraphProperties& graph_properties); + + // Returns OK iff 'node' is a TF-TRT conversion candidate, which will be added + // to TRT subgraph and later converted into TRT engine. + Status IsTensorRTCandidate(const tensorflow::Node* node); + + private: + // The TF-TRT node converter used to verify whether individual node is + // supported. It will operate in validation-only mode. + TrtNodeValidator validator_; + + // GraphProperties of the graph whose nodes are to be validated by + // IsTensorRTCandidate(). + const grappler::GraphProperties& graph_properties_; +}; + struct ConversionParams { ConversionParams() : input_graph_def(nullptr), diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc b/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc index 8146bed4b0541c..f10729987fdb78 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc @@ -15,9 +15,14 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/convert/convert_graph.h" +#include +#include +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/lib/core/status.h" @@ -33,6 +38,76 @@ namespace tensorflow { namespace tensorrt { namespace convert { +// TODO(laigd): put this into some test utils file. +void ExpectStatus(Status status, error::Code code = error::OK, + const char* substr = nullptr) { + EXPECT_EQ(code, status.code()) + << status << " vs expected error code \"" << error::Code_Name(code) + << "\" and message \"" << substr << "\""; + if (substr) { + EXPECT_THAT(status.error_message(), ::testing::HasSubstr(substr)) << status; + } +} + +TEST(TrtCandidateSelector, Basics) { + // Create a graph containing both TRT-compatible and TRT-incompatible nodes + // and use it to test TrtCandidateSelector::IsTensorRTCandidate(). + const std::vector input_shape_array{2, 2}; + TensorShape input_shape; + TF_EXPECT_OK(TensorShapeUtils::MakeShape(input_shape_array, &input_shape)); + + Scope s = Scope::NewRootScope(); + ops::Placeholder::Attrs feed_attrs; + TF_EXPECT_OK( + TensorShapeUtils::MakeShape(input_shape_array, &feed_attrs.shape_)); + + // Compatible input. + auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT, feed_attrs); + auto const_1 = ops::Const(s.WithOpName("const_1"), 1.0f, input_shape); + + // Compatible MatMul. + auto matmul = ops::MatMul(s.WithOpName("matmul"), feed, const_1); + + // Incompatible MatMul. + ops::MatMul::Attrs matmul_attrs; + matmul_attrs.transpose_a_ = true; + auto incompatible_matmul = ops::MatMul(s.WithOpName("incompatible_matmul"), + feed, const_1, matmul_attrs); + + // Unsupported op. + auto unsupported_op = ops::Sin(s.WithOpName("sin"), feed); + + // Incompatible input. + auto incompatible_feed = ops::Placeholder(s.WithOpName("feed"), DT_DOUBLE); + auto const_2 = ops::Const(s.WithOpName("const_2"), 1.0, input_shape); + // Compatible op with incompatible input. + auto matmul_with_incompatible_input = + ops::MatMul(s.WithOpName("matmul_with_incompatible_input"), + incompatible_feed, const_2); + + grappler::GrapplerItem item; + TF_EXPECT_OK(s.ToGraphDef(&item.graph)); + Tensor feed_tensor(DT_FLOAT, input_shape); + item.feed.push_back(std::make_pair("feed", feed_tensor)); + + grappler::GraphProperties graph_properties(item); + TF_EXPECT_OK(graph_properties.InferStatically(true)); + + TrtCandidateSelector selector(graph_properties); + TF_EXPECT_OK(selector.IsTensorRTCandidate(matmul.operation.node())); + ExpectStatus( + selector.IsTensorRTCandidate(incompatible_matmul.operation.node()), + error::INVALID_ARGUMENT, + "transpose_a is not supported for TensorRT FullyConnected " + "(op: MatMul), at: incompatible_matmul"); + ExpectStatus(selector.IsTensorRTCandidate(unsupported_op.operation.node()), + error::UNIMPLEMENTED, "Op type Sin is not supported"); + ExpectStatus(selector.IsTensorRTCandidate( + matmul_with_incompatible_input.operation.node()), + error::INTERNAL, + "Failed to convert input with index 0 to a TRT_TensorOrWeights"); +} + class FakeCluster : public grappler::Cluster { public: FakeCluster() : Cluster(0) {} @@ -48,8 +123,7 @@ class FakeCluster : public grappler::Cluster { } Status Run(const GraphDef& graph_def, const std::vector>& feed, - const std::vector& fetch, - RunMetadata* metadata) override { + const std::vector& fetch, RunMetadata* metadata) override { return Status::OK(); } diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index dcbc75aebf4e6e..a6f954391d3f2b 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -108,6 +108,18 @@ inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype, return tensorflow::Status::OK(); } +template +inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape, + bool ignore_first_dim) { + nvinfer1::Dims trt_dims; + const int offset = (ignore_first_dim ? 1 : 0); + for (int i = offset; i < shape.dims(); i++) { + trt_dims.d[i - offset] = shape.dim_size(i); + } + trt_dims.nbDims = shape.dims() - offset; + return trt_dims; +} + void GetOutputProperties(const grappler::GraphProperties& graph_properties, const Node* node, const int out_port, PartialTensorShape* shape, @@ -137,22 +149,37 @@ void GetInputProperties(const grappler::GraphProperties& graph_properties, } } -tensorflow::Status ValidateInputProperties(const PartialTensorShape& shape, - const tensorflow::DataType dtype, - nvinfer1::DataType* trt_dtype) { - // TODO(aaroey): some of these checks also apply to IsTensorRTCandidate(), so - // put them there instead. +Status ValidateTensorProperties(const string& producer_node_type, + const tensorflow::DataType dtype, + const PartialTensorShape& shape, + bool validation_only, + nvinfer1::DataType* trt_dtype, + nvinfer1::Dims* trt_dims, int* batch_size) { + // Convert data type. TF_RETURN_IF_ERROR(ConvertDType(dtype, trt_dtype)); + + // Convert shape. if (shape.dims() < 0) { - return tensorflow::errors::InvalidArgument("Input tensor rank is unknown."); + return errors::InvalidArgument("Input tensor rank is unknown."); + } + if (shape.dims() > nvinfer1::Dims::MAX_DIMS + 1) { // +1 for batch dim + return errors::OutOfRange("Input tensor rank is greater than ", + nvinfer1::Dims::MAX_DIMS + 1); } - if (shape.dims() > 9) { - return tensorflow::errors::OutOfRange( - "Input tensor rank is greater than 8."); + if (producer_node_type != "Const" && shape.dims() < 2) { + return errors::InvalidArgument( + "Input tensor with rank<2 is not supported since the first dimension " + "is treated as batch dimension by TRT"); } + *trt_dims = TensorShapeToTrtDims(shape, /*ignore_first_dim=*/true); + *batch_size = shape.dim_size(0); + + if (validation_only) return Status::OK(); + // Following are validations at runtime. + for (int d = 1; d < shape.dims(); ++d) { if (shape.dim_size(d) < 0) { - return tensorflow::errors::InvalidArgument( + return errors::InvalidArgument( "Input tensor with shape ", shape.DebugString(), " has an unknown non-batch dimemension at dim ", d); } @@ -358,22 +385,75 @@ string TRT_ShapedWeights::DebugString() const { ", values=", reinterpret_cast(GetValues()), ")"); } +// A fake ITensor implementation used to check whether the TF-TRT converter can +// handle specific node. We only need shape and type information, and the +// converter won't (and shouldn't) use this to build the TRT network. +class TRT_TensorOrWeights::SimpleITensor : public nvinfer1::ITensor { + public: + SimpleITensor(nvinfer1::DataType trt_dtype, const nvinfer1::Dims& trt_dims) + : trt_dtype_(trt_dtype), trt_dims_(trt_dims) {} + + void setName(const char* name) override {} + + const char* getName() const override { return ""; } + + void setDimensions(nvinfer1::Dims dimensions) override { + trt_dims_ = dimensions; + } + + nvinfer1::Dims getDimensions() const override { return trt_dims_; } + + void setType(nvinfer1::DataType trt_dtype) override { + trt_dtype_ = trt_dtype; + } + + nvinfer1::DataType getType() const override { return trt_dtype_; } + + bool isNetworkInput() const override { return false; } + + bool isNetworkOutput() const override { return false; } + + void setBroadcastAcrossBatch(bool broadcastAcrossBatch) override {} + + bool getBroadcastAcrossBatch() const override { return false; } + + nvinfer1::TensorLocation getLocation() const override { + // This is arbitrary, since we don't use it. + return nvinfer1::TensorLocation::kDEVICE; + } + + void setLocation(nvinfer1::TensorLocation location) override {} + +#if NV_TENSORRT_MAJOR >= 5 + bool setDynamicRange(float min, float max) override {} +#endif + + private: + nvinfer1::DataType trt_dtype_; + nvinfer1::Dims trt_dims_; +}; + TRT_TensorOrWeights::TRT_TensorOrWeights(nvinfer1::ITensor* tensor, int batch_size) : tensor_(tensor), batch_size_(batch_size), - weights_(DT_FLOAT), initialized_(true), is_tensor_(true) {} -TRT_TensorOrWeights::TRT_TensorOrWeights(const TRT_ShapedWeights& weights) - : tensor_(nullptr), - weights_(weights), +TRT_TensorOrWeights::TRT_TensorOrWeights(nvinfer1::DataType trt_dtype, + const nvinfer1::Dims& trt_dims, + int batch_size) + : simple_itensor_(new SimpleITensor(trt_dtype, trt_dims)), + batch_size_(batch_size), initialized_(true), - is_tensor_(false) {} + is_tensor_(true) {} + +TRT_TensorOrWeights::TRT_TensorOrWeights(const TRT_ShapedWeights& weights) + : weights_(weights), initialized_(true), is_tensor_(false) {} TRT_TensorOrWeights::TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs) : tensor_(rhs.tensor_), + simple_itensor_(rhs.simple_itensor_), batch_size_(rhs.batch_size_), weights_(rhs.weights_), initialized_(rhs.initialized_), @@ -381,12 +461,23 @@ TRT_TensorOrWeights::TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs) void TRT_TensorOrWeights::operator=(const TRT_TensorOrWeights& rhs) { tensor_ = rhs.tensor_; + simple_itensor_ = rhs.simple_itensor_; batch_size_ = rhs.batch_size_; weights_ = rhs.weights_; initialized_ = rhs.initialized_; is_tensor_ = rhs.is_tensor_; } +nvinfer1::ITensor* TRT_TensorOrWeights::tensor() { + CHECK(is_tensor()); + return tensor_ == nullptr ? simple_itensor_.get() : tensor_; +} + +const nvinfer1::ITensor* TRT_TensorOrWeights::tensor() const { + CHECK(is_tensor()); + return tensor_ == nullptr ? simple_itensor_.get() : tensor_; +} + nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const { if (is_tensor()) { return tensor()->getDimensions(); @@ -398,8 +489,8 @@ nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const { string TRT_TensorOrWeights::DebugString() const { string output = "TRT_TensorOrWeights(type="; if (is_tensor()) { - StrAppend(&output, "tensor @", reinterpret_cast(tensor_), - ", shape=", convert::DebugString(tensor_->getDimensions()), + StrAppend(&output, "tensor @", reinterpret_cast(tensor()), + ", shape=", convert::DebugString(tensor()->getDimensions()), ", batch_size=", batch_size_); } else { StrAppend(&output, "weights=", weights_.DebugString()); @@ -560,11 +651,10 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights, // TRT requires GKcRS, while TF depthwise has RSCK where c=1, C=G const int c = iweights.shape_.d[2] / num_groups; const int k = iweights.shape_.d[3] * num_groups; - VLOG(2) << "num_groups: " << num_groups - << "c" << iweights.shape_.d[2] << " then " << c - << "k" << iweights.shape_.d[3] << " then " << k - << "r" << iweights.shape_.d[0] << " then " << r - << "s" << iweights.shape_.d[1] << " then " << s; + VLOG(2) << "num_groups: " << num_groups << "c" << iweights.shape_.d[2] + << " then " << c << "k" << iweights.shape_.d[3] << " then " << k + << "r" << iweights.shape_.d[0] << " then " << r << "s" + << iweights.shape_.d[1] << " then " << s; oweights->shape_.d[0] = k / num_groups; oweights->shape_.d[1] = c * num_groups; oweights->shape_.d[2] = r; @@ -608,9 +698,68 @@ TRT_ShapedWeights TrtWeightStore::GetTempWeights(tensorflow::DataType type, TrtNodeValidator::TrtNodeValidator() { RegisterOpValidators(); } +Status TrtNodeValidator::ConvertToTensorOrWeights( + const NodeDef& node_def, int output_port, + const grappler::GraphProperties& graph_properties, + TRT_TensorOrWeights* tensor_or_weights) { + if (node_def.op() == "Const") { + if (output_port != 0) { + return errors::InvalidArgument("Const node should only have one output."); + } + // The output of the conversion will be used as input to other nodes to + // determine whether TRT supports those nodes. If it cannot convert the + // Const, it's very likely we cannot treat it as a tensor and make it an + // input to the TRT network, since TRT removes the first dimension and + // treats it as batch size. Also, it's not likely that the converter can + // support the op, and performance may suffer even if it can, so we just + // simply return error if the conversion fails. + std::vector inputs; + return ConvertConstToWeights(node_def, inputs, tensor_or_weights); + } + if (!graph_properties.HasOutputProperties(node_def.name())) { + return errors::InvalidArgument("Shape and data type are unknown"); + } + + // Validate and convert shape and dtype. + const auto& output_params = + graph_properties.GetOutputProperties(node_def.name()); + const auto& tensor_properties = output_params.at(output_port); + const DataType dtype = tensor_properties.dtype(); + const PartialTensorShape shape = tensor_properties.shape(); + nvinfer1::DataType trt_dtype; + nvinfer1::Dims trt_dims; + int batch_size = -1; + TF_RETURN_IF_ERROR(ValidateTensorProperties( + node_def.op(), dtype, shape, /*validation_only_=*/true, &trt_dtype, + &trt_dims, &batch_size)); + + // Adds a fake ITensor. This is fine since op converter operates in + // validation-only mode and it won't (and shouldn't) use the tensor to do + // any TRT network operations. + *tensor_or_weights = TRT_TensorOrWeights(trt_dtype, trt_dims, batch_size); + return Status::OK(); +} + Status TrtNodeValidator::ValidateNode( const tensorflow::NodeDef& node_def, - const std::vector& inputs) { + const std::vector>& input_node_and_ports, + const grappler::GraphProperties& graph_properties) { + // Convert input NodeDef and corresponding output ports to + // TRT_TensorOrWeights. + std::vector inputs; + for (int i = 0; i < input_node_and_ports.size(); ++i) { + const auto& pair = input_node_and_ports[i]; + TRT_TensorOrWeights tensor_or_weights; + Status status = ConvertToTensorOrWeights( + *pair.first, pair.second, graph_properties, &tensor_or_weights); + if (!status.ok()) { + return errors::Internal("Failed to convert input with index ", i, + " to a TRT_TensorOrWeights"); + } + inputs.push_back(tensor_or_weights); + } + + // Validate the node. const auto iter = op_validators_.find(node_def.op()); if (iter == op_validators_.end()) { // If validator is not registered, it means no validation is needed. @@ -621,7 +770,19 @@ Status TrtNodeValidator::ValidateNode( OpConverterParams params( /*arg_converter=*/nullptr, node_def, inputs, /*arg_outputs=*/nullptr, /*arg_validation_only=*/true, &weight_store_); - Status status = validator(¶ms); + return validator(¶ms); +} + +Status TrtNodeValidator::ConvertConstToWeights( + const NodeDef& const_node_def, + const std::vector& inputs, + TRT_TensorOrWeights* output) { + std::vector outputs; + OpConverterParams params( + /*arg_converter=*/nullptr, const_node_def, inputs, &outputs, + /*arg_validation_only=*/true, &weight_store_); + Status status = op_validators_["Const"](¶ms); + if (status.ok() && output) *output = outputs[0]; return status; } @@ -1663,7 +1824,7 @@ tensorflow::Status ConvertActivation(OpConverterParams* params) { } tensorflow::Status ConvertScale(OpConverterParams* params) { - const auto inputs = params->inputs; + const auto& inputs = params->inputs; const auto& node_def = params->node_def; if (inputs.size() != 2 || !inputs.at(0).is_tensor() || !inputs.at(1).is_weights()) { @@ -1798,8 +1959,13 @@ Status TfTensorToTrtWeights(const DataType dtype, const Tensor& tensor, return Status::OK(); } +// Convert a Const NodeDef to TRT_ShapedWeights. This is a special converter, it +// always ignores the params->validation_only parameter but adds the converted +// weights to params->outputs. We did this since TrtNodeValidator needs the +// weights as input to other nodes, and use it to determine whether those nodes +// are supported by TRT. tensorflow::Status ConvertConst(OpConverterParams* params) { - const auto inputs = params->inputs; + const auto& inputs = params->inputs; const auto& node_def = params->node_def; if (!inputs.empty()) { return errors::InvalidArgument( @@ -1896,11 +2062,10 @@ tensorflow::Status ConvertConst(OpConverterParams* params) { return errors::Unimplemented("Not supported constant type, at ", node_def.name()); } - // Pass the output. - if (!params->validation_only) { + if (params->outputs != nullptr) { params->outputs->push_back(TRT_TensorOrWeights(weights)); } - return tensorflow::Status::OK(); + return Status::OK(); } tensorflow::Status ConvertIdentity(OpConverterParams* params) { @@ -1909,7 +2074,7 @@ tensorflow::Status ConvertIdentity(OpConverterParams* params) { } tensorflow::Status ConvertBinary(OpConverterParams* params) { - const auto inputs = params->inputs; + const auto& inputs = params->inputs; const auto& node_def = params->node_def; if (inputs.size() != 2) { return tensorflow::errors::FailedPrecondition( @@ -2418,19 +2583,20 @@ tensorflow::Status ConvertMatMul(OpConverterParams* params) { // TODO(jie): INT32 should be converted? tensorflow::DataType tf_dtype = attrs.get("T"); if (tf_dtype != DataType::DT_FLOAT && tf_dtype != DataType::DT_HALF) { - return tensorflow::errors::Unimplemented( - "data type is not supported, for node " + node_def.name() + " got " + - tensorflow::DataTypeString(tf_dtype)); + return errors::Unimplemented("Data type is not supported, for node ", + node_def.name(), " got ", + DataTypeString(tf_dtype)); } bool transpose_a = attrs.get("transpose_a"); bool transpose_b = attrs.get("transpose_b"); // FullyConnected: if (transpose_a) { - return tensorflow::errors::Internal( - "Transpose_a is not supported for TensorRT FullyConnected (op: " + - node_def.op() + "), at: " + node_def.name()); + return errors::InvalidArgument( + "transpose_a is not supported for TensorRT FullyConnected (op: ", + node_def.op(), "), at: ", node_def.name()); } + if (params->validation_only) return Status::OK(); return ConvertMatMulHelper(params, inputs.at(0), inputs.at(1).weights(), transpose_b, node_def.name()); } @@ -2673,10 +2839,13 @@ tensorflow::Status ConvertGraphDefToEngine( return tensorflow::errors::InvalidArgument( "Failed to parse slot number from ", node_name); } - nvinfer1::DataType dtype; + nvinfer1::DataType trt_dtype; + nvinfer1::Dims trt_dims; + int batch_size = -1; auto shape = input_shapes.at(slot_number); - auto status = ValidateInputProperties( - shape, node_def.attr().at("dtype").type(), &dtype); + auto status = ValidateTensorProperties( + node_def.op(), node_def.attr().at("dtype").type(), shape, + /*validation_only=*/false, &trt_dtype, &trt_dims, &batch_size); if (!status.ok()) { const string error_message = StrCat("Validation failed for ", node_name, " and input slot ", @@ -2684,19 +2853,13 @@ tensorflow::Status ConvertGraphDefToEngine( LOG(WARNING) << error_message; return Status(status.code(), error_message); } - - nvinfer1::Dims input_dim; - for (int i = 1; i < shape.dims(); i++) { - input_dim.d[i - 1] = shape.dim_size(i); - } - input_dim.nbDims = shape.dims() - 1; VLOG(2) << "Adding engine input tensor " << node_name << " with shape " - << DebugString(input_dim); + << DebugString(trt_dims); // TODO(laigd): the conversion should always happen at runtime where all // the shapes are known, and we can provide a mode to generate the // engines offline, by calling sess.run() and cache/serialize the engines. - TF_RETURN_IF_ERROR(converter.AddInputTensor(node_name, dtype, input_dim, - shape.dim_size(0))); + TF_RETURN_IF_ERROR( + converter.AddInputTensor(node_name, trt_dtype, trt_dims, batch_size)); } else if (tensorflow::str_util::StartsWith(node_name, kOutputPHName) && (node_def.op() == "Identity")) { int32 slot_number = -1; @@ -2866,34 +3029,6 @@ tensorflow::Status ConvertSegmentToGraphDef( return tensorflow::Status::OK(); } -bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const { - if (in_edge->IsControlEdge()) return true; - PartialTensorShape shape; - tensorflow::DataType dtype; - GetOutputProperties(graph_properties_, in_edge->src(), in_edge->src_output(), - &shape, &dtype); - nvinfer1::DataType trt_dtype; - Status status = ValidateInputProperties(shape, dtype, &trt_dtype); - if (!status.ok()) { - VLOG(1) << "--> Need to remove input node " << in_edge->dst()->name() - << ": " << status; - return false; - } - - - if (in_edge->src()->type_string() != "Const" && - // Single dimensional input tensor is not supported since the first - // dimension is treated as batch dimension. - shape.dims() < 2) { - VLOG(1) << "--> Need to remove input node " << in_edge->dst()->name() - << " which has an input at port " << in_edge->dst_input() << " with" - << " #dim<2" - << " and is not a const: " << shape; - return false; - } - return true; -} - bool OutputEdgeValidator::operator()(const tensorflow::Edge* out_edge) const { if (out_edge->IsControlEdge()) return true; if (out_edge->src()->type_string() == "Const") { diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h index 699b50b37e3bb6..5cc28b33e7f2c5 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h @@ -148,21 +148,6 @@ tensorflow::Status ConvertGraphDefToEngine( TrtUniquePtrType* engine, bool* convert_successfully); -// Helper class for the segmenter to determine whether an input edge to the TRT -// segment is valid. -class InputEdgeValidator { - public: - InputEdgeValidator(const grappler::GraphProperties& graph_properties) - : graph_properties_(graph_properties) {} - - // Return true if the specified edge is eligible to be an input edge of the - // TRT segment. - bool operator()(const tensorflow::Edge* in_edge) const; - - private: - const grappler::GraphProperties& graph_properties_; -}; - // Helper class for the segmenter to determine whether an output edge from the // TRT segment is valid. class OutputEdgeValidator { @@ -245,8 +230,21 @@ class TRT_TensorOrWeights { public: TRT_TensorOrWeights() {} + // Constructor that makes it an ITensor, doesn't take ownership of 'tensor'. + // This is used by Converter when building the TRT network, where the ITensor + // is owned by the TRT network being built. See comment for 'tensor_' below. explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor, int batch_size = -1); + // Constructor that makes it an ITensor by creating one using provided data + // type and shape, and takes ownership of the created ITensor. This is used by + // TrtNodeValidator to encapsulate the type and shape information for + // validation of graph nodes, and the created ITensor is fake and temporary, + // and should not be used to build any TRT network. See comment for + // 'simple_itensor_' below. + explicit TRT_TensorOrWeights(nvinfer1::DataType trt_dtype, + const nvinfer1::Dims& trt_dims, int batch_size); + + // Constructor that makes it a TRT_TensorOrWeights. explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights); TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs); @@ -256,15 +254,9 @@ class TRT_TensorOrWeights { bool is_tensor() const { return initialized_ && is_tensor_; } bool is_weights() const { return initialized_ && !is_tensor_; } - nvinfer1::ITensor* tensor() { - CHECK(is_tensor()); - return tensor_; - } + nvinfer1::ITensor* tensor(); - const nvinfer1::ITensor* tensor() const { - CHECK(is_tensor()); - return tensor_; - } + const nvinfer1::ITensor* tensor() const; TRT_ShapedWeights& weights() { CHECK(is_weights()); @@ -283,9 +275,25 @@ class TRT_TensorOrWeights { string DebugString() const; private: + class SimpleITensor; + void set_batch_size(int batch_size) { batch_size_ = batch_size; } + // When it represents an ITensor, the ITensor can be either passed by the + // caller via the constructor that takes an ITensor* as parameter, or be + // created as a SimpleITensor. + // + // In the first case, the ITensor pointer is stored in 'tensor_' below, and + // the ITensor itself is not owned by this class. This method is used by + // Converter (e.g. AddInputTensor) and op converters during TRT network + // construction, where the TRT network owns the ITensor. + // + // In the second case, the created SimpleITensor is stored in + // 'simple_itensor_' below and is owned by this class. SimpleITensor is a fake + // implementation of ITensor and is used only by TrtNodeValidator to validate + // the graph nodes. nvinfer1::ITensor* tensor_ = nullptr; // Not owned. + std::shared_ptr simple_itensor_ = nullptr; // First dimension of the TF tensor (NOT tensor_) that is represented by // tensor_ is treated as the "batch dimension" by TRT, and tensor_'s @@ -339,13 +347,35 @@ class TrtNodeValidator { public: TrtNodeValidator(); - // Validate the node, and return ok if it's supported by the converter. - Status ValidateNode(const NodeDef& node_def, - const std::vector& inputs); + // Validate the node, and return ok if it's supported by TRT. + // + // - 'node_def' is the node to validate. + // - 'input_node_and_ports' are the input NodeDefs and their output ports that + // are connected to 'node_def' in the TF graph. + // - 'graph_properties' is the GraphProperties of the graph where 'node_def' + // belongs. It is used to get the shape and data type information of a + // tensor for validation purpose. + Status ValidateNode( + const NodeDef& node_def, + const std::vector>& input_node_and_ports, + const grappler::GraphProperties& graph_properties); private: void RegisterOpValidators(); + // Convert a Const node to a TRT_TensorOrWeights. + Status ConvertConstToWeights(const NodeDef& const_node_def, + const std::vector& inputs, + TRT_TensorOrWeights* output); + + // Convert the output tensor at 'output_port' of 'node_def' to a + // TRT_TensorOrWeights which will be later used as an input to other nodes and + // passed to ValidateNode() below. + Status ConvertToTensorOrWeights( + const NodeDef& node_def, int output_port, + const grappler::GraphProperties& graph_properties, + TRT_TensorOrWeights* tensor_or_weights); + // Stores all the validators by op type. If no validator is registered for // specific op, it means no validation is needed and ValidateNode() will // return OK. diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc index bc390743335c53..c3a39395f3a99f 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc @@ -21,6 +21,9 @@ limitations under the License. #include #include +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT @@ -29,6 +32,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -45,6 +49,7 @@ namespace convert { using ::testing::ElementsAre; +// TODO(laigd): put this into some test utils file. void ExpectStatus(Status status, error::Code code = error::OK, const char* substr = nullptr) { EXPECT_EQ(code, status.code()) @@ -75,6 +80,23 @@ NodeDef MakeNodeDef(const string& name, const string& op, return node_def; } +template +NodeDef MakeConstNodeDef(const string& name, const std::vector& vals, + const TensorShape& shape) { + Scope s = Scope::NewRootScope(); + Tensor t = ::tensorflow::test::AsTensor(vals, shape); + auto const_op = ops::Const(s.WithOpName(name), t); + return const_op.node()->def(); +} + +template +NodeDef MakeConstNodeDef(const string& name, const std::vector& vals) { + TensorShape shape; + const std::vector shape_dims = {static_cast(vals.size())}; + TF_EXPECT_OK(TensorShapeUtils::MakeShape(shape_dims, &shape)); + return MakeConstNodeDef(name, vals, shape); +} + bool TrtDimsEquals(const nvinfer1::Dims& lhs, const nvinfer1::Dims& rhs) { if (lhs.nbDims != rhs.nbDims) return false; for (int i = 0; i < lhs.nbDims; ++i) { @@ -95,6 +117,19 @@ bool TrtShapedWeightsEquals(const TRT_ShapedWeights& lhs, lhs.GetValues() == rhs.GetValues(); } +template +void ValidateWeights(const TRT_ShapedWeights& weights, + const std::vector& expected_dims, + const std::vector& expected_value) { + EXPECT_TRUE(TrtDimsEqualsArray(expected_dims, weights.shape_)) + << weights.DebugString(); + ASSERT_EQ(expected_value.size(), weights.count()) << weights.DebugString(); + const T* actual_values = static_cast(weights.GetValues()); + for (int i = 0; i < expected_value.size(); ++i) { + EXPECT_EQ(expected_value[i], actual_values[i]); + } +} + // Fake ITensor implementation for testing purposes. class FakeITensor : public nvinfer1::ITensor { public: @@ -194,32 +229,86 @@ TEST(TRT_ShapedWeights_Test, Basic) { } TEST(TRT_TensorOrWeights_Test, Basic) { + // Test constructor with no arguments. + { + TRT_TensorOrWeights tw; + TRT_TensorOrWeights copy(tw); + TRT_TensorOrWeights assigned; + assigned = tw; + for (auto ptr : {&tw, ©, &assigned}) { + EXPECT_EQ(false, ptr->is_tensor()); + EXPECT_EQ(false, ptr->is_weights()); + EXPECT_EQ(-1, ptr->batch_size()); + } + } + + // Test constructor with ITensor and batch size argument. { nvinfer1::Dims dims; dims.nbDims = 1; dims.d[0] = 1; FakeITensor itensor(dims); - TRT_TensorOrWeights tw(&itensor); - EXPECT_EQ(true, tw.is_tensor()); - EXPECT_EQ(false, tw.is_weights()); - EXPECT_EQ(&itensor, tw.tensor()); - EXPECT_TRUE(TrtDimsEqualsArray({1}, tw.GetTrtDims())) - << "- expected: " << DebugString(dims) - << "\n vs\n- actual: " << DebugString(tw.GetTrtDims()); + TRT_TensorOrWeights tw1(&itensor, /*batch_size=*/1); + + for (auto original_ptr : {&tw, &tw1}) { + TRT_TensorOrWeights copy(*original_ptr); + TRT_TensorOrWeights assigned; + assigned = *original_ptr; + + for (auto ptr : {original_ptr, ©, &assigned}) { + EXPECT_EQ(true, ptr->is_tensor()); + EXPECT_EQ(false, ptr->is_weights()); + if (original_ptr == &tw) { + EXPECT_EQ(-1, ptr->batch_size()); + } else { + EXPECT_EQ(1, ptr->batch_size()); + } + EXPECT_EQ(&itensor, ptr->tensor()); + EXPECT_TRUE(TrtDimsEqualsArray({1}, ptr->GetTrtDims())) + << "- expected: " << DebugString(dims) + << "\n vs\n- actual: " << DebugString(ptr->GetTrtDims()); + } + } } + // Test constructor which creates and owns an ITensor. { - TRT_ShapedWeights weights(DT_FLOAT); - TRT_TensorOrWeights tw(weights); - EXPECT_EQ(false, tw.is_tensor()); - EXPECT_EQ(true, tw.is_weights()); - EXPECT_TRUE(TrtShapedWeightsEquals(weights, tw.weights())); - nvinfer1::Dims dims; - dims.nbDims = 0; - EXPECT_TRUE(TrtDimsEqualsArray({}, tw.GetTrtDims())) - << "- expected: " << DebugString(dims) - << "\n vs\n- actual: " << DebugString(tw.GetTrtDims()); + dims.nbDims = 1; + dims.d[0] = 1; + TRT_TensorOrWeights tw(nvinfer1::DataType::kFLOAT, dims, /*batch_size=*/1); + TRT_TensorOrWeights copy(tw); + TRT_TensorOrWeights assigned; + assigned = tw; + + for (auto ptr : {&tw, ©, &assigned}) { + EXPECT_EQ(true, ptr->is_tensor()); + EXPECT_EQ(false, ptr->is_weights()); + EXPECT_EQ(1, ptr->batch_size()); + EXPECT_NE(nullptr, ptr->tensor()); + EXPECT_TRUE(TrtDimsEqualsArray({1}, ptr->GetTrtDims())) + << "- expected: " << DebugString(dims) + << "\n vs\n- actual: " << DebugString(ptr->GetTrtDims()); + } + } + // Test constructor with TRT_ShapedWeights argument. + { + TRT_ShapedWeights weights; + TRT_TensorOrWeights tw(weights); + TRT_TensorOrWeights copy(tw); + TRT_TensorOrWeights assigned; + assigned = tw; + for (auto ptr : {&tw, ©, &assigned}) { + EXPECT_EQ(false, ptr->is_tensor()); + EXPECT_EQ(true, ptr->is_weights()); + EXPECT_TRUE(TrtShapedWeightsEquals(weights, ptr->weights())); + + nvinfer1::Dims dims; + dims.nbDims = 0; + EXPECT_TRUE(TrtDimsEqualsArray({}, ptr->GetTrtDims())) + << "- expected: " << DebugString(dims) + << "\n vs\n- actual: " << DebugString(ptr->GetTrtDims()); + } } } @@ -229,11 +318,64 @@ class ValidatorTest : public ::testing::Test { validator_.op_validators_[op_name] = op_validator; } + Status ConvertToTensorOrWeights( + const NodeDef& node_def, int output_port, + const grappler::GraphProperties& graph_properties, + TRT_TensorOrWeights* tensor_or_weights) { + return validator_.ConvertToTensorOrWeights( + node_def, output_port, graph_properties, tensor_or_weights); + } + protected: TrtNodeValidator validator_; }; +TEST_F(ValidatorTest, ConvertToTensorOrWeights) { + // Convert Const. + { + NodeDef node_def = MakeConstNodeDef("my_const", {1.0f, 2.0f}); + TRT_TensorOrWeights output; + grappler::GrapplerItem item; + grappler::GraphProperties graph_properties(item); + ExpectStatus(ConvertToTensorOrWeights(node_def, /*output_port=*/0, + graph_properties, &output)); + ValidateWeights(output.weights(), {2}, {1.0, 2.0}); + } + // Convert non-Const. We test the case where the non-batch dimemsion is + // unknown as well, to make sure the validator allows that. + for (const int32 non_batch_dim : {-1, 2}) { + const int32 batch_size = 12; + + Scope s = Scope::NewRootScope(); + ops::Placeholder::Attrs attrs; + TF_EXPECT_OK(TensorShapeUtils::MakeShape( + std::vector{batch_size, non_batch_dim}, &attrs.shape_)); + auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT, attrs); + auto add = ops::Add(s.WithOpName("add"), feed, feed); + + grappler::GrapplerItem item; + TF_EXPECT_OK(s.ToGraphDef(&item.graph)); + + grappler::GraphProperties graph_properties(item); + TF_EXPECT_OK(graph_properties.InferStatically(true)); + + auto& node_def = add.operation.node()->def(); + TRT_TensorOrWeights output; + ExpectStatus(ConvertToTensorOrWeights(node_def, /*output_port=*/0, + graph_properties, &output)); + EXPECT_EQ(true, output.is_tensor()); + EXPECT_EQ(batch_size, output.batch_size()); + EXPECT_NE(nullptr, output.tensor()); + EXPECT_TRUE(TrtDimsEqualsArray({non_batch_dim}, output.GetTrtDims())) + << "- expected: {" << non_batch_dim << "} \n vs\n" + << "- actual: " << DebugString(output.GetTrtDims()); + } +} + TEST_F(ValidatorTest, ValidateNode) { + grappler::GrapplerItem item; + grappler::GraphProperties graph_properties(item); + bool start_conversion = false; bool should_fail = false; auto op_converter = [&start_conversion, @@ -245,16 +387,17 @@ TEST_F(ValidatorTest, ValidateNode) { NodeDef node_def = MakeNodeDef("my_op", "MyOp", {}); // Validator not registered, validation should pass. - TF_EXPECT_OK(validator_.ValidateNode(node_def, {})); + TF_EXPECT_OK(validator_.ValidateNode(node_def, {}, graph_properties)); // Register validator. AddOpValidator("MyOp", op_converter); - TF_EXPECT_OK(validator_.ValidateNode(node_def, {})); + TF_EXPECT_OK(validator_.ValidateNode(node_def, {}, graph_properties)); EXPECT_EQ(false, start_conversion); // Let the converter return error. should_fail = true; - ExpectStatus(validator_.ValidateNode(node_def, {}), error::INVALID_ARGUMENT); + ExpectStatus(validator_.ValidateNode(node_def, {}, graph_properties), + error::INVALID_ARGUMENT); } class ConverterTest : public ::testing::Test { @@ -289,6 +432,8 @@ class ConverterTest : public ::testing::Test { return converter_->GetInputs(node_def, inputs); } + int batch_size() const { return converter_->batch_size_; } + private: Logger logger_; // These members are ordered in a way such that the destruction order is: @@ -474,11 +619,48 @@ TEST_F(ConverterTest, PrepareTensorForShape_Weights) { << DebugString(*output_tensor); } +TEST_F(ConverterTest, MaybeUpdateBatchSize) { + EXPECT_EQ(-1, batch_size()); + + TF_EXPECT_OK(MaybeUpdateBatchSize(-1)); + EXPECT_EQ(-1, batch_size()); + + TF_EXPECT_OK(MaybeUpdateBatchSize(123)); + EXPECT_EQ(123, batch_size()); + + TF_EXPECT_OK(MaybeUpdateBatchSize(123)); + EXPECT_EQ(123, batch_size()); + + TF_EXPECT_OK(MaybeUpdateBatchSize(-1)); + EXPECT_EQ(123, batch_size()); + + ExpectStatus(MaybeUpdateBatchSize(124), error::INVALID_ARGUMENT, + "Provided batch size does not match converter batch size"); +} + +TEST_F(ConverterTest, AddAndGetTensorOrWeights) { + // Add a tensor. + FakeITensor fake_tensor; + TRT_TensorOrWeights tensor(&fake_tensor); + EXPECT_EQ(-1, tensor.batch_size()); + TF_EXPECT_OK(MaybeUpdateBatchSize(123)); + TF_EXPECT_OK(AddTensorOrWeights("my_tensor", tensor)); + + // Get the added tensor. + TRT_TensorOrWeights added_tensor; + TF_EXPECT_OK(GetTensorOrWeights("my_tensor", &added_tensor)); + EXPECT_EQ(123, added_tensor.batch_size()); + + // Add the same tensor again. + ExpectStatus(AddTensorOrWeights("my_tensor", tensor), error::ALREADY_EXISTS, + "tensor/weights my_tensor already exist"); +} + // Class to test various op converters, using both a TrtNodeValidator and // Converter. class OpConverterTest : public ::testing::Test { public: - OpConverterTest() { + OpConverterTest() : scope_(Scope::NewRootScope()) { QCHECK_EQ(0, cudaStreamCreate(&stream_)); Reset(); } @@ -505,8 +687,8 @@ class OpConverterTest : public ::testing::Test { converter_.reset(new Converter(network_.get(), /*fp16=*/false)); // Reset other related artifacts. - fake_itensors_.clear(); - fake_tensor_or_weights_.clear(); + scope_ = Scope::NewRootScope(); + validator_inputs_.clear(); } void BuildAndRun(const char* input_name, const std::vector& input_data, @@ -551,33 +733,41 @@ class OpConverterTest : public ::testing::Test { } // Add ITensor for both validation and conversion. - void AddTestTensor(const char* name, const std::vector& dims, - int batch_size = 1) { - nvinfer1::Dims trt_dims = GetTestDims(dims); - // Add FakeITensor for validation. - // - // TRT cannot add a tensor that has undetermined dims, so we manage the - // tensor using a vector. These tensors are used to test validation-only - // mode and thus should not be used to build the engine. - FakeITensor* fake_itensor = new FakeITensor(trt_dims); - fake_itensors_.emplace_back(fake_itensor); - fake_tensor_or_weights_[string(name)] = - TRT_TensorOrWeights{fake_itensor, batch_size}; + void AddTestTensor( + const char* name, const std::vector& dims, int batch_size = 1, + nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT) { + DataType tf_dtype = DT_FLOAT; + switch (trt_dtype) { + case nvinfer1::DataType::kFLOAT: + tf_dtype = DT_FLOAT; + break; + case nvinfer1::DataType::kINT32: + tf_dtype = DT_INT32; + break; + default: + ASSERT_TRUE(false) << "Unexpected data type " + << static_cast(trt_dtype); + } + ops::Placeholder::Attrs attrs; + TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &attrs.shape_)); + attrs.shape_.InsertDim(0, batch_size); + auto input = ops::Placeholder(scope_.WithOpName(name), tf_dtype, attrs); + validator_inputs_[name] = input.operation.node()->def(); // Add a real ITensor for conversion conditionally. + const nvinfer1::Dims trt_dims = GetTestDims(dims); if (HasStaticShape(trt_dims)) { - TF_EXPECT_OK(converter_->AddInputTensor(name, nvinfer1::DataType::kFLOAT, - trt_dims, batch_size)); + TF_EXPECT_OK( + converter_->AddInputTensor(name, trt_dtype, trt_dims, batch_size)); ASSERT_EQ(batch_size, converter_->batch_size_); } } // Add weights for both validation and conversion. - template - void AddTestWeights(const char* name, const DataType dtype, - const std::vector& dims, - const std::vector& values) { - QCHECK_EQ(DataTypeToEnum::v(), dtype); + template + void AddTestWeights(const char* name, const std::vector& dims, + const std::vector& values) { + const DataType dtype = DataTypeToEnum::v(); const nvinfer1::Dims trt_dims = GetTestDims(dims); const int64_t num_elements = TrtDimsNumElements(trt_dims); QCHECK_EQ(num_elements, values.size()) @@ -585,13 +775,15 @@ class OpConverterTest : public ::testing::Test { TRT_ShapedWeights weights(dtype); if (num_elements) { weights = converter_->weight_store_.GetTempWeights(dtype, trt_dims); - QCHECK_EQ(weights.size_bytes(), sizeof(CType) * values.size()) - << weights.size_bytes() << " vs " << sizeof(CType) * values.size(); + QCHECK_EQ(weights.size_bytes(), sizeof(T) * values.size()) + << weights.size_bytes() << " vs " << sizeof(T) * values.size(); memcpy(const_cast(weights.GetValues()), values.data(), weights.size_bytes()); } // Add weights for validation. - fake_tensor_or_weights_[string(name)] = TRT_TensorOrWeights{weights}; + TensorShape shape; + TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &shape)); + validator_inputs_[name] = MakeConstNodeDef(name, values, shape); // Add weights for conversion. TF_EXPECT_OK( converter_->AddTensorOrWeights(name, TRT_TensorOrWeights{weights})); @@ -601,12 +793,18 @@ class OpConverterTest : public ::testing::Test { void RunValidation(const NodeDef& node_def, error::Code expected_code = error::OK, const char* expected_msg_substr = nullptr) { - std::vector inputs; + std::vector> input_node_and_ports; for (const string& input : node_def.input()) { - inputs.emplace_back(fake_tensor_or_weights_[input]); + input_node_and_ports.emplace_back(&validator_inputs_[input], 0); } - ExpectStatus(validator_->ValidateNode(node_def, inputs), expected_code, - expected_msg_substr); + grappler::GrapplerItem item; + TF_EXPECT_OK(scope_.ToGraphDef(&item.graph)); + grappler::GraphProperties graph_properties(item); + TF_EXPECT_OK(graph_properties.InferStatically(true)); + + ExpectStatus(validator_->ValidateNode(node_def, input_node_and_ports, + graph_properties), + expected_code, expected_msg_substr); } void RunConversion(const NodeDef& node_def, @@ -637,8 +835,8 @@ class OpConverterTest : public ::testing::Test { TrtUniquePtrType network_; TrtUniquePtrType engine_; cudaStream_t stream_; - std::vector> fake_itensors_; - std::unordered_map fake_tensor_or_weights_; + Scope scope_; + std::unordered_map validator_inputs_; }; template @@ -662,15 +860,7 @@ void TestConvertConst(OpConverterTest* test) { test->RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(test->GetTensorOrWeights("my_const", &output)); - EXPECT_TRUE(TrtDimsEqualsArray(expected_dims, output.weights().shape_)) - << output.DebugString(); - ASSERT_EQ(expected_value.size(), output.weights().count()) - << output.DebugString(); - const OutputCType* actual_values = - static_cast(output.weights().GetValues()); - for (int i = 0; i < expected_value.size(); ++i) { - EXPECT_EQ(expected_value[i], actual_values[i]); - } + ValidateWeights(output.weights(), expected_dims, expected_value); }; auto& attr = *node_def.mutable_attr(); @@ -700,8 +890,6 @@ void TestConvertConst(OpConverterTest* test) { } } -// TODO(laigd): we should use c++ API to create the nodedef, so any change in -// the API will be captured. TEST_F(OpConverterTest, ConvertConst) { { Reset(); @@ -713,10 +901,9 @@ TEST_F(OpConverterTest, ConvertConst) { } { Reset(); - NodeDef node_def = MakeNodeDef("my_const", "Const", {}); - (*node_def.mutable_attr())["dtype"].set_type(DT_DOUBLE); + NodeDef node_def = MakeConstNodeDef("my_const", {}); RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, - "Unsupported data type"); + "Unsupported data type double"); } TestConvertConst(this); @@ -732,8 +919,14 @@ TEST_F(OpConverterTest, ConvertTranspose) { node_def, error::INVALID_ARGUMENT, "Input expects tensor and weights, at my_transpose"); } - NodeDef node_def = - MakeNodeDef("my_transpose", "Transpose", {"input", "weights"}); + + // Get the NodeDef for Transpose. + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32); + auto transpose = ops::Transpose(s.WithOpName("my_transpose"), input, weights); + const NodeDef& node_def = transpose.operation.node()->def(); + { // Permutation is a tensor, should fail. Reset(); @@ -747,7 +940,7 @@ TEST_F(OpConverterTest, ConvertTranspose) { // Transpose at batch dimension, should fail. Reset(); AddTestTensor("input", {1, 2, 3}); - AddTestWeights("weights", DT_INT32, {4}, {1, 0, 2, 3}); + AddTestWeights("weights", {4}, {1, 0, 2, 3}); RunValidationAndConversion(node_def, error::UNIMPLEMENTED, "Transpose at batch dimension is not supported"); } @@ -755,7 +948,7 @@ TEST_F(OpConverterTest, ConvertTranspose) { // Permutation rank doesn't match, should fail. Reset(); AddTestTensor("input", {1, 2, 3}); - AddTestWeights("weights", DT_INT32, {3}, {0, 1, 2}); + AddTestWeights("weights", {3}, {0, 1, 2}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, "Rank of perm for transpose does not match with that of the input."); @@ -764,7 +957,7 @@ TEST_F(OpConverterTest, ConvertTranspose) { // Ok. Reset(); AddTestTensor("input", {1, 2, 3}); - AddTestWeights("weights", DT_INT32, {4}, {0, 3, 1, 2}); + AddTestWeights("weights", {4}, {0, 3, 1, 2}); RunConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_transpose", &output)); @@ -786,7 +979,14 @@ TEST_F(OpConverterTest, ConvertReshape) { node_def, error::INVALID_ARGUMENT, "Input expects weights for shape, at my_reshape"); } - NodeDef node_def = MakeNodeDef("my_reshape", "Reshape", {"input", "weights"}); + + // Get the NodeDef for Reshape. + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32); + auto reshape = ops::Reshape(s.WithOpName("my_reshape"), input, weights); + const NodeDef& node_def = reshape.operation.node()->def(); + { // Shape is a tensor, should fail. Reset(); @@ -800,7 +1000,7 @@ TEST_F(OpConverterTest, ConvertReshape) { // Reshape to scalar, should fail. Reset(); AddTestTensor("input", {1, 2, 3}); - AddTestWeights("weights", DT_INT32, {}, {}); + AddTestWeights("weights", {0}, {}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, "Reshape to shape=[] is not supported, at my_reshape"); @@ -830,7 +1030,7 @@ TEST_F(OpConverterTest, ConvertReshape) { Reset(); const std::vector& dims = params[i].tensor_dims; AddTestTensor("input", dims, params[i].batch_size); - AddTestWeights("weights", DT_INT32, {4}, params[i].shape); + AddTestWeights("weights", {4}, params[i].shape); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, "Reshape on batch dimension is not supported, at my_reshape", @@ -847,7 +1047,7 @@ TEST_F(OpConverterTest, ConvertReshape) { for (int i = 0; i < kReshapeOKCases; ++i) { Reset(); AddTestTensor("input", ok_params[i].tensor_dims, ok_params[i].batch_size); - AddTestWeights("weights", DT_INT32, {4}, ok_params[i].shape); + AddTestWeights("weights", {4}, ok_params[i].shape); RunConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_reshape", &output)); @@ -869,24 +1069,67 @@ TEST_F(OpConverterTest, ConvertMatMul) { node_def, error::INVALID_ARGUMENT, "Input expects tensor and weights, at my_matmul"); } - NodeDef node_def = MakeNodeDef("my_matmul", "MatMul", {"input", "weights"}); - auto& attr = *node_def.mutable_attr(); - attr["T"].set_type(DT_FLOAT); - attr["transpose_a"].set_b(false); - attr["transpose_b"].set_b(false); - { - AddTestTensor("input", {2}, 1); - AddTestWeights("weights", DT_FLOAT, {2, 1}, {3, 5}); - RunConversion(node_def); - TRT_TensorOrWeights output; - TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output)); - EXPECT_TRUE(output.is_tensor()); - EXPECT_TRUE(TrtDimsEqualsArray({1}, output.tensor()->getDimensions())) - << output.DebugString(); - std::vector output_data(1); - BuildAndRun("input", {2, 7}, "my_matmul", &output_data); - EXPECT_THAT(output_data, ElementsAre(41)); + // Get the NodeDef for Reshape. + auto get_matmul_nodedef = [](DataType dtype, bool transpose_a, + bool transpose_b) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), dtype); + auto weights = ops::Placeholder(s.WithOpName("weights"), dtype); + ops::MatMul::Attrs matmul_attrs; + matmul_attrs.transpose_a_ = transpose_a; + matmul_attrs.transpose_b_ = transpose_b; + auto matmul = + ops::MatMul(s.WithOpName("my_matmul"), input, weights, matmul_attrs); + return matmul.operation.node()->def(); + }; + + { + // Unsupported data type. + Reset(); + NodeDef node_def = get_matmul_nodedef(DT_INT32, false, false); + AddTestTensor("input", {2}, /*batch_size=*/1, nvinfer1::DataType::kINT32); + AddTestWeights("weights", {2, 1}, {3, 5}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Data type is not supported, for node my_matmul got int32"); + } + { + // transpose_a is set. + for (bool transpose_b : {false, true}) { + Reset(); + NodeDef node_def = + get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/true, transpose_b); + AddTestTensor("input", {2}, /*batch_size=*/1); + AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "transpose_a is not supported for TensorRT FullyConnected"); + } + } + { + // OK. + for (bool transpose_b : {false, true}) { + Reset(); + NodeDef node_def = + get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/false, transpose_b); + AddTestTensor("input", {2}, /*batch_size=*/1); + AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); + RunConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output)); + EXPECT_TRUE(output.is_tensor()); + EXPECT_TRUE(TrtDimsEqualsArray({2}, output.tensor()->getDimensions())) + << output.DebugString(); + + std::vector output_data(2); + BuildAndRun("input", {0, 1}, "my_matmul", &output_data); + if (transpose_b) { + EXPECT_THAT(output_data, ElementsAre(1, 3)); + } else { + EXPECT_THAT(output_data, ElementsAre(2, 3)); + } + } } } diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc index b3fb011f16221b..b30d94b0282451 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc @@ -163,7 +163,6 @@ void TRTOptimizationPass::PrintDebugInfo( } else { LOG(INFO) << offset << "No keep ops"; } - LOG(INFO) << item.graph.DebugString(); for (const auto dev : cluster->GetDeviceSet()->devices()) { const auto& pname = dev->parsed_name(); LOG(INFO) << "Device name= " << dev->name() diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py index 99890d910e717d..bb81fbf93f37b9 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -109,7 +109,11 @@ def tensorrt_rewriter_config(rewriter_config=None, if rewriter_config is None: rewriter_config = rewriter_config_pb2.RewriterConfig() - rewriter_config.optimizers.extend(["constfold", "layout"]) + # Layout optimizer may add Const nodes followed by Reshape nodes, thus we + # need to run constant folding again. + rewriter_config.optimizers.extend(["constfold", "layout", "constfold"]) + rewriter_config.meta_optimizer_iterations = ( + rewriter_config_pb2.RewriterConfig.ONE) if precision_mode.upper() not in TrtPrecisionMode.supported_precision_modes(): raise ValueError(("precision mode '{}' is not supported." diff --git a/tensorflow/contrib/tensorrt/python/trt_convert_test.py b/tensorflow/contrib/tensorrt/python/trt_convert_test.py index 530adafcb3fe67..9f2eeac990dcac 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert_test.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert_test.py @@ -26,6 +26,7 @@ # pylint: enable=unused-import from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import graph_util from tensorflow.python.framework import importer @@ -57,7 +58,10 @@ def testTensorrtRewriterConfig(self): is_dynamic_op=True, maximum_cached_engines=2, cached_engine_batch_sizes=[1, 128]) - self.assertEqual(["constfold", "layout"], rewriter_cfg.optimizers) + self.assertEqual(["constfold", "layout", "constfold"], + rewriter_cfg.optimizers) + self.assertEqual(rewriter_config_pb2.RewriterConfig.ONE, + rewriter_cfg.meta_optimizer_iterations) trt_optimizer = None for optimizer in rewriter_cfg.custom_optimizers: if optimizer.name == "TensorRTOptimizer": diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc index c82d4a018392be..4f64b7a9522a17 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.cc +++ b/tensorflow/contrib/tensorrt/segment/segment.cc @@ -389,7 +389,7 @@ void ContractEdge(SimpleEdge* edge, SimpleGraph* graph, tensorflow::Status SegmentGraph( const tensorflow::Graph* tf_graph, - const std::function& candidate_fn, + const std::function& candidate_fn, const std::function& input_candidate_fn, const std::function& output_candidate_fn, const SegmentOptions& options, SegmentNodesVector* segments) { @@ -409,9 +409,16 @@ tensorflow::Status SegmentGraph( std::vector> node_segments; for (int i = 0; i < graph->num_node_ids(); ++i) { SimpleNode* node = graph->FindNodeId(i); - if (options.exclude_node_list.count(node->name()) != 0 || - !candidate_fn(node->tf_node())) { + if (options.exclude_node_list.count(node->name()) != 0) { + VLOG(1) << "Not a TF-TRT candidate: " << node->name() + << " (excluded by segmenter option)."; node = nullptr; + } else { + const Status status = candidate_fn(node->tf_node()); + if (!status.ok()) { + VLOG(1) << "Not a TF-TRT candidate: " << node->name() << ": " << status; + node = nullptr; + } } node_segments.emplace_back(node); } diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h index 8c44eb782aa370..b9693aad1b7645 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.h +++ b/tensorflow/contrib/tensorrt/segment/segment.h @@ -43,7 +43,7 @@ struct SegmentOptions { // Get the subgraphs of a graph that can be handled by TensorRT. // // @param graph tensorflow::Graph of the network -// @param candidate_fn A function that returns true for a Node* if +// @param candidate_fn A function that returns OK for a Node* if // that node can be handled by TensorRT. // @param segments Returns the TensorRT segments/subgraphs. Each entry // in the vector describes a subgraph by giving a set of the names of @@ -51,7 +51,7 @@ struct SegmentOptions { // @return the status. tensorflow::Status SegmentGraph( const tensorflow::Graph* tf_graph, - const std::function& candidate_fn, + const std::function& candidate_fn, const std::function& input_candidate_fn, const std::function& output_candidate_fn, const SegmentOptions& options, SegmentNodesVector* segments); diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc index 5937fa8259a393..4805ef9c61a778 100644 --- a/tensorflow/contrib/tensorrt/segment/segment_test.cc +++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc @@ -34,10 +34,13 @@ namespace ops = ::tensorflow::ops; class SegmentTest : public ::testing::Test { protected: - std::function MakeCandidateFn( + std::function MakeCandidateFn( const std::set& node_names) { - return [node_names](const tensorflow::Node* node) -> bool { - return node_names.find(node->name()) != node_names.end(); + return [node_names](const tensorflow::Node* node) -> Status { + if (node_names.find(node->name()) != node_names.end()) { + return Status::OK(); + } + return errors::NotFound(""); }; } diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py index 7d006b73d53631..7545bb9df20f29 100644 --- a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py +++ b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py @@ -118,30 +118,11 @@ def GetConversionParams(self, run_params): """Return a ConversionParams for test.""" return super(BiasaddMatMulTest, self).GetConversionParams(run_params)._replace( - max_batch_size=4, maximum_cached_engines=2) - - def _ValidEngines(self): - """Engines expected to build and run.""" - return ["my_trt_op_0"] - - def _InvalidEngines(self): - """Engines that will cause conversion error at building time.""" - return ["my_trt_op_1", "my_trt_op_2"] + max_batch_size=4, maximum_cached_engines=1) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - # In dynamic engine mode the engines are built in execution time, not in - # conversion time, so build errors occurs later. Here three of the engines - # will be failed to built but the corresponding engine op are still created. - # TODO(aaroey, jjsjann123): fix this. - if (run_params.dynamic_engine and - not trt_test.IsQuantizationMode(run_params.precision_mode)): - return self._ValidEngines() + self._InvalidEngines() - return self._ValidEngines() - - def ExpectedEnginesToRun(self, run_params): - """Return the expected engines to run.""" - return self._ValidEngines() + return ["my_trt_op_0"] def ShouldRunTest(self, run_params): """Whether to run the test.""" diff --git a/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py b/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py index 3cf7dadb1f4722..bbc724ab18e18b 100644 --- a/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py +++ b/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py @@ -79,7 +79,7 @@ def GetParams(self): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" return { - "my_trt_op_3": ["reshape-%d" % i for i in range(7)] + + "my_trt_op_0": ["reshape-%d" % i for i in range(7)] + ["reshape-%d/shape" % i for i in range(7)] } diff --git a/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py b/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py index 832d34d60d0553..49260f272eeb27 100644 --- a/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py +++ b/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py @@ -339,7 +339,7 @@ def test_filter_input_filter_vocab(self): lookup.KeyValueTensorInitializer(keys, values), -1) with self.cached_session(): - vocab_freq_table.init.run() + vocab_freq_table.initializer.run() # No vocab_freq_table specified - output should be the same as input. no_table_output = skip_gram_ops._filter_input( @@ -396,7 +396,7 @@ def test_filter_input_subsample_vocab(self): lookup.KeyValueTensorInitializer(keys, values), -1) with self.cached_session(): - vocab_freq_table.init.run() + vocab_freq_table.initializer.run() output = skip_gram_ops._filter_input( input_tensor=input_tensor, vocab_freq_table=vocab_freq_table, diff --git a/tensorflow/contrib/timeseries/python/timeseries/model.py b/tensorflow/contrib/timeseries/python/timeseries/model.py index 4edfbe58a1973f..edd97b2a4c131d 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/model.py @@ -21,6 +21,8 @@ import abc import collections +import six + from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.contrib.timeseries.python.timeseries.feature_keys import PredictionFeatures from tensorflow.contrib.timeseries.python.timeseries.feature_keys import TrainEvalFeatures @@ -53,11 +55,10 @@ ]) +@six.add_metaclass(abc.ABCMeta) class TimeSeriesModel(object): """Base class for creating generative time series models.""" - __metaclass__ = abc.ABCMeta - def __init__(self, num_features, exogenous_feature_columns=None, diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor.py index 7fa538a16ecd7d..e9e2ac0aaf4c4d 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor.py @@ -20,6 +20,8 @@ import abc +import six + from tensorflow.contrib import distributions from tensorflow.contrib.timeseries.python.timeseries import math_utils @@ -32,11 +34,10 @@ from tensorflow.python.util import nest +@six.add_metaclass(abc.ABCMeta) class FilteringStepPostprocessor(object): """Base class for processors that are applied after each filter step.""" - __metaclass__ = abc.ABCMeta - @abc.abstractmethod def process_filtering_step(self, current_times, current_values, predicted_state, filtered_state, outputs): diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto index c2e3be03db0e4c..aae1ab1d37a166 100644 --- a/tensorflow/contrib/tpu/proto/optimization_parameters.proto +++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto @@ -154,6 +154,14 @@ message OptimizationParameters { // updates; not present means no limits are applied. ClippingLimits gradient_clipping_limits = 7; + // Amount of weight decay to apply; see weight_decay_optimizers.py for + // details. Almost all optimizers are supported with this option (MDL Adagrad + // Light does not work, and SGD does not behave as expected if it is enabled). + // Although there is no check, users who want weight decay will probably also + // want to enable gradient accumulation as well so that the decay will happen + // once per minibatch. + float weight_decay_factor = 16; + // Whether to use gradient accumulation (do two passes over the input // gradients: one to accumulate them into a temporary array and another to // apply them using the actual optimization algorithm). This feature is diff --git a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py index 78253d83fc4dcc..c32bd5997c1493 100644 --- a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py +++ b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py @@ -102,7 +102,8 @@ def _write_graph_fn(self): training_util.write_graph( ops.get_default_graph().as_graph_def(add_shapes=True), self._checkpoint_dir, "graph.pbtxt") - self._write_graph_thread = threading.Thread(target=_write_graph_fn) + self._write_graph_thread = threading.Thread(target=_write_graph_fn, + args=[self]) self._write_graph_thread.start() saver_def = self._get_saver().saver_def if self._get_saver() else None diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 9f76435bebc990..ce2c322ff49382 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -543,6 +543,7 @@ def make_feed_dict(self, tpu_model_op): pass +@six.add_metaclass(abc.ABCMeta) class TPUInfeedManager(object): """TPUInfeedManager manages the data infeeding of data to a TPU computation. @@ -977,7 +978,7 @@ def _model_fn(): # When running on more than one core, concatenate outputs at the end # of processing. In backprop stage, the gradients will be - # calculdated according to the local inputs as gradient of + # calculated according to the local inputs as gradient of # cross-replica-concat being zero for any outputs other than those # from mlocal core so the loss calculation is identical. num_towers = self.model._tpu_assignment.num_towers @@ -1004,7 +1005,9 @@ def _model_fn(): for tensor in tpu_targets ] - if is_training or is_test: + if is_training or is_test: + with variable_scope.variable_scope( + 'metrics', reuse=variable_scope.AUTO_REUSE): self._cloned_model.compile( optimizer=_replicated_optimizer(self._cloned_optimizer), loss=self.model.loss, @@ -1023,29 +1026,29 @@ def _model_fn(): # the Momentum optimizer) when _make_train_function is invoked. with keras_tpu_variables.replicated_variable_for_optimizer( self._tpu_assignment.num_towers): - self._cloned_model._make_train_function() + self._cloned_model._make_fit_function() else: - self._cloned_model._make_train_function() + self._cloned_model._make_fit_function() self._outfeed_spec = [ tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name) - for tensor in self._cloned_model.train_function.outputs + for tensor in self._cloned_model._fit_function.outputs ] return [ - self._cloned_model.train_function.updates_op, + self._cloned_model._fit_function.updates_op, tpu_ops.outfeed_enqueue_tuple( - self._cloned_model.train_function.outputs, + self._cloned_model._fit_function.outputs, name='outfeed-enqueue-train') ] elif is_test: - self._cloned_model._make_test_function() + self._cloned_model._make_eval_function() self._outfeed_spec = [ tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name) - for tensor in self._cloned_model.test_function.outputs + for tensor in self._cloned_model._eval_function.outputs ] return [ tpu_ops.outfeed_enqueue_tuple( - self._cloned_model.test_function.outputs, + self._cloned_model._eval_function.outputs, name='outfeed-enqueue-test') ] elif is_predict: @@ -1071,7 +1074,7 @@ def _model_fn(): # `execute op` replicates `_model_fn` `num_replicas` times, with each shard # running on a different logical core. compile_op, execute_op = tpu.split_compile_and_replicate( - _model_fn, inputs=[[]] * self._tpu_assignment.num_towers) + _model_fn, inputs=[[] for _ in range(self._tpu_assignment.num_towers)]) # Generate CPU side operations to enqueue features/labels and dequeue # outputs from the model call. @@ -1216,7 +1219,7 @@ def _process_outputs(self, outfeed_outputs): """ # TODO(xiejw): Decide how to reduce outputs, or discard all but first. if self.execution_mode == model_fn_lib.ModeKeys.PREDICT: - outputs = [[]] * len(self._outfeed_spec) + outputs = [[] for _ in range(len(self._outfeed_spec))] outputs_per_replica = len(self._outfeed_spec) for i in range(self._tpu_assignment.num_towers): @@ -1375,6 +1378,8 @@ def __init__(self, cpu_model, strategy): self.predict_function = None self.test_function = None self.train_function = None + self._fit_function = None + self._eval_function = None cluster_resolver = strategy._tpu_cluster_resolver self._tpu_name_or_address = cluster_resolver.get_master() @@ -1536,10 +1541,17 @@ def evaluate(self, verbose=1, sample_weight=None, steps=None): - assert not self._numpy_to_infeed_manager_list # Ensure empty. + original_numpy_to_infeed_manager_list = [] + if self._numpy_to_infeed_manager_list: + # evaluate call may be executed as callbacks during the training. In this + # case, _numpy_to_infeed_manager_list is not empty, so save it for + # recovery at the end of evaluate call. + original_numpy_to_infeed_manager_list = self._numpy_to_infeed_manager_list + self._numpy_to_infeed_manager_list = [] with _tpu_session_context(): - infeed_managers = [] # Managers to clean up at the end of the fit call. + # Managers to clean up at the end of the evaluate call. + infeed_managers = [] if isinstance(x, dataset_ops.Dataset): # TODO(b/111413240): Support taking a tf.data.Dataset directly. raise ValueError( @@ -1569,7 +1581,8 @@ def evaluate(self, return super(KerasTPUModel, self).evaluate(x, y, batch_size, verbose, sample_weight, steps) finally: - self._numpy_to_infeed_manager_list = [] + self._numpy_to_infeed_manager_list = ( + original_numpy_to_infeed_manager_list) def _pipeline_fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, @@ -1985,10 +1998,21 @@ def optimizer(self): def optimizer(self, optimizer): self._optimizer = optimizer + @property + def stateful_metric_functions(self): + if self._tpu_model: + return self._tpu_model.stateful_metric_functions + return self._stateful_metric_functions + + @stateful_metric_functions.setter + def stateful_metric_functions(self, stateful_metric_functions): + self._stateful_metric_functions = stateful_metric_functions + def _make_train_function(self): if not self.train_function: self.train_function = TPUFunction( - self, model_fn_lib.ModeKeys.TRAIN, + self, + model_fn_lib.ModeKeys.TRAIN, tpu_assignment=self._tpu_assignment) return self.train_function @@ -1999,6 +2023,21 @@ def _make_test_function(self): self, model_fn_lib.ModeKeys.EVAL, tpu_assignment=self._tpu_assignment) return self.test_function + def _make_fit_function(self): + if not self._fit_function: + self._fit_function = TPUFunction( + self, + model_fn_lib.ModeKeys.TRAIN, + tpu_assignment=self._tpu_assignment) + + return self._fit_function + + def _make_eval_function(self): + if not self._eval_function: + self._eval_function = TPUFunction( + self, model_fn_lib.ModeKeys.EVAL, tpu_assignment=self._tpu_assignment) + return self._eval_function + def _make_predict_function(self): if not self.predict_function: self.predict_function = TPUFunction( diff --git a/tensorflow/contrib/tpu/python/tpu/topology.py b/tensorflow/contrib/tpu/python/tpu/topology.py index b6bb5c6e56c740..6ae718cc2c9716 100644 --- a/tensorflow/contrib/tpu/python/tpu/topology.py +++ b/tensorflow/contrib/tpu/python/tpu/topology.py @@ -189,12 +189,13 @@ def tpu_device_ordinal_at_coordinates(self, device_coordinates): def cpu_device_name_at_coordinates(self, device_coordinates, job=None): """Returns the CPU device attached to a logical core.""" return _tpu_host_device_name( - job, self._topology_tasks[device_coordinates]) + job, self._topology_tasks[tuple(device_coordinates)]) def tpu_device_name_at_coordinates(self, device_coordinates, job=None): """Returns the name of the TPU device assigned to a logical core.""" - return _tpu_device_name(job, self._topology_tasks[device_coordinates], - self._topology_devices[device_coordinates]) + return _tpu_device_name(job, + self._topology_tasks[tuple(device_coordinates)], + self._topology_devices[tuple(device_coordinates)]) @property def num_tasks(self): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index a51dcc8020a423..e3e791faacb9b3 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -372,14 +372,11 @@ def AddOp(self, op): if external_control_inputs: # Use an identity to pull control inputs as data inputs. Note that we # ignore ops which don't have outputs. TODO(phawkins): fix that. - with ops.control_dependencies(None): - self.Enter() - external_control_inputs = [ - array_ops.identity(x.outputs[0]).op - for x in external_control_inputs - if x.outputs - ] - self.Exit() + external_control_inputs = [ + array_ops.identity(x.outputs[0]).op + for x in external_control_inputs + if x.outputs + ] # pylint: disable=protected-access op._add_control_inputs(external_control_inputs) # pylint: enable=protected-access diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py index 65ac9a6224e594..3fe896426a7ae5 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py @@ -319,7 +319,7 @@ def __init__(self, mode: `TRAINING` or `INFERENCE`. optimization_parameters: `AdagradParameters`, `AdamParameters`, `Stochasticgradientdescentparameters`. Must be set in training and must - not be `None` in inference. + be `None` in inference. tpu_embedding_test: A `bool`. Only used for testing. Raises: diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 8b0b618d44de18..555ad0f1fdbe36 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -2998,6 +2998,12 @@ def __init__(self, message): control_flow_ops.ControlFlowContext.__init__(self) self._message = message + def to_control_flow_context_def(self, context_def, export_scope=None): + # pylint: disable=useless-super-delegation + # NOTE(slebedev): the method is required by `ControlFlowContext`. + super(_CapturingContext, self).to_control_flow_context_def( + context_def, export_scope) + def AddOp(self, op): # pylint: disable=invalid-name for c in op.inputs: if tpu._TPU_REPLICATE_ATTR in c.op.node_def.attr: # pylint: disable=protected-access diff --git a/tensorflow/contrib/training/python/training/tuner.py b/tensorflow/contrib/training/python/training/tuner.py index 8843632619f088..ad647a61da7adb 100644 --- a/tensorflow/contrib/training/python/training/tuner.py +++ b/tensorflow/contrib/training/python/training/tuner.py @@ -21,9 +21,12 @@ import abc +import six + from tensorflow.contrib.framework.python.framework import experimental +@six.add_metaclass(abc.ABCMeta) class Tuner(object): """Tuner class is the interface for Experiment hyper-parameters tuning. @@ -42,8 +45,6 @@ def _create_my_experiment(run_config, hparams): learn_runner.tune(experiment_fn=_create_my_experiment, tuner) """ - __metaclass__ = abc.ABCMeta - @experimental @abc.abstractmethod def next_trial(self): diff --git a/tensorflow/contrib/util/__init__.py b/tensorflow/contrib/util/__init__.py index 338acef63f2446..acc5a049aa8764 100644 --- a/tensorflow/contrib/util/__init__.py +++ b/tensorflow/contrib/util/__init__.py @@ -15,8 +15,6 @@ """Utilities for dealing with Tensors. -See [Contrib Util](https://tensorflow.org/api_guides/python/contrib.util) guide. - @@constant_value @@make_tensor_proto @@make_ndarray diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 4628258efc4a7f..afe4c46c8efc59 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -131,7 +131,6 @@ load( "tf_kernel_tests_linkstatic", "tf_lib_proto_compiler_deps", "tf_lib_proto_parsing_deps", - "tf_nano_proto_library", "tf_platform_hdrs", "tf_platform_srcs", "tf_proto_library", @@ -252,15 +251,6 @@ tf_jspb_proto_library( deps = [":protos_all_cc"], ) -tf_nano_proto_library( - name = "protos_all_nano_proto", - field_style = "accessors", - generate_equals = 1, - generate_intdefs = 1, - visibility = ["//visibility:public"], - deps = [":protos_all_cc"], -) - proto_library( name = "example_protos", srcs = [ @@ -2405,7 +2395,6 @@ cc_library( ]), hdrs = [ "lib/bfloat16/bfloat16.h", - "lib/core/casts.h", "lib/core/stringpiece.h", "lib/png/png_io.h", "platform/byte_order.h", @@ -2525,6 +2514,7 @@ FRAMEWORK_INTERNAL_PRIVATE_HEADERS = [ }) FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [ + "framework/model.h", # only needed for tests "framework/op_segment.h", "framework/rendezvous.h", # only needed for tests "framework/resource_var.h", @@ -3621,6 +3611,7 @@ tf_cc_tests( "framework/kernel_def_builder_test.cc", "framework/kernel_def_util_test.cc", "framework/memory_types_test.cc", + "framework/model_test.cc", "framework/node_def_builder_test.cc", "framework/node_def_util_test.cc", "framework/op_compatibility_test.cc", @@ -4378,6 +4369,7 @@ tf_cc_test( "//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:function_ops", "//tensorflow/core/kernels:matmul_op", + "//tensorflow/core/kernels:partitioned_function_ops", "//tensorflow/core/kernels:random_ops", "//tensorflow/core/kernels:shape_ops", "//third_party/eigen3", diff --git a/tensorflow/core/api_def/BUILD b/tensorflow/core/api_def/BUILD index 06b797e32edc04..f610facd75d8cf 100644 --- a/tensorflow/core/api_def/BUILD +++ b/tensorflow/core/api_def/BUILD @@ -17,6 +17,10 @@ load( "tf_cc_binary", "tf_cc_test", ) +load( + "//third_party/mkl:build_defs.bzl", + "if_mkl", +) filegroup( name = "base_api_def", @@ -40,6 +44,7 @@ cc_library( name = "excluded_ops_lib", srcs = ["excluded_ops.cc"], hdrs = ["excluded_ops.h"], + copts = if_mkl(["-DINTEL_MKL=1"]), ) cc_library( diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc index 51812caeb29792..6f988569159536 100644 --- a/tensorflow/core/api_def/api_test.cc +++ b/tensorflow/core/api_def/api_test.cc @@ -176,6 +176,19 @@ void TestDeprecatedAttributesSetCorrectly( } } } + +void TestDeprecationVersionSetCorrectly( + const std::unordered_map& api_defs_map) { + for (const auto& name_and_api_def : api_defs_map) { + const auto& name = name_and_api_def.first; + const auto& api_def = name_and_api_def.second; + ASSERT_TRUE(api_def.deprecation_version() == 0 || + api_def.deprecation_message().empty()) + << "ApiDef that includes deprecation_version > 0 must also specify " + << "a deprecation_message. Op " << name + << " has deprecation_version > 0 but deprecation_message is not set."; + } +} } // namespace class BaseApiTest : public ::testing::Test { @@ -268,6 +281,12 @@ TEST_F(BaseApiTest, DeprecationSetCorrectly) { TestDeprecatedAttributesSetCorrectly(api_defs_map_); } +// Checks that deprecation_version is set for entire op only if +// deprecation_message is set. +TEST_F(BaseApiTest, DeprecationVersionSetCorrectly) { + TestDeprecationVersionSetCorrectly(api_defs_map_); +} + class PythonApiTest : public ::testing::Test { protected: PythonApiTest() { @@ -309,4 +328,10 @@ TEST_F(PythonApiTest, DeprecationSetCorrectly) { TestDeprecatedAttributesSetCorrectly(api_defs_map_); } +// Checks that deprecation_version is set for entire op only if +// deprecation_message is set. +TEST_F(PythonApiTest, DeprecationVersionSetCorrectly) { + TestDeprecationVersionSetCorrectly(api_defs_map_); +} + } // namespace tensorflow diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalNonSerializableDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalNonSerializableDataset.pbtxt new file mode 100644 index 00000000000000..08632aa262a35b --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalNonSerializableDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ExperimentalNonSerializableDataset" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/excluded_ops.cc b/tensorflow/core/api_def/excluded_ops.cc index 931c943dbc803c..02026e94abc5b3 100644 --- a/tensorflow/core/api_def/excluded_ops.cc +++ b/tensorflow/core/api_def/excluded_ops.cc @@ -21,7 +21,19 @@ const std::unordered_set* GetExcludedOps() { static std::unordered_set* excluded_ops = new std::unordered_set( {"BigQueryReader", "GenerateBigQueryReaderPartitions", - "GcsConfigureBlockCache", "GcsConfigureCredentials"}); + "GcsConfigureBlockCache", "GcsConfigureCredentials", +#ifdef INTEL_MKL + // QuantizedFusedOps for Intel CPU + "QuantizedConv2DAndRequantize", "QuantizedConv2DWithBias", + "QuantizedConv2DWithBiasAndRequantize", "QuantizedConv2DAndRelu", + "QuantizedConv2DAndReluAndRequantize", + "QuantizedConv2DWithBiasAndRelu", + "QuantizedConv2DWithBiasAndReluAndRequantize", + "QuantizedConv2DWithBiasSumAndRelu", + "QuantizedConv2DWithBiasSumAndReluAndRequantize", + "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize" +#endif // INTEL_MKL + }); return excluded_ops; } } // namespace tensorflow diff --git a/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt b/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt index e395e333bf5104..801dfbc28545da 100644 --- a/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "manip.batch_to_space_nd" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_Betainc.pbtxt b/tensorflow/core/api_def/python_api/api_def_Betainc.pbtxt index 7ad7cbcba9a906..1c90c56f5e73f9 100644 --- a/tensorflow/core/api_def/python_api/api_def_Betainc.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Betainc.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "betainc" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_Ceil.pbtxt b/tensorflow/core/api_def/python_api/api_def_Ceil.pbtxt index f2265bad56cd8c..331bb9cbf5581c 100644 --- a/tensorflow/core/api_def/python_api/api_def_Ceil.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Ceil.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "ceil" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_CheckNumerics.pbtxt b/tensorflow/core/api_def/python_api/api_def_CheckNumerics.pbtxt index 541b09a591fcdd..33110d8c9ec3ff 100644 --- a/tensorflow/core/api_def/python_api/api_def_CheckNumerics.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_CheckNumerics.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "check_numerics" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_Cholesky.pbtxt b/tensorflow/core/api_def/python_api/api_def_Cholesky.pbtxt index 942f4e6ed8da2b..5db2667262686f 100644 --- a/tensorflow/core/api_def/python_api/api_def_Cholesky.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Cholesky.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "cholesky" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_Cross.pbtxt b/tensorflow/core/api_def/python_api/api_def_Cross.pbtxt index e8a871cae6b101..51394dda4e9797 100644 --- a/tensorflow/core/api_def/python_api/api_def_Cross.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Cross.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "cross" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_DecodeBase64.pbtxt b/tensorflow/core/api_def/python_api/api_def_DecodeBase64.pbtxt index 8b96eee6311e4a..e4a61e122ceddb 100644 --- a/tensorflow/core/api_def/python_api/api_def_DecodeBase64.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_DecodeBase64.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "decode_base64" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_DecodeCompressed.pbtxt b/tensorflow/core/api_def/python_api/api_def_DecodeCompressed.pbtxt index 829608fc8f9ae9..a85a76a8dc6669 100644 --- a/tensorflow/core/api_def/python_api/api_def_DecodeCompressed.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_DecodeCompressed.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "decode_compressed" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_DecodeJSONExample.pbtxt b/tensorflow/core/api_def/python_api/api_def_DecodeJSONExample.pbtxt index 9f28bc5f59bdc1..13ffbcce7c71cc 100644 --- a/tensorflow/core/api_def/python_api/api_def_DecodeJSONExample.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_DecodeJSONExample.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "decode_json_example" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_DecodeRaw.pbtxt b/tensorflow/core/api_def/python_api/api_def_DecodeRaw.pbtxt index 0010a59ca40adb..dab7a5e0094721 100644 --- a/tensorflow/core/api_def/python_api/api_def_DecodeRaw.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_DecodeRaw.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "decode_raw" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_Dequantize.pbtxt b/tensorflow/core/api_def/python_api/api_def_Dequantize.pbtxt index 5edd0c216ba4ed..96844a65b510cb 100644 --- a/tensorflow/core/api_def/python_api/api_def_Dequantize.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Dequantize.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "dequantize" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_Diag.pbtxt b/tensorflow/core/api_def/python_api/api_def_Diag.pbtxt index cba30e63e892cf..43e7af891c510c 100644 --- a/tensorflow/core/api_def/python_api/api_def_Diag.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Diag.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "diag" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_DiagPart.pbtxt b/tensorflow/core/api_def/python_api/api_def_DiagPart.pbtxt index 54e1f34e82b3c5..6a149848f69e63 100644 --- a/tensorflow/core/api_def/python_api/api_def_DiagPart.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_DiagPart.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "diag_part" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_Digamma.pbtxt b/tensorflow/core/api_def/python_api/api_def_Digamma.pbtxt index 91b4dfead77664..e6e9375ecd954b 100644 --- a/tensorflow/core/api_def/python_api/api_def_Digamma.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Digamma.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "digamma" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_EncodeBase64.pbtxt b/tensorflow/core/api_def/python_api/api_def_EncodeBase64.pbtxt index 71bb73cfb24ee8..534b5d8152c149 100644 --- a/tensorflow/core/api_def/python_api/api_def_EncodeBase64.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_EncodeBase64.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "encode_base64" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_Erfc.pbtxt b/tensorflow/core/api_def/python_api/api_def_Erfc.pbtxt index e96df0c596ab19..fccda9dfca5967 100644 --- a/tensorflow/core/api_def/python_api/api_def_Erfc.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Erfc.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "erfc" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_Expm1.pbtxt b/tensorflow/core/api_def/python_api/api_def_Expm1.pbtxt index 8ddf9d4d70f491..d8bdaeadc88ca0 100644 --- a/tensorflow/core/api_def/python_api/api_def_Expm1.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Expm1.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "expm1" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_ExtractImagePatches.pbtxt b/tensorflow/core/api_def/python_api/api_def_ExtractImagePatches.pbtxt index f008b1222deeca..0bd8b1c11aa15b 100644 --- a/tensorflow/core/api_def/python_api/api_def_ExtractImagePatches.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_ExtractImagePatches.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "extract_image_patches" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_FFT.pbtxt b/tensorflow/core/api_def/python_api/api_def_FFT.pbtxt index d79e936b7195dd..7f4a2add4e7131 100644 --- a/tensorflow/core/api_def/python_api/api_def_FFT.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_FFT.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "fft" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgs.pbtxt b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgs.pbtxt index d8db83331f916c..97ab3ff7efc1f4 100644 --- a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgs.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgs.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "fake_quant_with_min_max_args" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgsGradient.pbtxt b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgsGradient.pbtxt index 74f01d1a0c5691..a30bdc35343410 100644 --- a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgsGradient.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgsGradient.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "fake_quant_with_min_max_args_gradient" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVars.pbtxt b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVars.pbtxt index e14fb6d118ada9..fc64d0e15acea3 100644 --- a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVars.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVars.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "fake_quant_with_min_max_vars" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsGradient.pbtxt b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsGradient.pbtxt index 4611ebdfb82860..66fcfbb8466675 100644 --- a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsGradient.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsGradient.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "fake_quant_with_min_max_vars_gradient" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannel.pbtxt b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannel.pbtxt index 0936e513c3ff09..132ecc1ac49f57 100644 --- a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannel.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannel.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "fake_quant_with_min_max_vars_per_channel" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannelGradient.pbtxt b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannelGradient.pbtxt index 0d9968248c5397..66c811b6a26ca1 100644 --- a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannelGradient.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannelGradient.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "fake_quant_with_min_max_vars_per_channel_gradient" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt b/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt index 598f23bde3c3ca..1c3b9d5571e325 100644 --- a/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "manip.gather_nd" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_IFFT.pbtxt b/tensorflow/core/api_def/python_api/api_def_IFFT.pbtxt index 17fbd8ace4333f..0124721e1cb185 100644 --- a/tensorflow/core/api_def/python_api/api_def_IFFT.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_IFFT.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "ifft" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_Igamma.pbtxt b/tensorflow/core/api_def/python_api/api_def_Igamma.pbtxt index 8c4815c26eeabc..c07932a1a7ad0b 100644 --- a/tensorflow/core/api_def/python_api/api_def_Igamma.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Igamma.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "igamma" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_Igammac.pbtxt b/tensorflow/core/api_def/python_api/api_def_Igammac.pbtxt index b43b54391b7d8f..8031a51db96e18 100644 --- a/tensorflow/core/api_def/python_api/api_def_Igammac.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Igammac.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "igammac" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_InvertPermutation.pbtxt b/tensorflow/core/api_def/python_api/api_def_InvertPermutation.pbtxt index d75fcd63e3baeb..d75cef5fac3e15 100644 --- a/tensorflow/core/api_def/python_api/api_def_InvertPermutation.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_InvertPermutation.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "invert_permutation" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_IsFinite.pbtxt b/tensorflow/core/api_def/python_api/api_def_IsFinite.pbtxt index 27142644bf098b..91160bd8bfa776 100644 --- a/tensorflow/core/api_def/python_api/api_def_IsFinite.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_IsFinite.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "is_finite" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_IsInf.pbtxt b/tensorflow/core/api_def/python_api/api_def_IsInf.pbtxt index 4cd92f1cb78f22..7f029ee8cf0c7c 100644 --- a/tensorflow/core/api_def/python_api/api_def_IsInf.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_IsInf.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "is_inf" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_IsNan.pbtxt b/tensorflow/core/api_def/python_api/api_def_IsNan.pbtxt index 07d49f9436ea26..f2b8862c28d496 100644 --- a/tensorflow/core/api_def/python_api/api_def_IsNan.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_IsNan.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "is_nan" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_Lgamma.pbtxt b/tensorflow/core/api_def/python_api/api_def_Lgamma.pbtxt index 0262b838caa0e3..ee339967a48962 100644 --- a/tensorflow/core/api_def/python_api/api_def_Lgamma.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Lgamma.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "lgamma" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_MatchingFiles.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatchingFiles.pbtxt index 74145670a8f956..2dc5916f60f3f0 100644 --- a/tensorflow/core/api_def/python_api/api_def_MatchingFiles.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_MatchingFiles.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "matching_files" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixBandPart.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixBandPart.pbtxt index 1122c52ab40423..c8aaf44b0d458c 100644 --- a/tensorflow/core/api_def/python_api/api_def_MatrixBandPart.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_MatrixBandPart.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "matrix_band_part" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixDeterminant.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixDeterminant.pbtxt index 9563bf0354598a..64a5950e56a287 100644 --- a/tensorflow/core/api_def/python_api/api_def_MatrixDeterminant.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_MatrixDeterminant.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "matrix_determinant" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixDiag.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixDiag.pbtxt index 8ab0bf75ebc5a4..57dc182474e242 100644 --- a/tensorflow/core/api_def/python_api/api_def_MatrixDiag.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_MatrixDiag.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "matrix_diag" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixDiagPart.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixDiagPart.pbtxt index 82ce67853c9507..142763f44bdc0a 100644 --- a/tensorflow/core/api_def/python_api/api_def_MatrixDiagPart.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_MatrixDiagPart.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "matrix_diag_part" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixInverse.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixInverse.pbtxt index 85862f6eb57096..13df986ac17620 100644 --- a/tensorflow/core/api_def/python_api/api_def_MatrixInverse.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_MatrixInverse.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "matrix_inverse" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixSetDiag.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixSetDiag.pbtxt index 6325e4f0e6e021..fc97a29cf21631 100644 --- a/tensorflow/core/api_def/python_api/api_def_MatrixSetDiag.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_MatrixSetDiag.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "matrix_set_diag" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixSolve.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixSolve.pbtxt index 6325dff407af71..0bbc9891590efa 100644 --- a/tensorflow/core/api_def/python_api/api_def_MatrixSolve.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_MatrixSolve.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "matrix_solve" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt index 7f865e23b2ab90..17dc57335ae47e 100644 --- a/tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "matrix_triangular_solve" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_ParseTensor.pbtxt b/tensorflow/core/api_def/python_api/api_def_ParseTensor.pbtxt index 10b3aab0c771ec..6ea8094565b54c 100644 --- a/tensorflow/core/api_def/python_api/api_def_ParseTensor.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_ParseTensor.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "parse_tensor" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_Polygamma.pbtxt b/tensorflow/core/api_def/python_api/api_def_Polygamma.pbtxt index 9df81402d55242..33c96505ba4d23 100644 --- a/tensorflow/core/api_def/python_api/api_def_Polygamma.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Polygamma.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "polygamma" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_Qr.pbtxt b/tensorflow/core/api_def/python_api/api_def_Qr.pbtxt index 0260eecc9172f2..e3a0e9d45a596c 100644 --- a/tensorflow/core/api_def/python_api/api_def_Qr.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Qr.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "qr" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_QuantizedConcat.pbtxt b/tensorflow/core/api_def/python_api/api_def_QuantizedConcat.pbtxt index 69404b947257d2..937a1a813d4845 100644 --- a/tensorflow/core/api_def/python_api/api_def_QuantizedConcat.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_QuantizedConcat.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "quantized_concat" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_ReadFile.pbtxt b/tensorflow/core/api_def/python_api/api_def_ReadFile.pbtxt index 9d479be45ff483..a671bc3ed14910 100644 --- a/tensorflow/core/api_def/python_api/api_def_ReadFile.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_ReadFile.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "read_file" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_Reciprocal.pbtxt b/tensorflow/core/api_def/python_api/api_def_Reciprocal.pbtxt index c4d4c27722266f..d10b87b6a7be8c 100644 --- a/tensorflow/core/api_def/python_api/api_def_Reciprocal.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Reciprocal.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "reciprocal" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt b/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt index b3d596de7aaede..ee20249094cb7d 100644 --- a/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "manip.reshape" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt index 51478b7c3434d4..9ff0506c4e7f15 100644 --- a/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt @@ -5,10 +5,10 @@ op { } endpoint { name: "manip.reverse" - deprecated: true + deprecation_version: 2 } endpoint { name: "reverse_v2" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_Rint.pbtxt b/tensorflow/core/api_def/python_api/api_def_Rint.pbtxt index ec37a231273cf4..06e02f354c93d8 100644 --- a/tensorflow/core/api_def/python_api/api_def_Rint.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Rint.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "rint" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_Rsqrt.pbtxt b/tensorflow/core/api_def/python_api/api_def_Rsqrt.pbtxt index 4fc2b8142108e0..3cfbfc1106e68f 100644 --- a/tensorflow/core/api_def/python_api/api_def_Rsqrt.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Rsqrt.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "rsqrt" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt b/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt index 85888da45a2296..b76497d2661ea3 100644 --- a/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "manip.scatter_nd" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_SegmentMax.pbtxt b/tensorflow/core/api_def/python_api/api_def_SegmentMax.pbtxt index 2e22c375c071db..5f40b94b81eca9 100644 --- a/tensorflow/core/api_def/python_api/api_def_SegmentMax.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_SegmentMax.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "segment_max" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_SegmentMean.pbtxt b/tensorflow/core/api_def/python_api/api_def_SegmentMean.pbtxt index 646348072f08c2..a7da724f1dcae7 100644 --- a/tensorflow/core/api_def/python_api/api_def_SegmentMean.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_SegmentMean.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "segment_mean" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_SegmentMin.pbtxt b/tensorflow/core/api_def/python_api/api_def_SegmentMin.pbtxt index 1a77019a2dca9d..d4ccfe7457b74c 100644 --- a/tensorflow/core/api_def/python_api/api_def_SegmentMin.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_SegmentMin.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "segment_min" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_SegmentProd.pbtxt b/tensorflow/core/api_def/python_api/api_def_SegmentProd.pbtxt index cf4d6f0237dc9d..8bbd6ce105fd2d 100644 --- a/tensorflow/core/api_def/python_api/api_def_SegmentProd.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_SegmentProd.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "segment_prod" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_SegmentSum.pbtxt b/tensorflow/core/api_def/python_api/api_def_SegmentSum.pbtxt index c6d7999455039f..b40b5237a28350 100644 --- a/tensorflow/core/api_def/python_api/api_def_SegmentSum.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_SegmentSum.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "segment_sum" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt b/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt index 146b97f444a85a..9069a3e7a2f76a 100644 --- a/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "manip.space_to_batch_nd" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_SquaredDifference.pbtxt b/tensorflow/core/api_def/python_api/api_def_SquaredDifference.pbtxt index 4bab8cf00c34ba..c2ef8d3b34c486 100644 --- a/tensorflow/core/api_def/python_api/api_def_SquaredDifference.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_SquaredDifference.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "squared_difference" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_StringJoin.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringJoin.pbtxt index 46a7c0361e21a8..a54cdb46c1f04a 100644 --- a/tensorflow/core/api_def/python_api/api_def_StringJoin.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_StringJoin.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "string_join" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_StringStrip.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringStrip.pbtxt index fbcdeaad6d3be2..fedc03a19da68c 100644 --- a/tensorflow/core/api_def/python_api/api_def_StringStrip.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_StringStrip.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "string_strip" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_StringToHashBucket.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringToHashBucket.pbtxt index d122e79b39466c..cf0b8831ef14bf 100644 --- a/tensorflow/core/api_def/python_api/api_def_StringToHashBucket.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_StringToHashBucket.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "string_to_hash_bucket" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_StringToHashBucketFast.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringToHashBucketFast.pbtxt index aef9dffefe5f49..06451a9ad57b42 100644 --- a/tensorflow/core/api_def/python_api/api_def_StringToHashBucketFast.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_StringToHashBucketFast.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "string_to_hash_bucket_fast" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_StringToHashBucketStrong.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringToHashBucketStrong.pbtxt index 385b9fd02ac214..8e103c8e2d3016 100644 --- a/tensorflow/core/api_def/python_api/api_def_StringToHashBucketStrong.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_StringToHashBucketStrong.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "string_to_hash_bucket_strong" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_StringToNumber.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringToNumber.pbtxt index f740b9849df4d2..155dd2675037b0 100644 --- a/tensorflow/core/api_def/python_api/api_def_StringToNumber.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_StringToNumber.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "string_to_number" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt b/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt index 1d8695f1fdfdf7..3ffbe8cf526edb 100644 --- a/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "manip.tile" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMax.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMax.pbtxt index cf8184324160bd..32044fd90edf9f 100644 --- a/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMax.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMax.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "unsorted_segment_max" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMin.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMin.pbtxt index 475361c85a26f9..177e840e4272d9 100644 --- a/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMin.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMin.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "unsorted_segment_min" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentProd.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentProd.pbtxt index a9d741bbc33a0b..f3aa8e8a515ed0 100644 --- a/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentProd.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentProd.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "unsorted_segment_prod" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentSum.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentSum.pbtxt index 337678dcffe12d..1542bb039e0c0d 100644 --- a/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentSum.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentSum.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "unsorted_segment_sum" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_WriteFile.pbtxt b/tensorflow/core/api_def/python_api/api_def_WriteFile.pbtxt index 1a58ae19e54195..d065027e9320d2 100644 --- a/tensorflow/core/api_def/python_api/api_def_WriteFile.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_WriteFile.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "write_file" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/api_def/python_api/api_def_Zeta.pbtxt b/tensorflow/core/api_def/python_api/api_def_Zeta.pbtxt index 4684a9d6242c5e..69bf4eb51d2698 100644 --- a/tensorflow/core/api_def/python_api/api_def_Zeta.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Zeta.pbtxt @@ -5,6 +5,6 @@ op { } endpoint { name: "zeta" - deprecated: true + deprecation_version: 2 } } diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc index 5b01f7fa037f4a..92e56df1810521 100644 --- a/tensorflow/core/common_runtime/base_collective_executor.cc +++ b/tensorflow/core/common_runtime/base_collective_executor.cc @@ -261,6 +261,13 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx, }); } +void BaseCollectiveExecutor::CompleteParamsAsync( + const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr, + StatusCallback done) { + cp->instance.gpu_ring_order = *gpu_ring_order_; + cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr, done); +} + Status BaseCollectiveExecutor::CreateCollective( const CollectiveParams& col_params, CollectiveImplementationInterface** col_impl) { diff --git a/tensorflow/core/common_runtime/base_collective_executor.h b/tensorflow/core/common_runtime/base_collective_executor.h index 360ce4db7bdab1..09826a8814511c 100644 --- a/tensorflow/core/common_runtime/base_collective_executor.h +++ b/tensorflow/core/common_runtime/base_collective_executor.h @@ -89,11 +89,13 @@ class BaseCollectiveExecutor : public CollectiveExecutor { public: BaseCollectiveExecutor(CollectiveExecutorMgrInterface* cem, PerStepCollectiveRemoteAccess* remote_access, - int64 step_id, const DeviceMgr* dev_mgr) + int64 step_id, const DeviceMgr* dev_mgr, + const string* gpu_ring_order) : CollectiveExecutor(cem), step_id_(step_id), dev_mgr_(dev_mgr), - remote_access_(remote_access) {} + remote_access_(remote_access), + gpu_ring_order_(gpu_ring_order) {} ~BaseCollectiveExecutor() override; @@ -102,6 +104,10 @@ class BaseCollectiveExecutor : public CollectiveExecutor { void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams& col_params, const string& exec_key, StatusCallback done) override; + void CompleteParamsAsync(const string& device, CollectiveParams* cp, + CancellationManager* cancel_mgr, + StatusCallback done) override; + PerStepCollectiveRemoteAccess* remote_access() override { return remote_access_.get(); } @@ -133,6 +139,7 @@ class BaseCollectiveExecutor : public CollectiveExecutor { const int64 step_id_; const DeviceMgr* dev_mgr_; // Not owned. std::unique_ptr remote_access_; + const string* gpu_ring_order_; // Not owned. private: Status CreateCollective(const CollectiveParams& col_params, diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.cc b/tensorflow/core/common_runtime/collective_executor_mgr.cc index 4f03a5e13ad59b..7bbc7ca06c5608 100644 --- a/tensorflow/core/common_runtime/collective_executor_mgr.cc +++ b/tensorflow/core/common_runtime/collective_executor_mgr.cc @@ -29,7 +29,9 @@ CollectiveExecutorMgr::CollectiveExecutorMgr( std::unique_ptr param_resolver) : dev_mgr_(dev_mgr), dev_resolver_(std::move(dev_resolver)), - param_resolver_(std::move(param_resolver)) {} + param_resolver_(std::move(param_resolver)), + gpu_ring_order_( + config.gpu_options().experimental().collective_ring_order()) {} CollectiveExecutorMgr::~CollectiveExecutorMgr() { for (auto iter : executor_table_) { @@ -56,7 +58,8 @@ CollectiveExecutor* CollectiveExecutorMgr::FindOrCreate(int64 step_id) { CollectiveExecutor* CollectiveExecutorMgr::Create(int64 step_id) { CollectiveRemoteAccessLocal* rma = new CollectiveRemoteAccessLocal(dev_mgr_, dev_resolver_.get(), step_id); - return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_); + return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_, + &gpu_ring_order_); } void CollectiveExecutorMgr::Cleanup(int64 step_id) { diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.h b/tensorflow/core/common_runtime/collective_executor_mgr.h index d53aca85b967c1..4db121a4d6d024 100644 --- a/tensorflow/core/common_runtime/collective_executor_mgr.h +++ b/tensorflow/core/common_runtime/collective_executor_mgr.h @@ -62,8 +62,7 @@ class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface { const DeviceMgr* dev_mgr_; std::unique_ptr dev_resolver_; std::unique_ptr param_resolver_; - CollectiveRemoteAccess* remote_access_; - string task_name_; + string gpu_ring_order_; private: mutex exec_mu_; diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc index 1bc873d0c5c132..f90fb174344d4b 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/device_name_utils.h" @@ -170,8 +171,43 @@ GlobalDeviceMap BuildDevRecs(const CollInstanceParams& ip, return gdm; } -void OrderTaskDeviceMap(TaskDeviceMap* tdm) { +bool ParseRingOrder(const string& gpu_ring_order_str, TaskDeviceMap* tdm) { + std::vector gpu_ring_order_vec; + if (!str_util::SplitAndParseAsInts(gpu_ring_order_str, ',', + &gpu_ring_order_vec)) { + return false; + } + if (gpu_ring_order_vec.size() != tdm->size()) return false; + // gpu id -> local rank + gtl::FlatMap gpu_ranks; + for (int32 rank = 0; rank < static_cast(gpu_ring_order_vec.size()); + ++rank) { + gpu_ranks[gpu_ring_order_vec[rank]] = rank; + } + + for (auto& tdm_it : *tdm) { + DeviceNameUtils::ParsedName parsed_name; + DevRec* dr = &tdm_it.second; + if (!DeviceNameUtils::ParseFullName(dr->device, &parsed_name)) { + return false; + } + auto rank_it = gpu_ranks.find(parsed_name.id); + if (rank_it == gpu_ranks.end()) return false; + dr->local_rank = rank_it->second; + } + VLOG(2) << "Assigned local ranks based on ring order " << gpu_ring_order_str; + return true; +} + +void OrderTaskDeviceMap(const string& gpu_ring_order, TaskDeviceMap* tdm) { CHECK_GT(tdm->size(), 0); // Should never be called with 0 devices + + // If a valid ring order has been passed in via ConfigProto, use that. + if (ParseRingOrder(gpu_ring_order, tdm)) return; + + // Either no ring order was passed in, or the format was unexpected. + // We now assign a ring order based on link strengths. Note that this + // algorithm is not optimal and may not always find the best ring order. int least_rank = -1; string next_device; std::set selected; @@ -256,7 +292,7 @@ GlobalDeviceMap EstablishGlobalRank( GlobalDeviceMap gdm = BuildDevRecs(cp->instance, localities); for (auto& iter : gdm) { TaskDeviceMap& tdm = iter.second; - OrderTaskDeviceMap(&tdm); + OrderTaskDeviceMap(cp->instance.gpu_ring_order, &tdm); } // Connect the global rank order by the order in which tasks first appear. std::set ordered_tasks; diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h index c5c3497e28cc9c..365bddc787a7ba 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.h +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h @@ -57,6 +57,9 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { const StatusCallback& done) override; protected: + // For access to InstanceRec and CompleteDefaultRanking. + friend class CollectiveParamResolverLocalTest; + // Used to complete/verify CollGroup. struct GroupRec { CollGroupParams group; diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc index 9e1e2e8d5b24b3..2b43adbac69359 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc +++ b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc @@ -44,12 +44,116 @@ class CollectiveParamResolverLocalTest : public ::testing::Test { task_name)); } + void RunCompleteDefaultRanking( + const CollectiveParams& shared_cp, + const std::vector& localities, + const std::vector& gpu_ring_order, + const std::vector& expected_device_order) { + CollectiveParams cp; + cp.instance.device_names = shared_cp.instance.device_names; + CollectiveParamResolverLocal::InstanceRec ir; + { + mutex_lock l(ir.out_mu); + ir.shared.name = shared_cp.name; + ir.shared.group = shared_cp.group; + ir.shared.instance = shared_cp.instance; + if (!gpu_ring_order.empty()) { + ir.shared.instance.gpu_ring_order = ""; + for (int i = 0; i < static_cast(gpu_ring_order.size() - 1); + ++i) { + ir.shared.instance.gpu_ring_order = strings::StrCat( + ir.shared.instance.gpu_ring_order, gpu_ring_order[i], ","); + } + ir.shared.instance.gpu_ring_order = strings::StrCat( + ir.shared.instance.gpu_ring_order, gpu_ring_order.back()); + } + VLOG(2) << "gpu_ring_order " << ir.shared.instance.gpu_ring_order; + prl_->CompleteDefaultRanking(nullptr, &cp, &ir, localities); + EXPECT_EQ(ir.shared.instance.device_names, expected_device_order); + } + } + std::vector devices_; std::unique_ptr device_mgr_; std::unique_ptr drl_; std::unique_ptr prl_; }; +TEST_F(CollectiveParamResolverLocalTest, CompleteDefaultRanking) { + constexpr int kNumGpus = 8; + CollectiveParams cp; + std::vector localities(kNumGpus); + cp.name = "PRLTest"; + cp.group.device_type = DeviceType("GPU"); + cp.group.num_tasks = 1; + cp.group.group_size = kNumGpus; + cp.instance.instance_key = 5; + cp.instance.type = REDUCTION_COLLECTIVE; + cp.instance.data_type = DataType(DT_FLOAT); + std::unordered_set clique1 = {0, 1, 6, 7}; + for (int gpu_idx = 0; gpu_idx < kNumGpus; ++gpu_idx) { + cp.instance.task_names.push_back("/job:localhost/replica:0/task:0"); + cp.instance.device_names.push_back(strings::StrCat( + "/job:localhost/replica:0/task:0/device:GPU:", gpu_idx)); + DeviceLocality* locality = &localities[gpu_idx]; + // Build localities so that 0,1,6,7 and 2,3,4,5 form 2 strongly connected + // components. Across components, connect 3 and 7. + for (int link_idx = 0; link_idx < kNumGpus; ++link_idx) { + if (gpu_idx == link_idx) continue; + bool gpu_in_clique1 = clique1.find(gpu_idx) != clique1.end(); + bool link_in_clique1 = clique1.find(link_idx) != clique1.end(); + if ((gpu_in_clique1 && link_in_clique1) || + (!gpu_in_clique1 && !link_in_clique1)) { + LocalLinks* links = locality->mutable_links(); + InterconnectLink* ilink = links->add_link(); + ilink->set_device_id(link_idx); + ilink->set_strength(2); + } else if ((gpu_idx == 3 && link_idx == 7) || + (gpu_idx == 7 && link_idx == 3)) { + LocalLinks* links = locality->mutable_links(); + InterconnectLink* ilink = links->add_link(); + ilink->set_device_id(link_idx); + ilink->set_strength(1); + } + } + } + RunCompleteDefaultRanking(cp, localities, {1, 3, 5, 7, 6, 4, 2, 0}, + { + "/job:localhost/replica:0/task:0/device:GPU:1", + "/job:localhost/replica:0/task:0/device:GPU:3", + "/job:localhost/replica:0/task:0/device:GPU:5", + "/job:localhost/replica:0/task:0/device:GPU:7", + "/job:localhost/replica:0/task:0/device:GPU:6", + "/job:localhost/replica:0/task:0/device:GPU:4", + "/job:localhost/replica:0/task:0/device:GPU:2", + "/job:localhost/replica:0/task:0/device:GPU:0", + }); + RunCompleteDefaultRanking(cp, localities, {7, 6, 5, 4, 3, 2, 1, 0}, + { + "/job:localhost/replica:0/task:0/device:GPU:7", + "/job:localhost/replica:0/task:0/device:GPU:6", + "/job:localhost/replica:0/task:0/device:GPU:5", + "/job:localhost/replica:0/task:0/device:GPU:4", + "/job:localhost/replica:0/task:0/device:GPU:3", + "/job:localhost/replica:0/task:0/device:GPU:2", + "/job:localhost/replica:0/task:0/device:GPU:1", + "/job:localhost/replica:0/task:0/device:GPU:0", + }); + // With no gpu_ring_order passed, automatic link detection should kick in. + // Starting at dev 0, the best order would be: 0,1,6,7,3,2,4,5 + RunCompleteDefaultRanking(cp, localities, {}, + { + "/job:localhost/replica:0/task:0/device:GPU:0", + "/job:localhost/replica:0/task:0/device:GPU:1", + "/job:localhost/replica:0/task:0/device:GPU:6", + "/job:localhost/replica:0/task:0/device:GPU:7", + "/job:localhost/replica:0/task:0/device:GPU:3", + "/job:localhost/replica:0/task:0/device:GPU:2", + "/job:localhost/replica:0/task:0/device:GPU:4", + "/job:localhost/replica:0/task:0/device:GPU:5", + }); +} + TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) { CollectiveParams cps[NUM_DEVS]; Status statuses[NUM_DEVS]; diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index 7b74c67c858659..a7b618c18be5a7 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -148,6 +148,7 @@ tf_cuda_library( ], visibility = ["//tensorflow:internal"], deps = [ + ":attr_builder", "@farmhash_archive//:farmhash", ] + select({ "//tensorflow:android": [ @@ -219,7 +220,6 @@ tf_cuda_library( hdrs = ["attr_builder.h"], visibility = ["//tensorflow:internal"], deps = [ - ":kernel_and_device", "@farmhash_archive//:farmhash", # Only the TF_AttrType enum is required, so pull in just the C headers. # TODO(b/113535673): Break this dependency and avoid the C header completely. diff --git a/tensorflow/core/common_runtime/eager/attr_builder.cc b/tensorflow/core/common_runtime/eager/attr_builder.cc index 5c8369de8765bc..29edc4e3b8f5e4 100644 --- a/tensorflow/core/common_runtime/eager/attr_builder.cc +++ b/tensorflow/core/common_runtime/eager/attr_builder.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/device_factory.h" -#include "tensorflow/core/common_runtime/eager/kernel_and_device.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/node_def.pb.h" diff --git a/tensorflow/core/common_runtime/eager/attr_builder.h b/tensorflow/core/common_runtime/eager/attr_builder.h index c114ea4ba0212d..af5b7d80c324d9 100644 --- a/tensorflow/core/common_runtime/eager/attr_builder.h +++ b/tensorflow/core/common_runtime/eager/attr_builder.h @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/common_runtime/eager/kernel_and_device.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" @@ -157,7 +156,6 @@ template <> AttrBuilder& AttrBuilder::Set(StringPiece attr_name, tensorflow::DataType&& value); - } // namespace tensorflow #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_ diff --git a/tensorflow/core/common_runtime/eager/copy_to_device_node.h b/tensorflow/core/common_runtime/eager/copy_to_device_node.h index 8a887540b06605..953b3580c2ea91 100644 --- a/tensorflow/core/common_runtime/eager/copy_to_device_node.h +++ b/tensorflow/core/common_runtime/eager/copy_to_device_node.h @@ -30,7 +30,7 @@ class CopyToDeviceNode : public EagerNode { src_(src), dstd_(dstd), ctx_(ctx), - dst_(new TensorHandle(id, src_->dtype, ctx)) { + dst_(new TensorHandle(id, dstd_, dstd_, src->dtype, ctx)) { src_->Ref(); dst_->Ref(); } diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index c5f1d52e43d95b..0fcf5d93877b31 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -277,33 +277,20 @@ Status EagerLocalExecute(EagerOperation* op, LOG(INFO) << "Executing op " << ndef.op() << " in device " << device->name(); } - kernel = new KernelAndDevice(ctx->GetRendezvous(), ctx->LogMemory()); - auto* flr = ctx->func_lib(device); + auto* flr = ctx->func_lib(device); if (flr == nullptr) { return errors::Unavailable( "Unable to find a FunctionLibraryRuntime corresponding to device ", device->name()); } + kernel = new KernelAndDevice(ctx->GetRendezvous(), ctx->LogMemory()); status = KernelAndDevice::Init(ndef, flr, ctx->runner(), kernel); if (!status.ok()) { delete kernel; return status; } - // Update output_dtypes inside `kernel`. - const OpDef* op_def = nullptr; - const FunctionDef* function_def = ctx->FuncLibDef()->Find(ndef.op()); - if (function_def != nullptr) { - op_def = &(function_def->signature()); - } - if (op_def == nullptr) { - status = OpDefForOp(ndef.op().c_str(), &op_def); - if (!status.ok()) return status; - } - DataTypeVector input_dtypes; - status = InOutTypesForNode(ndef, *op_def, &input_dtypes, - kernel->mutable_output_dtypes()); - if (!status.ok()) return status; + ctx->AddKernelToCache(cache_key, kernel); } const DataTypeVector& output_dtypes = kernel->output_dtypes(); @@ -346,8 +333,17 @@ Status EagerLocalExecute(EagerOperation* op, // input handles are ready before executing them. // TODO(agarwal): Consider executing "cheap" kernels inline for performance. tensorflow::uint64 id = ctx->NextId(); + const MemoryTypeVector* output_memory_types = nullptr; + output_memory_types = &kernel->kernel()->output_memory_types(); + + Device* op_device = kernel->device(); for (int i = 0; i < *num_retvals; ++i) { - (*retvals)[i] = new TensorHandle(id, output_dtypes[i], ctx); + Device* d = op_device; + if (d != nullptr && output_memory_types != nullptr && + (*output_memory_types)[i] == HOST_MEMORY) { + d = nullptr; + } + (*retvals)[i] = new TensorHandle(id, d, op_device, output_dtypes[i], ctx); } EagerNode* node = new ExecuteNode( id, ctx, op->Device(), op->Inputs(), kernel, maybe_stats.release(), diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index b63257907f6129..ac9fd187b3427a 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" #include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/framework/allocator.h" @@ -33,17 +34,27 @@ limitations under the License. namespace tensorflow { // static -Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, +Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flr, std::function)>* runner, KernelAndDevice* out) { OpKernel* k = nullptr; - Status s = flib->CreateKernel(ndef, &k); - out->device_ = flib->device(); + TF_RETURN_IF_ERROR(flr->CreateKernel(ndef, &k)); + out->device_ = flr->device(); out->kernel_.reset(k); - out->flib_ = flib; + out->flr_ = flr; out->runner_ = runner; out->default_runner_ = [](std::function f) { f(); }; - return s; + + // Update output_dtypes_. + const OpDef* op_def = nullptr; + const FunctionDef* function_def = + flr->GetFunctionLibraryDefinition()->Find(ndef.op()); + if (function_def != nullptr) { + op_def = &(function_def->signature()); + } else { + TF_RETURN_IF_ERROR(OpDefForOp(ndef.op().c_str(), &op_def)); + } + return OutputTypesForNode(ndef, *op_def, &out->output_dtypes_); } Status KernelAndDevice::Run(std::vector* inputs, @@ -80,7 +91,7 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container, params.op_kernel = kernel_.get(); params.resource_manager = device_->resource_manager(); params.output_attr_array = gtl::vector_as_array(&out_attrs); - params.function_library = flib_; + params.function_library = flr_; params.slice_reader_cache = &slice_reader_cache_; params.rendezvous = rendez_; params.cancellation_manager = &cm_; @@ -120,7 +131,7 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container, outputs->push_back(Tensor(*context.mutable_output(i))); } if (stats != nullptr) { - for (const auto& allocator_pair : context.wrapped_allocators()) { + for (const auto& allocator_pair : context.ConsumeWrappedAllocators()) { AllocatorMemoryUsed* memory = stats->add_memory(); memory->set_allocator_name(allocator_pair.first->Name()); auto sizes = allocator_pair.second->GetSizes(); diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h index ac9143b253a2fc..4b0f5182a0e4d2 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.h +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h @@ -50,13 +50,13 @@ class KernelAndDevice { // // The provided FunctionLibraryRuntime MUST outlive all calls to // Run() on the returned KernelAndDevice. - static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, + static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flr, std::function)>* runner, KernelAndDevice* out); KernelAndDevice(tensorflow::Rendezvous* rendez, bool log_memory) : device_(nullptr), - flib_(nullptr), + flr_(nullptr), rendez_(rendez), log_memory_(log_memory) {} @@ -73,7 +73,6 @@ class KernelAndDevice { Device* device() const { return device_; } - DataTypeVector* mutable_output_dtypes() { return &output_dtypes_; } const DataTypeVector& output_dtypes() { return output_dtypes_; } private: @@ -84,7 +83,7 @@ class KernelAndDevice { CancellationManager cm_; std::unique_ptr kernel_; Device* device_; - FunctionLibraryRuntime* flib_; + FunctionLibraryRuntime* flr_; checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_; Rendezvous* rendez_; DataTypeVector output_dtypes_; diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index d58724cbfacf6f..655add00e9bb66 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -80,15 +80,11 @@ Status TensorHandle::Tensor(const tensorflow::Tensor** t) { } Status TensorHandle::Device(tensorflow::Device** d) { - TF_RETURN_IF_ERROR(WaitReady()); - DCHECK(IsReady()); *d = device_; return Status::OK(); } Status TensorHandle::OpDevice(tensorflow::Device** d) { - TF_RETURN_IF_ERROR(WaitReady()); - DCHECK(IsReady()); *d = op_device_; return Status::OK(); } diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index e55f1a03385f2d..4f2c1a31a47796 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -61,12 +61,13 @@ class TensorHandle : public core::RefCounted { ctx_(ctx), is_ready_(true) {} - TensorHandle(uint64 node_id, DataType dtype, EagerContext* ctx) + TensorHandle(uint64 node_id, Device* d, Device* op_device, DataType dtype, + EagerContext* ctx) : dtype(dtype), node_id_(node_id), tensor_(dtype), - device_(nullptr), - op_device_(nullptr), + device_(d), + op_device_(op_device), remote_op_id_(-1), remote_output_num_(-1), remote_shape_node_id_(-1), diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index f5c6a5c669487d..1e68954827f3c5 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -1651,9 +1651,10 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { params.track_allocations = false; stats = nullptr; if (stats_collector_ && !tagged_node.is_dead) { - // track allocations if and only if we are collecting statistics - params.track_allocations = true; stats = stats_collector_->CreateNodeExecStats(node); + // Track allocations if and only if we are collecting statistics, and + // `stats` object is expecting allocations to be tracked. + params.track_allocations = stats ? stats->TrackAllocations() : false; nodestats::SetScheduled(stats, scheduled_nsec); nodestats::SetAllStart(stats); } diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index e0e5f4a21560f3..6775695fa2d941 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -139,6 +139,153 @@ static Node* AddRet(Graph* g, Endpoint input, int index) { return ret; } +// FunctionLibraryRuntime implementation that forwards all the function calls to +// the base runtime implementation, and only overrides overlay lib in calls to +// Instantiate (if caller doesn't provide its own overlay lib). +// +// When function library runtime (FunctionLibraryRuntimeImpl specifically) +// instantiates function into a Graph object, it also creates an Executor for +// it. That executor has a pointer to the function library runtime instance, +// that is used to instantiate all nested function calls. +// +// If the original function was instantiated using overlay lib, we must preserve +// that overlay lib in the executor's function library runtime. +// +// IMPORTANT: This runtime is intended for use only in executors created for +// functions instantiated into a graph in FunctionLibraryRuntimeImpl. +class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime { + public: + FunctionLibraryRuntimeOverlay( + FunctionLibraryRuntime* base_flr, + const FunctionLibraryDefinition* overlay_lib_def) + : base_flr_(base_flr), overlay_lib_def_(overlay_lib_def) {} + ~FunctionLibraryRuntimeOverlay() override; + + Status Instantiate(const string& function_name, AttrSlice attrs, + const InstantiateOptions& options, + Handle* handle) override; + + Status ReleaseHandle(Handle handle) override; + + const FunctionBody* GetFunctionBody(Handle h) override; + + void Run(const Options& opts, Handle handle, gtl::ArraySlice args, + std::vector* rets, DoneCallback done) override; + + void Run(const Options& opts, Handle handle, CallFrameInterface* call_frame, + DoneCallback done) override; + + Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override; + + bool IsStateful(const string& function_name) override; + + const FunctionLibraryDefinition* GetFunctionLibraryDefinition() + const override; + + Env* env() override; + Device* device() override; + const DeviceMgr* device_mgr() const override; + + string DebugString(Handle handle) override; + int graph_def_version() override; + + Status Clone(std::unique_ptr* out_lib_def, + std::unique_ptr* out_pflr, + FunctionLibraryRuntime** out_flr) override; + + private: + FunctionLibraryRuntime* base_flr_; // not owned + const FunctionLibraryDefinition* overlay_lib_def_; // not owned +}; + +FunctionLibraryRuntimeOverlay::~FunctionLibraryRuntimeOverlay() = default; + +Status FunctionLibraryRuntimeOverlay::Instantiate( + const string& function_name, AttrSlice attrs, + const InstantiateOptions& options, Handle* handle) { + // We automatically add overlay lib to all instantiations, if the caller + // doesn't provide its own override. + if (!options.overlay_lib && overlay_lib_def_) { + InstantiateOptions options_copy = options; + options_copy.overlay_lib = overlay_lib_def_; + return base_flr_->Instantiate(function_name, attrs, options_copy, handle); + } else { + return base_flr_->Instantiate(function_name, attrs, options, handle); + } +} + +Status FunctionLibraryRuntimeOverlay::ReleaseHandle(Handle handle) { + return base_flr_->ReleaseHandle(handle); +} + +const FunctionBody* FunctionLibraryRuntimeOverlay::GetFunctionBody(Handle h) { + return base_flr_->GetFunctionBody(h); +} + +void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle, + gtl::ArraySlice args, + std::vector* rets, + DoneCallback done) { + base_flr_->Run(opts, handle, args, rets, std::move(done)); +} + +void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle, + CallFrameInterface* call_frame, + DoneCallback done) { + base_flr_->Run(opts, handle, call_frame, std::move(done)); +} + +Status FunctionLibraryRuntimeOverlay::CreateKernel(const NodeDef&, OpKernel**) { + // We don't have access base_lib_def_ in base function library runtime (aka + // FunctionLibraryRuntimeImpl), so to make sure we do not create kernel with + // wrong lib_def we just disable creation of new kernels through overlays. + // + // When we call Instantiate from the base runtime with overlay lib override, + // the base runtime implementation is responsible for correctly passing custom + // overlay lib to all kernel constructions. + return errors::Internal( + "Overlay function library runtime doesn't support kernel creation."); +} + +bool FunctionLibraryRuntimeOverlay::IsStateful(const string& function_name) { + // Important: we do not forward lookup to the base FLR. + const OpDef* op_def; + const Status s = overlay_lib_def_->LookUpOpDef(function_name, &op_def); + return s.ok() && op_def->is_stateful(); +} + +Env* FunctionLibraryRuntimeOverlay::env() { return base_flr_->env(); } + +Device* FunctionLibraryRuntimeOverlay::device() { return base_flr_->device(); } + +const DeviceMgr* FunctionLibraryRuntimeOverlay::device_mgr() const { + return base_flr_->device_mgr(); +} + +const FunctionLibraryDefinition* +FunctionLibraryRuntimeOverlay::GetFunctionLibraryDefinition() const { + return overlay_lib_def_ ? overlay_lib_def_ + : base_flr_->GetFunctionLibraryDefinition(); +} + +string FunctionLibraryRuntimeOverlay::DebugString(Handle handle) { + return base_flr_->DebugString(handle); +} + +int FunctionLibraryRuntimeOverlay::graph_def_version() { + return base_flr_->graph_def_version(); +} + +Status FunctionLibraryRuntimeOverlay::Clone( + std::unique_ptr* out_lib_def, + std::unique_ptr* out_pflr, + FunctionLibraryRuntime** out_flr) { + // NOTE(ezhulenev): Cloned FunctionLibraryRuntime will be missing overlay lib, + // but that's ok because we anyway do not copy/clone instantiated items from + // the base FLR. + return base_flr_->Clone(out_lib_def, out_pflr, out_flr); +} + class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { public: FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env, Device* device, @@ -216,11 +363,13 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { const FunctionLibraryDefinition* overlay_lib = nullptr; // Not owned. FunctionBody* func_graph = nullptr; Executor* exec = nullptr; + FunctionLibraryRuntimeOverlay* overlay_flr = nullptr; string executor_type; ~Item() { delete this->func_graph; delete this->exec; + delete this->overlay_flr; } }; std::unordered_map> items_ GUARDED_BY(mu_); @@ -233,8 +382,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs, const FunctionLibraryDefinition* lib_def, FunctionBody** fbody); - Status CreateItem(Handle handle, Item** item); - Status GetOrCreateItem(Handle handle, Item** item); + Status CreateItem(Item** item); + Status GetOrCreateItem(LocalHandle local_handle, Item** item); Status InstantiateSymbolicGradient(const NameAttrList& func, const FunctionLibraryDefinition* lib_def, FunctionBody** g_body); @@ -242,7 +391,11 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { AttrValueMap FixAttrs(const AttrSlice& attrs); void RunRemote(const Options& opts, Handle handle, gtl::ArraySlice args, std::vector* rets, - Executor::Args* exec_args, Item* item, DoneCallback done); + Item* item, DoneCallback done); + + void ExecutorArgsFromOptions(const FunctionLibraryRuntime::Options& run_opts, + CallFrameInterface* frame, + Executor::Args* exec_args); TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl); }; @@ -538,13 +691,14 @@ Status FunctionLibraryRuntimeImpl::Instantiate( TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, lib_def, &fbody)); } + LocalHandle local_handle; { mutex_lock l(mu_); *handle = parent_->GetHandle(key); if (*handle != kInvalidHandle) { delete fbody; - ++items_[parent_->GetHandleOnDevice(device_name_, *handle)] - ->instantiation_counter; + local_handle = parent_->GetHandleOnDevice(device_name_, *handle); + ++items_[local_handle]->instantiation_counter; } else { *handle = parent_->AddHandle(key, device_name_, next_handle_); Item* item = new Item; @@ -552,26 +706,28 @@ Status FunctionLibraryRuntimeImpl::Instantiate( item->overlay_lib = options.overlay_lib; item->instantiation_counter = 1; item->executor_type = ExecutorType(options, attrs); - items_.emplace(next_handle_, std::unique_ptr(item)); - next_handle_++; + if (options.overlay_lib) { + item->overlay_flr = + new FunctionLibraryRuntimeOverlay(this, options.overlay_lib); + } + local_handle = next_handle_++; + items_.emplace(local_handle, std::unique_ptr(item)); } } if (options.create_kernels_eagerly) { Item* item; - TF_RETURN_IF_ERROR(GetOrCreateItem(*handle, &item)); + TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, &item)); } return Status::OK(); } Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) { - if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) { + LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle); + if (h == kInvalidLocalHandle) { return parent_->ReleaseHandle(handle); } - - LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle); - CHECK_NE(h, kInvalidLocalHandle); mutex_lock l(mu_); CHECK_EQ(1, items_.count(h)); std::unique_ptr& item = items_[h]; @@ -632,7 +788,7 @@ void PruneFunctionBody(Graph* g) { } } // namespace -Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) { +Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) { const FunctionBody* fbody; const FunctionLibraryDefinition* lib_def; string executor_type; @@ -653,11 +809,14 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) { TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device()->device_type()), device()->name(), g.get())); - // Creates an executor based on the g. This must be done without + // Creates an executor based on the g. This must be done without // holding mu_ because create_kernel_ calls back into the library. LocalExecutorParams params; params.device = device_; - params.function_library = this; + params.function_library = + (*item)->overlay_flr + ? static_cast((*item)->overlay_flr) + : static_cast(this); if (lib_def == base_lib_def_) { params.create_kernel = create_kernel_; } else { @@ -683,13 +842,13 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) { return Status::OK(); } -Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) { - LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle); +Status FunctionLibraryRuntimeImpl::GetOrCreateItem(LocalHandle local_handle, + Item** item) { { tf_shared_lock l(mu_); auto iter = items_.find(local_handle); if (iter == items_.end()) { - return errors::NotFound("Function handle ", handle, + return errors::Internal("Local function handle ", local_handle, " is not valid. Likely an internal error."); } *item = iter->second.get(); @@ -699,22 +858,37 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) { } // NOTE: We need to call CreateItem out of mu_ because creating an // executor needs to call CreateKernel. - return CreateItem(handle, item); + return CreateItem(item); +} + +void FunctionLibraryRuntimeImpl::ExecutorArgsFromOptions( + const FunctionLibraryRuntime::Options& run_opts, CallFrameInterface* frame, + Executor::Args* exec_args) { + // Inherit the step_id from the caller. + exec_args->step_id = run_opts.step_id; + exec_args->rendezvous = run_opts.rendezvous; + exec_args->stats_collector = run_opts.stats_collector; + exec_args->cancellation_manager = run_opts.cancellation_manager; + exec_args->step_container = run_opts.step_container; + if (run_opts.runner) { + exec_args->runner = *run_opts.runner; + } else { + exec_args->runner = default_runner_; + } + exec_args->collective_executor = run_opts.collective_executor; + exec_args->call_frame = frame; } void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle, gtl::ArraySlice args, std::vector* rets, - Executor::Args* exec_args, Item* item, DoneCallback done) { - DCHECK(exec_args->call_frame == nullptr); string target_device = parent_->GetDeviceName(handle); string source_device = opts.source_device; Rendezvous* rendezvous = opts.rendezvous; DeviceContext* device_context; Status s = parent_->GetDeviceContext(target_device, &device_context); if (!s.ok()) { - delete exec_args; done(s); return; } @@ -722,7 +896,6 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle, s = parent_->GetDeviceIncarnation(source_device, &src_incarnation); s.Update(parent_->GetDeviceIncarnation(target_device, &target_incarnation)); if (!s.ok()) { - delete exec_args; done(s); return; } @@ -730,13 +903,8 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle, const FunctionBody* fbody = GetFunctionBody(handle); FunctionCallFrame* frame = new FunctionCallFrame(fbody->arg_types, fbody->ret_types); - exec_args->call_frame = frame; - if (!s.ok()) { - delete frame; - delete exec_args; - done(s); - return; - } + Executor::Args* exec_args = new Executor::Args; + ExecutorArgsFromOptions(opts, frame, exec_args); std::vector args_alloc_attrs, rets_alloc_attrs; args_alloc_attrs.reserve(fbody->arg_types.size()); @@ -782,10 +950,10 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle, return; } item->exec->RunAsync( - *exec_args, [frame, rets, done, source_device, target_device, - target_incarnation, rendezvous, device_context, - remote_args, exec_args, rets_alloc_attrs, - allow_dead_tensors](const Status& status) { + *exec_args, + [frame, rets, done, source_device, target_device, + target_incarnation, rendezvous, device_context, remote_args, + rets_alloc_attrs, allow_dead_tensors](const Status& status) { Status s = status; if (s.ok()) { s = frame->ConsumeRetvals(rets, allow_dead_tensors); @@ -793,7 +961,6 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle, delete frame; if (!s.ok()) { delete remote_args; - delete exec_args; done(s); return; } @@ -801,9 +968,9 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle, target_device, source_device, "ret_", target_incarnation, *rets, device_context, rets_alloc_attrs, rendezvous); delete remote_args; - delete exec_args; done(s); }); + delete exec_args; }); } @@ -826,7 +993,8 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, }; } - if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) { + LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle); + if (local_handle == kInvalidLocalHandle) { parent_->Run(run_opts, handle, args, rets, done); return; } @@ -836,54 +1004,43 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, } DCHECK(run_opts.runner != nullptr); - Executor::Args* exec_args = new Executor::Args; - // Inherit the step_id from the caller. - exec_args->step_id = run_opts.step_id; - exec_args->rendezvous = run_opts.rendezvous; - exec_args->stats_collector = run_opts.stats_collector; - exec_args->cancellation_manager = run_opts.cancellation_manager; - exec_args->step_container = run_opts.step_container; - exec_args->runner = *run_opts.runner; - exec_args->collective_executor = run_opts.collective_executor; - Item* item = nullptr; - Status s = GetOrCreateItem(handle, &item); + Status s = GetOrCreateItem(local_handle, &item); if (!s.ok()) { - delete exec_args; done(s); return; } if (run_opts.remote_execution) { // NOTE(mrry): `RunRemote()` will set `exec_args->call_frame` for us. - RunRemote(run_opts, handle, args, rets, exec_args, item, done); + RunRemote(run_opts, handle, args, rets, item, done); return; } const FunctionBody* fbody = GetFunctionBody(handle); FunctionCallFrame* frame = new FunctionCallFrame(fbody->arg_types, fbody->ret_types); - exec_args->call_frame = frame; s = frame->SetArgs(args); if (!s.ok()) { delete frame; - delete exec_args; done(s); return; } - bool allow_dead_tensors = opts.allow_dead_tensors; + Executor::Args exec_args; + ExecutorArgsFromOptions(run_opts, frame, &exec_args); + + bool allow_dead_tensors = run_opts.allow_dead_tensors; item->exec->RunAsync( // Executor args - *exec_args, + exec_args, // Done callback. - [frame, rets, done, exec_args, allow_dead_tensors](const Status& status) { + [frame, rets, done, allow_dead_tensors](const Status& status) { Status s = status; if (s.ok()) { s = frame->ConsumeRetvals(rets, allow_dead_tensors); } delete frame; - delete exec_args; done(s); }); } @@ -895,8 +1052,8 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, done(errors::Cancelled("")); return; } - if (!parent_->IsInstantiatedOnDevice(device_name_, handle) || - opts.remote_execution) { + LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle); + if (local_handle == kInvalidLocalHandle || opts.remote_execution) { done(errors::Unimplemented("Remote calling with CallFrameInterface")); return; } @@ -917,7 +1074,7 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, } Item* item = nullptr; - Status s = GetOrCreateItem(handle, &item); + Status s = GetOrCreateItem(local_handle, &item); if (!s.ok()) { done(s); return; @@ -928,16 +1085,7 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, DCHECK(run_opts.runner != nullptr); Executor::Args exec_args; - // Inherit the step_id from the caller. - exec_args.step_id = run_opts.step_id; - exec_args.rendezvous = run_opts.rendezvous; - exec_args.stats_collector = run_opts.stats_collector; - exec_args.cancellation_manager = run_opts.cancellation_manager; - exec_args.collective_executor = run_opts.collective_executor; - exec_args.step_container = run_opts.step_container; - exec_args.runner = *run_opts.runner; - exec_args.call_frame = frame; - + ExecutorArgsFromOptions(run_opts, frame, &exec_args); item->exec->RunAsync(exec_args, std::move(done)); } @@ -949,7 +1097,8 @@ bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) { string FunctionLibraryRuntimeImpl::DebugString(Handle handle) { Item* item = nullptr; - Status s = GetOrCreateItem(handle, &item); + LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle); + Status s = GetOrCreateItem(local_handle, &item); if (s.ok()) { return tensorflow::DebugString(item->graph); } else { diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 13ed29a841b792..13c189fb87732c 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -436,6 +436,57 @@ TEST_F(FunctionLibraryRuntimeTest, XTimesNInOverlayLib) { "Not found: Function XTimesTwo is not defined."); } +TEST_F(FunctionLibraryRuntimeTest, XTimesNInOverlayLibAndDelayedInstantiation) { + using FDH = ::tensorflow::FunctionDefHelper; + + Init({}); + + FunctionDef xt4_override = test::function::XTimesTwo(); + xt4_override.mutable_signature()->set_name("XTimesFour"); + + // Call XTimesFour via PartitionedCall which delays functions instantiation + // to the first call to Compute/ComputeAsync. + FunctionDef my_xt4 = FunctionDefHelper::Create( + "MyXTimesFour", {"x:float"}, {"z:float"}, {}, + {{{"x_times_four"}, + "PartitionedCall", + {"x"}, + {{"Tin", DataTypeSlice({DT_FLOAT})}, + {"Tout", DataTypeSlice({DT_FLOAT})}, + {"f", FDH::FunctionRef("XTimesFour", {{"T", DT_FLOAT}})}}}}, + /* Mapping between function returns and function node outputs. */ + {{"z", "x_times_four:output:0"}}); + + FunctionDefLibrary lib; + *lib.add_function() = test::function::XTimesTwo(); + *lib.add_function() = test::function::XTimesFour(); + *lib.add_function() = my_xt4; + std::unique_ptr overlay_lib( + new FunctionLibraryDefinition(OpRegistry::Global(), lib)); + + FunctionLibraryRuntime::InstantiateOptions options; + options.overlay_lib = overlay_lib.get(); + + auto x = test::AsTensor({1, 2, 3, 4}); + Tensor y; + + // When we instantiate with default library overlay we should get x*4. + TF_CHECK_OK(InstantiateAndRun(flr0_, "MyXTimesFour", {}, options, {x}, {&y})); + test::ExpectTensorEqual(y, test::AsTensor({4, 8, 12, 16})); + + // Overlay library that overrides default XTimesFour with XTimesTwo body. + FunctionDefLibrary lib_override; + *lib_override.add_function() = xt4_override; + *lib_override.add_function() = my_xt4; + std::unique_ptr overlay_lib_override( + new FunctionLibraryDefinition(OpRegistry::Global(), lib_override)); + + // We should call the XTimesFour override which is actually x*2. + options.overlay_lib = overlay_lib_override.get(); + TF_CHECK_OK(InstantiateAndRun(flr0_, "MyXTimesFour", {}, options, {x}, {&y})); + test::ExpectTensorEqual(y, test::AsTensor({2, 4, 6, 8})); +} + TEST_F(FunctionLibraryRuntimeTest, StateHandle) { auto T = DT_INT32; @@ -1381,7 +1432,9 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { GraphDef actual; g->ToGraphDef(&actual); - TF_EXPECT_GRAPH_EQ(expected, actual); + // The optimizer is non-deterministic, so we only check that the number of + // nodes is not greater than expected. + EXPECT_LE(actual.node_size(), expected.node_size()); } } diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc index afa219cc0bacd7..ab619ef619acab 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.cc +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -76,7 +76,7 @@ GraphExecutionState::~GraphExecutionState() { GraphDef* graph_def, const GraphExecutionStateOptions& options, std::unique_ptr* out_state) { #ifndef __ANDROID__ - VLOG(4) << "Graph proto is " << graph_def->DebugString(); + VLOG(4) << "Graph proto is \n" << graph_def->DebugString(); #endif // __ANDROID__ std::unique_ptr ret( diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc index da0e359cf8abdd..2144eea84f0a86 100644 --- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc +++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc @@ -245,11 +245,12 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test { if (!dev_mgr_ || device_type == DEVICE_CPU) { dev_mgr_.reset(new DeviceMgr(local_devices)); } + if (!gpu_ring_order_) gpu_ring_order_.reset(new string()); dev_resolver_.reset(new DeviceResolverLocal(dev_mgr_.get())); rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(), kStepId, fail_after); - col_exec_ = new BaseCollectiveExecutor(&col_exec_mgr_, rma_, kStepId, - dev_mgr_.get()); + col_exec_ = new BaseCollectiveExecutor( + &col_exec_mgr_, rma_, kStepId, dev_mgr_.get(), gpu_ring_order_.get()); col_params_.name = "test_collective"; col_params_.instance.data_type = dtype; static const int kGroupKey = 6; @@ -715,6 +716,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test { CollectiveParams col_params_; std::vector gpu_devices_; std::unique_ptr dev_mgr_; + std::unique_ptr gpu_ring_order_; mutex mu_; int bcast_recv_counter_ GUARDED_BY(mu_) = 0; int bcast_send_counter_ GUARDED_BY(mu_) = 0; diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc index 5e1ed130808a50..305d6a3b1bddca 100644 --- a/tensorflow/core/common_runtime/placer.cc +++ b/tensorflow/core/common_runtime/placer.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def_util.h" @@ -30,6 +31,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { @@ -575,11 +577,21 @@ class ColocationGraph { for (Device* d : device_set_->devices()) { registered_device_types.insert(d->device_type()); } + std::vector attr_key_vals; + for (const auto& it : node.attrs()) { + const string& name = it.first; + const AttrValue& attr_value = it.second; + attr_key_vals.push_back( + strings::StrCat(name, "=", SummarizeAttrValue(attr_value))); + } return errors::InvalidArgument( "No OpKernel was registered to support Op '", node.type_string(), - "' with these attrs. Registered devices: [", - str_util::Join(registered_device_types, ","), - "], Registered kernels:\n", + "' used by ", errors::FormatNodeNameForError(node.name()), + "with these attrs: [", str_util::Join(attr_key_vals, ", "), + "]\n" + "Registered devices: [", + str_util::Join(registered_device_types, ", "), "]\n", + "Registered kernels:\n", KernelsRegisteredForOp(node.type_string())); } diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc index e3d2663b984163..d5e98b8d9e81b6 100644 --- a/tensorflow/core/common_runtime/placer_test.cc +++ b/tensorflow/core/common_runtime/placer_test.cc @@ -1028,9 +1028,10 @@ TEST_F(PlacerTest, TestNoKernelsRegistered) { Status s = Place(&g); EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(str_util::StrContains( - s.error_message(), - "No OpKernel was registered to support Op 'VariableNoKernels'")); + EXPECT_TRUE( + str_util::StrContains(s.error_message(), + "No OpKernel was registered to support Op " + "'VariableNoKernels' used by {{node var}}")); EXPECT_TRUE( str_util::StrContains(s.error_message(), "")); } @@ -1052,9 +1053,9 @@ TEST_F(PlacerTest, TestNoDevicesRegistered) { Status s = Place(&g, &cpu_only); EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(str_util::StrContains( - s.error_message(), - "No OpKernel was registered to support Op 'VariableGPU'")); + EXPECT_TRUE(str_util::StrContains(s.error_message(), + "No OpKernel was registered to support Op " + "'VariableGPU' used by {{node var}}")); EXPECT_TRUE(str_util::StrContains(s.error_message(), "device='FakeGPU'")); } diff --git a/tensorflow/core/common_runtime/ring_reducer_test.cc b/tensorflow/core/common_runtime/ring_reducer_test.cc index 75aba435726bbd..a271bf7b747abb 100644 --- a/tensorflow/core/common_runtime/ring_reducer_test.cc +++ b/tensorflow/core/common_runtime/ring_reducer_test.cc @@ -187,11 +187,12 @@ class RingReducerTest : public ::testing::Test { << " devices: "; dev_mgr_.reset(new DeviceMgr(local_devices)); } + if (!gpu_ring_order_) gpu_ring_order_.reset(new string()); dev_resolver_.reset(new DeviceResolverLocal(dev_mgr_.get())); rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(), kStepId, fail_after); - col_exec_ = new BaseCollectiveExecutor(&col_exec_mgr_, rma_, kStepId, - dev_mgr_.get()); + col_exec_ = new BaseCollectiveExecutor( + &col_exec_mgr_, rma_, kStepId, dev_mgr_.get(), gpu_ring_order_.get()); col_params_.name = "test_collective"; static const int kGroupKey = 5; col_params_.group.group_key = kGroupKey; @@ -545,6 +546,7 @@ class RingReducerTest : public ::testing::Test { CollectiveParams col_params_; std::vector gpu_devices_; std::unique_ptr dev_mgr_; + std::unique_ptr gpu_ring_order_; mutex mu_; int32 reduce_counter_ GUARDED_BY(mu_) = 0; }; diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc index a70ab93d4ad7f7..49265445659ff1 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.cc +++ b/tensorflow/core/common_runtime/step_stats_collector.cc @@ -139,7 +139,7 @@ void NodeExecStatsWrapper::SetScheduled(int64 nanos) { } void NodeExecStatsWrapper::SetMemory(OpKernelContext* ctx) { - for (const auto& allocator_pair : ctx->wrapped_allocators()) { + for (const auto& allocator_pair : ctx->ConsumeWrappedAllocators()) { AddAllocation(allocator_pair.first, allocator_pair.second); } auto* ms = stats_->mutable_memory_stats(); diff --git a/tensorflow/core/common_runtime/step_stats_collector.h b/tensorflow/core/common_runtime/step_stats_collector.h index 4365b11b19e1b1..7d34383ce8209c 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.h +++ b/tensorflow/core/common_runtime/step_stats_collector.h @@ -68,8 +68,13 @@ class NodeExecStatsInterface { // Called immediately after this executor finishes processing this node. virtual void RecordExecutorEnded() = 0; + // Returns `true` if this object should track memory allocations. + virtual bool TrackAllocations() const = 0; + // Records information about the memory allocated during the execution of this // node. + // + // Takes ownership of any `TrackingAllocator` objects stored in `ctx`. virtual void SetMemory(OpKernelContext* ctx) = 0; // Records information about the tensor produced by this node at the given @@ -104,6 +109,7 @@ class NodeExecStatsWrapper : public NodeExecStatsInterface { void RecordComputeStarted() override; void RecordComputeEnded() override; void RecordExecutorEnded() override; + bool TrackAllocations() const override { return true; } void SetMemory(OpKernelContext* ctx) override; void SetOutput(int slot, const Tensor* tensor) override; void SetReferencedTensors(const TensorReferenceVector& tensors) override; diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc index 805e023b0f3c86..9087703cb5524d 100644 --- a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc +++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc @@ -61,6 +61,15 @@ class RecvBufCall : public CancellableCall { RecvBufResponse resp_; }; +void PopulateTensorFromExtra(const RecvBufRespExtra& extra, + Tensor* cpu_tensor) { + char* head = reinterpret_cast(DMAHelper::base(cpu_tensor)); + for (const auto& tensor_content_chunk : extra.tensor_content()) { + memcpy(head, tensor_content_chunk.data(), + tensor_content_chunk.size()); + head += tensor_content_chunk.size(); + } +} } // namespace void CollectiveRemoteAccessDistributed::RecvFromPeer( @@ -95,7 +104,10 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer( // them into the destination tensor here. RecvBufRespExtra extra; state->call->resp_.transport_options().UnpackTo(&extra); - int64 num_bytes = extra.tensor_content().size(); + int64 num_bytes = 0; + for (const auto& chunk : extra.tensor_content()) { + num_bytes += chunk.size(); + } if (num_bytes != to_tensor->TotalBytes()) { done(errors::Internal("RecvBufResponse returned ", num_bytes, " bytes where to_tensor expected ", @@ -118,8 +130,7 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer( cpu_attr.set_gpu_compatible(true); Tensor* cpu_tensor = new Tensor(cpu_dev->GetAllocator(cpu_attr), to_tensor->dtype(), to_tensor->shape()); - memcpy(DMAHelper::base(cpu_tensor), extra.tensor_content().data(), - num_bytes); + PopulateTensorFromExtra(extra, cpu_tensor); // Then copy it to the GPU. CopyTensor::ViaDMA("", // edge name (non-existent) nullptr /*send_dev_ctx*/, to_device_ctx, cpu_dev, @@ -135,8 +146,7 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer( return; } else { // CPU device - memcpy(DMAHelper::base(to_tensor), extra.tensor_content().data(), - num_bytes); + PopulateTensorFromExtra(extra, to_tensor); } } if (!s.ok() && errors::IsFailedPrecondition(s)) { diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc index bfd312410cb18f..33e1c8f2c33ff8 100644 --- a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc +++ b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc @@ -104,7 +104,7 @@ class FakeWorker : public TestWorkerInterface { // bytes in the response. RecvBufRespExtra extra; int64 num_bytes = h->prod_value->TotalBytes(); - extra.set_tensor_content(string( + extra.add_tensor_content(string( reinterpret_cast(DMAHelper::base(h->prod_value)), num_bytes)); response->mutable_transport_options()->PackFrom(extra); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index c4f2247145c20b..63d438c615567e 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -194,8 +194,8 @@ Status GrpcServer::Init( MaybeMutateBuilder(&builder); master_impl_ = CreateMaster(&master_env_); master_service_ = NewGrpcMasterService(master_impl_.get(), config, &builder); - worker_impl_ = - worker_func ? worker_func(&worker_env_) : NewGrpcWorker(&worker_env_); + worker_impl_ = worker_func ? worker_func(&worker_env_) + : NewGrpcWorker(&worker_env_, config); worker_service_ = NewGrpcWorkerService(worker_impl_.get(), &builder).release(); eager_service_ = new eager::GrpcEagerServiceImpl(&worker_env_, &builder); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index 1b6d796bd4331a..de80992095d13f 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -418,8 +418,13 @@ class GrpcWorkerService : public AsyncServiceInterface { } // namespace -GrpcWorker::GrpcWorker(WorkerEnv* worker_env) - : Worker(worker_env), recent_request_ids_(100000) {} +GrpcWorker::GrpcWorker(WorkerEnv* worker_env, const ConfigProto& config) + : Worker(worker_env), + recent_request_ids_(100000), + recv_buf_max_chunk_( + config.experimental().recv_buf_max_chunk() > 0 + ? config.experimental().recv_buf_max_chunk() + : (config.experimental().recv_buf_max_chunk() < 0 ? 0 : 4096)) {} // GrpcRecvTensorAsync: unlike the other Worker methods, which use protocol // buffers for a response object, to avoid extra protocol buffer serialization @@ -505,6 +510,33 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts, }); } +namespace { +// If RecvBufRespExtra.tensor_content is a single large string, then gRPC +// can stall on the recv side when the string buffer needs to be enlarged, +// since the size is not sent in advance. Changing this field to a sequence +// of small strings costs some extra time on the send side, since we do +// some otherwise unnecessary copies, but it improves runtime overall by +// improving flow control. Best performance is likely achieved with a +// max_chunk_bytes equal to the memory page size. +// +// TODO(tucker): When proto3 supports [ctype=CORD] then change +// RecvBufRespExtra.tensor_content to a cord instead of a repeated string, +// and remove this function. +void SetTensorInRecvBufResp(int64 max_chunk_bytes, const Tensor* tensor, + int64 num_bytes, RecvBufResponse* response) { + RecvBufRespExtra extra; + const char* head = reinterpret_cast(DMAHelper::base(tensor)); + while (num_bytes > 0) { + int64 bytes = + max_chunk_bytes > 0 ? std::min(num_bytes, max_chunk_bytes) : num_bytes; + extra.add_tensor_content(std::string(head, bytes)); + head += bytes; + num_bytes -= bytes; + } + response->mutable_transport_options()->PackFrom(extra); +} +} // namespace + void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, RecvBufResponse* response, StatusCallback done) { // This is a generic, low performance implementation appropriate for grpc. @@ -551,11 +583,8 @@ void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, [this, num_bytes, response, done, hook, cpu_tensor](const Status& s) { if (s.ok()) { - RecvBufRespExtra extra; - extra.set_tensor_content(reinterpret_cast( - DMAHelper::base(cpu_tensor)), - num_bytes); - response->mutable_transport_options()->PackFrom(extra); + SetTensorInRecvBufResp(recv_buf_max_chunk_, cpu_tensor, + num_bytes, response); } response->set_send_start_micros(env_->env->NowMicros()); done(s); @@ -566,11 +595,8 @@ void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, } } else { // Tensor is on CPU. - RecvBufRespExtra extra; - extra.set_tensor_content(reinterpret_cast( - DMAHelper::base(hook->prod_value)), - num_bytes); - response->mutable_transport_options()->PackFrom(extra); + SetTensorInRecvBufResp(recv_buf_max_chunk_, hook->prod_value, + num_bytes, response); } } response->set_send_start_micros(env_->env->NowMicros()); @@ -608,8 +634,9 @@ void GrpcWorker::LoggingAsync(const LoggingRequest* request, WorkerEnv* GrpcWorker::env() { return env_; } -std::unique_ptr NewGrpcWorker(WorkerEnv* env) { - return std::unique_ptr(new GrpcWorker(env)); +std::unique_ptr NewGrpcWorker(WorkerEnv* env, + const ConfigProto& config) { + return std::unique_ptr(new GrpcWorker(env, config)); } std::unique_ptr NewGrpcWorkerService( diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h index d9e48524dea0f2..996617d385d1c0 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h @@ -27,12 +27,13 @@ class ServerBuilder; namespace tensorflow { class AsyncServiceInterface; +class ConfigProto; struct WorkerEnv; struct WorkerSession; class GrpcWorker : public Worker { public: - GrpcWorker(WorkerEnv* env); + GrpcWorker(WorkerEnv* env, const ConfigProto& config); // Specialized version of RecvTensor for gRPC, which avoids a copy. virtual void GrpcRecvTensorAsync(CallOptions* opts, @@ -50,9 +51,11 @@ class GrpcWorker : public Worker { private: RecentRequestIds recent_request_ids_; + const int32 recv_buf_max_chunk_; }; -std::unique_ptr NewGrpcWorker(WorkerEnv* worker_env); +std::unique_ptr NewGrpcWorker(WorkerEnv* worker_env, + const ConfigProto& config); // Returns an implementation of WorkerService rpc service. std::unique_ptr NewGrpcWorkerService( diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc index 45b989f6e22676..054bed7781b8f4 100644 --- a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc +++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc @@ -49,7 +49,8 @@ CollectiveExecutor* RpcCollectiveExecutorMgr::Create(int64 step_id) { CollectiveRemoteAccessDistributed* rma = new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(), worker_cache_, step_id); - return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_); + return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_, + &gpu_ring_order_); } namespace { diff --git a/tensorflow/core/framework/api_def.proto b/tensorflow/core/framework/api_def.proto index f8553cf5bbb690..b0f852170b159a 100644 --- a/tensorflow/core/framework/api_def.proto +++ b/tensorflow/core/framework/api_def.proto @@ -34,6 +34,10 @@ message ApiDef { // that should be logged when this op is used. // The message should indicate alternative op to use, if any. string deprecation_message = 12; + // Major version when the op will be deleted. For e.g. set this + // value to 2 if op API should be removed in TensorFlow 2.0 and + // deprecated in versions before that. + int32 deprecation_version = 13; enum Visibility { // Normally this is "VISIBLE" unless you are inheriting a @@ -64,6 +68,11 @@ message ApiDef { // to use a non-deprecated endpoint instead will be printed. If all // endpoints are deprecated, set deprecation_message in ApiDef instead. bool deprecated = 3; + + // Major version when an endpoint will be deleted. For e.g. set this + // value to 2 if endpoint should be removed in TensorFlow 2.0 and + // deprecated in versions before that. + int32 deprecation_version = 4; } repeated Endpoint endpoint = 3; diff --git a/tensorflow/core/framework/collective.cc b/tensorflow/core/framework/collective.cc index 4cb277d5a886a4..7fa58347f258ac 100644 --- a/tensorflow/core/framework/collective.cc +++ b/tensorflow/core/framework/collective.cc @@ -64,6 +64,7 @@ CollInstanceParams& CollInstanceParams::operator=( device_names.assign(other.device_names.begin(), other.device_names.end()); task_names.assign(other.task_names.begin(), other.task_names.end()); same_num_devices_per_task = other.same_num_devices_per_task; + gpu_ring_order = other.gpu_ring_order; impl_details.subdiv_offsets.assign( other.impl_details.subdiv_offsets.begin(), other.impl_details.subdiv_offsets.end()); diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h index e35edb09d0c1ca..0321429702af74 100644 --- a/tensorflow/core/framework/collective.h +++ b/tensorflow/core/framework/collective.h @@ -85,6 +85,9 @@ struct CollInstanceParams { std::vector task_names; // True if every task has the same number of devices. bool same_num_devices_per_task = false; + // If passed in to GPUOptions in ConfigProto, defines a good ring order for + // GPUs. Assumes same GPU configuration at each worker. + string gpu_ring_order = ""; CollImplDetails impl_details; string ToString() const; CollInstanceParams& operator=(const struct CollInstanceParams& other); @@ -259,7 +262,9 @@ class CollectiveExecutor : public PeerAccessInterface, public core::RefCounted { virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr, StatusCallback done) { - cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr, done); + done(errors::Internal( + "A collective Op has been called in a context in which " + "a CollectiveExecutor has not been provided.")); } virtual PerStepCollectiveRemoteAccess* remote_access() { return nullptr; } diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index 284dafb886e6dc..6852b97e744b5b 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -140,7 +140,7 @@ Status GraphDefBuilderWrapper::AddFunction(SerializationContext* ctx, << " the graph. It will not be added again."; return Status::OK(); } - if (!ctx->allow_stateful_functions()) { + if (!ctx->optimization_only()) { TF_RETURN_IF_ERROR( EnsureFunctionIsStateless(ctx->flib_def(), function_name)); } @@ -203,25 +203,6 @@ bool GraphDefBuilderWrapper::HasAttr(const string& name, return HasAttr(op_def, attr_name); } -Status DatasetBase::Save(SerializationContext* ctx, - IteratorStateWriter* writer) const { - string serialized_graph_def; - string output_node; - GraphDefBuilder b; - DatasetGraphDefBuilder db(&b); - Node* node = nullptr; - TF_RETURN_IF_ERROR(AsGraphDefInternal(ctx, &db, &node)); - output_node = node->name(); - GraphDef graph_def; - TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def)); - graph_def.SerializeToString(&serialized_graph_def); - TF_RETURN_IF_ERROR( - writer->WriteScalar(kDatasetGraphKey, serialized_graph_def)); - TF_RETURN_IF_ERROR( - writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node)); - return Status::OK(); -} - Status GetDatasetFromVariantTensor(const Tensor& tensor, DatasetBase** out_dataset) { if (!(tensor.dtype() == DT_VARIANT || @@ -251,6 +232,47 @@ Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor) { return Status::OK(); } +Status DatasetBase::Save(SerializationContext* ctx, + IteratorStateWriter* writer) const { + string serialized_graph_def; + string output_node; + GraphDefBuilder b; + DatasetGraphDefBuilder db(&b); + Node* node = nullptr; + TF_RETURN_IF_ERROR(AsGraphDefInternal(ctx, &db, &node)); + output_node = node->name(); + GraphDef graph_def; + TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def)); + graph_def.SerializeToString(&serialized_graph_def); + TF_RETURN_IF_ERROR( + writer->WriteScalar(kDatasetGraphKey, serialized_graph_def)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node)); + return Status::OK(); +} + +Status DatasetBase::DatasetGraphDefBuilder::AddInputDataset( + SerializationContext* ctx, const DatasetBase* dataset, Node** output) { + Status status = dataset->AsGraphDefInternal(ctx, this, output); + if (ctx->optimization_only() && errors::IsUnimplemented(status)) { + Tensor t(DT_VARIANT, TensorShape({})); + // `StoreDatasetInVariantTensor` will transfer ownership of `dataset`. We + // increment the refcount of `dataset` here to retain ownership. + dataset->Ref(); + TF_RETURN_IF_ERROR( + StoreDatasetInVariantTensor(const_cast(dataset), &t)); + TF_RETURN_IF_ERROR(AddPlaceholder(t, output)); + DCHECK_NE(ctx->input_list(), nullptr); + ctx->input_list()->emplace_back((*output)->name(), std::move(t)); + LOG(WARNING) + << "Input of " << dataset->DebugString() + << " will not be optimized because the dataset does not implement the " + "AsGraphDefInternal() method needed to apply optimizations."; + return Status::OK(); + } + return status; +} + void DatasetOpKernel::Compute(OpKernelContext* ctx) { DatasetBase* dataset = nullptr; MakeDataset(ctx, &dataset); diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 964a7d5f8c20c9..ffd6b6202589ae 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -272,47 +272,70 @@ class StatsAggregator; class IteratorContext { public: struct Params { - // Interface to operating system functionality. - Env* env; + explicit Params(IteratorContext* ctx) + : allocator_getter(ctx->allocator_getter()), + env(ctx->env()), + function_library(ctx->function_library()), + lib(ctx->lib()), + model(ctx->model()), + runner(*(ctx->runner())), + runner_threadpool_size(ctx->runner_threadpool_size()), + stats_aggregator(ctx->stats_aggregator()) {} + + explicit Params(OpKernelContext* ctx) + : env(ctx->env()), + lib(ctx->function_library()), + runner(*(ctx->runner())), + runner_threadpool_size( + ctx->device()->tensorflow_cpu_worker_threads()->num_threads) { + // NOTE: need reinterpret_cast because function.h forward-declares Device. + DeviceBase* device = + reinterpret_cast(ctx->function_library()->device()); + allocator_getter = [device](AllocatorAttributes attrs) { + return device->GetAllocator(attrs); + }; + } - // Function call support. - std::function)> runner = nullptr; + // The Allocator to be used to allocate the output of an iterator. + std::function allocator_getter = nullptr; - // The `StatsAggregator` object to record statistics about the iterator. - std::shared_ptr stats_aggregator = nullptr; + // Interface to operating system functionality. + Env* env = nullptr; - // The FunctionLibraryRuntime object to be used to make function calls. - FunctionLibraryRuntime* lib = nullptr; + // The FunctionLibraryDefinition used to look up user-defined functions. std::shared_ptr function_library = nullptr; - // The Allocator to be used to allocate the output of an iterator. - std::function allocator_getter = nullptr; + // The FunctionLibraryRuntime object to be used to make function calls. + FunctionLibraryRuntime* lib = nullptr; // If non-null, identifies the object used for performance modeling. std::shared_ptr model = nullptr; + + // Function call support. + std::function)> runner = nullptr; + + // Number of threads used for executing user-defined functions. + int32 runner_threadpool_size = 0; + + // The `StatsAggregator` object to record statistics about the iterator. + std::shared_ptr stats_aggregator = nullptr; }; + explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {} + + explicit IteratorContext(OpKernelContext* ctx) : params_(Params{ctx}) {} + explicit IteratorContext(Params params) : params_(std::move(params)) {} - explicit IteratorContext(OpKernelContext* ctx) { - params_.env = ctx->env(); - params_.runner = *(ctx->runner()); - params_.lib = ctx->function_library(); - // NOTE: must use reinterpret_cast because function.h forward-declares - // Device. - DeviceBase* device = - reinterpret_cast(ctx->function_library()->device()); - params_.allocator_getter = [device](AllocatorAttributes attrs) { - return device->GetAllocator(attrs); - }; + Allocator* allocator(AllocatorAttributes attrs) { + return params_.allocator_getter(attrs); } - Env* env() const { return params_.env; } - - std::function)>* runner() { - return ¶ms_.runner; + std::function allocator_getter() { + return params_.allocator_getter; } + Env* env() const { return params_.env; } std::shared_ptr function_library() { return params_.function_library; @@ -320,22 +343,18 @@ class IteratorContext { FunctionLibraryRuntime* lib() { return params_.lib; } - void set_lib(FunctionLibraryRuntime* lib) { params_.lib = lib; } + std::shared_ptr model() { return params_.model; } - Allocator* allocator(AllocatorAttributes attrs) { - return params_.allocator_getter(attrs); + std::function)>* runner() { + return ¶ms_.runner; } - std::function allocator_getter() { - return params_.allocator_getter; - } + int32 runner_threadpool_size() { return params_.runner_threadpool_size; } std::shared_ptr stats_aggregator() { return params_.stats_aggregator; } - std::shared_ptr model() { return params_.model; } - Params params() { return params_; } private: @@ -346,21 +365,21 @@ class IteratorContext { class SerializationContext { public: struct Params { - bool allow_stateful_functions = false; const FunctionLibraryDefinition* flib_def = nullptr; // Not owned. std::vector>* input_list = nullptr; // Not owned. + bool optimization_only = false; }; explicit SerializationContext(Params params) : params_(std::move(params)) {} - bool allow_stateful_functions() { return params_.allow_stateful_functions; } - const FunctionLibraryDefinition& flib_def() { return *params_.flib_def; } std::vector>* input_list() { return params_.input_list; } + bool optimization_only() { return params_.optimization_only; } + private: Params params_; @@ -429,6 +448,10 @@ class IteratorBase { } protected: + // Returns a node that models this iterator. + virtual std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const = 0; + // This is needed so that sub-classes of IteratorBase can call // `SaveInternal` on their input iterators. Status SaveInput(IteratorStateWriter* writer, @@ -511,22 +534,23 @@ class DatasetBase : public core::RefCounted { // // The prefix identifies the sequence of iterators leading up to the newly // created iterator. - Status MakeIterator(IteratorContext* ctx, const string& prefix, + Status MakeIterator(IteratorContext* ctx, const string& output_prefix, std::unique_ptr* iterator) const { - *iterator = MakeIteratorInternal(prefix); - if (ctx->model()) { - ctx->model()->AddNode((*iterator)->prefix(), prefix); - std::shared_ptr model = ctx->model(); + *iterator = MakeIteratorInternal(output_prefix); + std::shared_ptr model = ctx->model(); + if (model) { const string& prefix = (*iterator)->prefix(); + model->AddNode(MakeNodeFactory(ctx, iterator->get()), prefix, + output_prefix); (*iterator)->AddCleanupFunction( [model, prefix]() { model->RemoveNode(prefix); }); } return (*iterator)->Initialize(ctx); } - Status MakeIterator(IteratorContext&& ctx, const string& prefix, + Status MakeIterator(IteratorContext&& ctx, const string& output_prefix, std::unique_ptr* iterator) const { - return MakeIterator(&ctx, prefix, iterator); + return MakeIterator(&ctx, output_prefix, iterator); } // Returns a vector of DataType values, representing the respective @@ -553,9 +577,7 @@ class DatasetBase : public core::RefCounted { public: DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {} Status AddInputDataset(SerializationContext* ctx, - const DatasetBase* dataset, Node** output) { - return dataset->AsGraphDefInternal(ctx, this, output); - } + const DatasetBase* dataset, Node** output); }; // TODO(jsimsa): Consolidate overloading into a single method. @@ -567,6 +589,14 @@ class DatasetBase : public core::RefCounted { const string& prefix) const = 0; private: + // Returns a factory for nodes that represent the given iterator. + static model::Node::Factory MakeNodeFactory(IteratorContext* ctx, + IteratorBase* iterator) { + return [ctx, iterator](model::Node::Args args) { + return iterator->CreateNode(ctx, std::move(args)); + }; + } + const string name_; }; @@ -631,28 +661,11 @@ class DatasetBaseIterator : public IteratorBase { return strings::StrCat(params_.prefix, ":", name); } - // When performance modeling is enabled, this method adds a constant parameter - // to the model node corresponding to this iterator. - void AddConstantParameter(IteratorContext* ctx, const string& name, - int64 value) { - if (ctx->model()) { - ctx->model()->AddConstantParameter(prefix(), name, value); - } - } - - // When performance modeling is enabled, this method adds a tunable parameter - // to the model node corresponding to this iterator. - // - // The performance modeling logic may use `state` to set the value of the - // tunable parameter at any point during the lifetime of this iterator. When - // it does, it acquires `state->mu` and notifies `state->cond_var`. - void AddTunableParameter(IteratorContext* ctx, const string& name, - std::shared_ptr state, int64 min, - int64 max) { - if (ctx->model()) { - ctx->model()->AddTunableParameter(prefix(), name, std::move(state), min, - max); - } + // By default we model iterators using an unknown node, which acts as + // pass-through with respect to performance modeling. + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeUnknownNode(std::move(args)); } // When performance modeling is enabled, this method records the fact that diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc index 5dfa19bef218ac..5650b4861b9306 100644 --- a/tensorflow/core/framework/model.cc +++ b/tensorflow/core/framework/model.cc @@ -21,271 +21,317 @@ namespace tensorflow { namespace data { namespace model { -// TODO(jsimsa): Use `Node` subclassing instead of types and node statements. -void Model::Node::CollectTunables( - std::vector>* tunables) { - tf_shared_lock l(mu_); - for (auto input : inputs_) { - input->CollectTunables(tunables); - } - switch (type_) { - case Type::MAP_AND_BATCH: - case Type::PARALLEL_INTERLEAVE_V2: - case Type::PARALLEL_MAP: { - if (auto* tunable_param = - gtl::FindOrNull(tunable_params_, "parallelism")) { - tunables->push_back(*tunable_param); - } - return; - } - default: - return; - } +std::shared_ptr MakeParameter(const string& name, + std::shared_ptr state, + int64 min, int64 max) { + return std::make_shared(name, state, min, max); } -int64 Model::Node::GetParameterValue(const string& name) { - if (auto* tunable_param = gtl::FindOrNull(tunable_params_, name)) { - return (*tunable_param)->value; +namespace { + +// The first input of InterleaveMany corresponds to the input dataset whose +// elements are used to create the (derived) input datasets whose elements are +// interleaved as output. +// +// TODO(jsimsa): model the first input +class InterleaveMany : public Node { + public: + using Node::Node; + + virtual ~InterleaveMany() {} + + protected: + std::shared_ptr Clone(std::shared_ptr output) const override + SHARED_LOCKS_REQUIRED(mu_) { + return std::make_shared( + Args{id_, name_, std::move(output)}); } - return constant_params_[name]; -} -int64 Model::Node::ProcessingTimeLocked() { - switch (type_) { - case Type::BATCH: - case Type::MAP_AND_BATCH: - case Type::PADDED_BATCH: { - int64 batch_size = GetParameterValue("batch_size"); - return NanosPerElementLocked() + batch_size * ProcessingTimeForInputs(); + int64 OutputTimeLocked(std::vector* input_times) const override + SHARED_LOCKS_REQUIRED(mu_) { + if (inputs_.size() <= 1) { + return NanosPerElementLocked(); } - case Type::FILTER: { - if (inputs_.size() <= 1) { - return NanosPerElementLocked(); - } - std::shared_ptr input = inputs_.front(); - double ratio = 0.0L; - if (num_elements_ > 0) { - ratio = static_cast(input->num_elements()) / - static_cast(num_elements_); - } - return NanosPerElementLocked() + - static_cast(ratio * - static_cast(ProcessingTimeForInputs())); + int64 delta = NanosPerElementLocked() * (inputs_.size() - 1); + input_times->back() += delta; + auto cleanup = gtl::MakeCleanup( + [input_times, delta]() { input_times->back() -= delta; }); + int64 output_time = + static_cast(OutputTimeForInputs(input_times) - + inputs_.front()->OutputTime(input_times)) / + static_cast(inputs_.size() - 1); + return NanosPerElementLocked() + output_time; + } + + int64 ProcessingTimeLocked() const override SHARED_LOCKS_REQUIRED(mu_) { + if (inputs_.size() <= 1) { + return NanosPerElementLocked(); } - case Type::FLAT_MAP: - case Type::INTERLEAVE: - case Type::PARALLEL_INTERLEAVE: - case Type::PARALLEL_INTERLEAVE_V2: { - // TODO(jsimsa): model the first input - // TODO(jsimsa): use processing time history as a prior for future inputs - if (inputs_.size() <= 1) { - return NanosPerElementLocked(); - } - int64 processing_time = - ProcessingTimeForInputs() - inputs_.front()->ProcessingTime(); - return NanosPerElementLocked() + - static_cast(processing_time) / - static_cast(inputs_.size() - 1); + int64 processing_time = + static_cast(ProcessingTimeForInputs() - + inputs_.front()->ProcessingTime()) / + static_cast(inputs_.size() - 1); + return NanosPerElementLocked() + processing_time; + } +}; + +// TODO(jsimsa): model the first input +class AsyncInterleaveMany : public Node { + public: + AsyncInterleaveMany(Node::Args args, + std::vector> parameters) + : Node(args) { + for (auto& parameter : parameters) { + parameters_[parameter->name] = std::move(parameter); } - case Type::CACHE: - case Type::CONCATENATE: - case Type::MAP: - case Type::PARALLEL_MAP: - case Type::PREFETCH: - // TODO(jsimsa): use processing time history as a prior for future inputs - case Type::REPEAT: - case Type::SHUFFLE: - case Type::SKIP: - case Type::TAKE: - case Type::ZIP: { - return NanosPerElementLocked() + ProcessingTimeForInputs(); + } + + virtual ~AsyncInterleaveMany() {} + + protected: + std::shared_ptr Clone(std::shared_ptr output) const override + SHARED_LOCKS_REQUIRED(mu_) { + std::vector> parameters; + for (auto& pair : parameters_) { + parameters.push_back(pair.second); } - default: - return NanosPerElementLocked(); + return std::make_shared( + Args{id_, name_, std::move(output)}, parameters); } -} -int64 Model::Node::OutputTimeLocked(std::vector* input_times) { - switch (type_) { - case Type::BATCH: - case Type::PADDED_BATCH: { - double batch_size = GetParameterValue("batch_size"); - int64 old_value = (*input_times)[input_times->size() - 1]; - (*input_times)[input_times->size() - 1] = static_cast( - static_cast(old_value + NanosPerElementLocked()) / - batch_size); - auto cleanup = gtl::MakeCleanup([input_times, old_value]() { - (*input_times)[input_times->size() - 1] = old_value; - }); - return NanosPerElementLocked() + - batch_size * OutputTimeForInputs(input_times); + int64 OutputTimeLocked(std::vector* input_times) const override + SHARED_LOCKS_REQUIRED(mu_) { + if (inputs_.size() <= 1) { + return NanosPerElementLocked(); } - case Type::FILTER: { - if (inputs_.size() <= 1) { - return NanosPerElementLocked(); - } - std::shared_ptr input = inputs_.front(); - double ratio = 0.0L; - if (num_elements_ > 0) { - ratio = static_cast(input->num_elements()) / - static_cast(num_elements_); - int64 old_value = (*input_times)[input_times->size() - 1]; - (*input_times)[input_times->size() - 1] = static_cast( - static_cast(old_value + NanosPerElementLocked()) / ratio); - auto cleanup = gtl::MakeCleanup([input_times, old_value]() { - (*input_times)[input_times->size() - 1] = old_value; - }); - } - return NanosPerElementLocked() + - static_cast( - static_cast(OutputTimeForInputs(input_times)) * ratio); + int64 old_input_time = input_times->back(); + int64 new_input_time = static_cast(NanosPerElementLocked()) * + static_cast(inputs_.size() - 1); + input_times->push_back(new_input_time); + auto cleanup = + gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); + double parallelism = inputs_.size() - 1; // default to cycle length + if (auto* parameter = gtl::FindOrNull(parameters_, "parallelism")) { + parallelism = std::min(static_cast(parallelism), + static_cast((*parameter)->value)); } - case Type::FLAT_MAP: - case Type::INTERLEAVE: { - // TODO(jsimsa): model the first input - // TODO(jsimsa): use cycle length metadata instead of `inputs_.size() - 1` - if (inputs_.size() <= 1) { - return NanosPerElementLocked(); - } - int64 delta = - static_cast(static_cast(NanosPerElementLocked()) * - static_cast(inputs_.size() - 1)); - (*input_times)[input_times->size() - 1] += delta; - auto cleanup = gtl::MakeCleanup([input_times, delta]() { - (*input_times)[input_times->size() - 1] -= delta; - }); - int64 output_time = OutputTimeForInputs(input_times) - - inputs_.front()->OutputTime(input_times); - return NanosPerElementLocked() + - static_cast(output_time) / - static_cast(inputs_.size() - 1); + int64 output_time = + static_cast(OutputTimeForInputs(input_times) - + inputs_.front()->OutputTime(input_times)) / + static_cast(inputs_.size() - 1) / parallelism; + return std::max(0LL, + NanosPerElementLocked() + output_time - old_input_time); + } + + int64 ProcessingTimeLocked() const override SHARED_LOCKS_REQUIRED(mu_) { + if (inputs_.size() <= 1) { + return NanosPerElementLocked(); } - case Type::MAP_AND_BATCH: { - double batch_size = GetParameterValue("batch_size"); - double parallelism = GetParameterValue("parallelism"); - int64 delta = - static_cast(static_cast(NanosPerElementLocked()) / - (batch_size * parallelism)); - input_times->push_back(delta); - auto cleanup = - gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); - int64 output_time = static_cast( - static_cast(NanosPerElementLocked()) / parallelism + - batch_size * OutputTimeForInputs(input_times)); - return std::max(0LL, - output_time - input_times->at(input_times->size() - 2)); + int64 processing_time = + ProcessingTimeForInputs() - inputs_.front()->ProcessingTime(); + return NanosPerElementLocked() + + static_cast(processing_time) / + static_cast(inputs_.size() - 1); + } +}; + +class KnownRatio : public Node { + public: + KnownRatio(Node::Args args, int64 ratio) : Node(args), ratio_(ratio) {} + + virtual ~KnownRatio() {} + + protected: + std::shared_ptr Clone(std::shared_ptr output) const override + SHARED_LOCKS_REQUIRED(mu_) { + return std::make_shared(Args{id_, name_, std::move(output)}, + ratio_); + } + + int64 OutputTimeLocked(std::vector* input_times) const override + SHARED_LOCKS_REQUIRED(mu_) { + if (ratio_ == 0) { + return NanosPerElementLocked(); } - case Type::PARALLEL_INTERLEAVE: { - // TODO(jsimsa): model the first input - if (inputs_.size() <= 1) { - return NanosPerElementLocked(); - } - int64 delta = static_cast(NanosPerElementLocked()) * - static_cast(inputs_.size() - 1); - input_times->push_back(delta); - auto cleanup = - gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); - int64 inputs_output_time = OutputTimeForInputs(input_times) - - inputs_.front()->OutputTime(input_times); - double parallelism = GetParameterValue("parallelism"); - int64 output_time = - NanosPerElementLocked() + ((static_cast(inputs_output_time) / - static_cast(inputs_.size() - 1)) / - parallelism); - return std::max(0LL, - output_time - input_times->at(input_times->size() - 2)); + int64 old_input_time = input_times->back(); + input_times->back() += static_cast( + static_cast(old_input_time + NanosPerElementLocked()) / ratio_); + auto cleanup = gtl::MakeCleanup([input_times, old_input_time]() { + input_times->back() = old_input_time; + }); + return NanosPerElementLocked() + ratio_ * OutputTimeForInputs(input_times); + } + + int64 ProcessingTimeLocked() const override SHARED_LOCKS_REQUIRED(mu_) { + return NanosPerElementLocked() + ratio_ * ProcessingTimeForInputs(); + } + + private: + const double ratio_; +}; + +class AsyncKnownRatio : public Node { + public: + AsyncKnownRatio(Node::Args args, double ratio, + std::vector> parameters) + : Node(args), ratio_(ratio) { + for (auto& parameter : parameters) { + parameters_[parameter->name] = std::move(parameter); } - case Type::PARALLEL_INTERLEAVE_V2: { - // TODO(jsimsa): model the first input - if (inputs_.size() <= 1) { - return NanosPerElementLocked(); - } - int64 delta = static_cast(NanosPerElementLocked()) * - static_cast(inputs_.size() - 1); - input_times->push_back(delta); - auto cleanup = - gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); - int64 inputs_output_time = OutputTimeForInputs(input_times) - - inputs_.front()->OutputTime(input_times); - double parallelism = - std::min(static_cast(GetParameterValue("cycle_length")), - static_cast(GetParameterValue("parallelism"))); - int64 output_time = - NanosPerElementLocked() + ((static_cast(inputs_output_time) / - static_cast(inputs_.size() - 1)) / - parallelism); - return std::max(0LL, - output_time - input_times->at(input_times->size() - 2)); + } + + virtual ~AsyncKnownRatio() {} + + protected: + std::shared_ptr Clone(std::shared_ptr output) const override + SHARED_LOCKS_REQUIRED(mu_) { + std::vector> parameters; + for (auto& pair : parameters_) { + parameters.push_back(pair.second); } - case Type::PARALLEL_MAP: { - double parallelism = - std::min(port::NumSchedulableCPUs(), - static_cast(GetParameterValue("parallelism"))); - int64 delta = static_cast( - static_cast(NanosPerElementLocked()) / parallelism); - input_times->push_back(delta); - auto cleanup = - gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); - int64 output_time = - static_cast(NanosPerElementLocked()) / parallelism + - OutputTimeForInputs(input_times); - return std::max(0LL, - output_time - input_times->at(input_times->size() - 2)); + return std::make_shared( + Args{id_, name_, std::move(output)}, ratio_, parameters); + } + + int64 OutputTimeLocked(std::vector* input_times) const override + SHARED_LOCKS_REQUIRED(mu_) { + double parallelism = 1.0; + if (auto* parameter = gtl::FindOrNull(parameters_, "parallelism")) { + parallelism = (*parameter)->value; } - case Type::PREFETCH: { - int64 delta = NanosPerElementLocked(); - input_times->push_back(delta); - auto cleanup = - gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); - return std::max(0LL, NanosPerElementLocked() + - OutputTimeForInputs(input_times) - - input_times->at(input_times->size() - 2)); + if (ratio_ == 0.0) { + int64 output_time = + static_cast(NanosPerElementLocked()) / parallelism; + return std::max(0LL, output_time - input_times->back()); } - case Type::CACHE: - case Type::CONCATENATE: - case Type::MAP: - case Type::REPEAT: - case Type::SHUFFLE: - case Type::SKIP: - case Type::TAKE: - case Type::ZIP: { - int64 delta = NanosPerElementLocked(); - (*input_times)[input_times->size() - 1] += delta; - auto cleanup = gtl::MakeCleanup([input_times, delta]() { - (*input_times)[input_times->size() - 1] -= delta; - }); - return NanosPerElementLocked() + OutputTimeForInputs(input_times); + int64 old_input_time = input_times->back(); + int64 new_input_time = static_cast( + static_cast(NanosPerElementLocked()) / ratio_ / parallelism); + input_times->push_back(new_input_time); + auto cleanup = + gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); + int64 output_time = static_cast( + static_cast(NanosPerElementLocked()) / parallelism + + ratio_ * OutputTimeForInputs(input_times)); + return std::max(0LL, output_time - old_input_time); + } + + int64 ProcessingTimeLocked() const override SHARED_LOCKS_REQUIRED(mu_) { + return NanosPerElementLocked() + ratio_ * ProcessingTimeForInputs(); + } + + private: + const double ratio_; +}; + +class UnknownRatio : public Node { + public: + using Node::Node; + + virtual ~UnknownRatio() {} + + protected: + std::shared_ptr Clone(std::shared_ptr output) const override + SHARED_LOCKS_REQUIRED(mu_) { + return std::make_shared(Args{id_, name_, std::move(output)}); + } + + int64 OutputTimeLocked(std::vector* input_times) const override + SHARED_LOCKS_REQUIRED(mu_) { + if (num_elements_ == 0 || inputs_.empty() || + inputs_.front()->num_elements() == 0) { + return NanosPerElementLocked(); } - default: + // TODO(jsimsa): The current implementation assumes that the number of input + // elements consumed per output is the same across all inputs. + std::shared_ptr input = inputs_.front(); + double ratio = static_cast(input->num_elements()) / + static_cast(num_elements_); + int64 old_input_time = input_times->back(); + input_times->back() = + static_cast(old_input_time + NanosPerElementLocked()) / ratio; + auto cleanup = gtl::MakeCleanup([input_times, old_input_time]() { + input_times->back() = old_input_time; + }); + return NanosPerElementLocked() + + static_cast( + ratio * static_cast(OutputTimeForInputs(input_times))); + } + + int64 ProcessingTimeLocked() const override SHARED_LOCKS_REQUIRED(mu_) { + if (inputs_.empty() || num_elements_ == 0) { return NanosPerElementLocked(); + } + // TODO(jsimsa): The current implementation that the number of input + // elements consumed per output is the same across all inputs. + std::shared_ptr input = inputs_.front(); + double ratio = static_cast(input->num_elements()) / + static_cast(num_elements_); + return NanosPerElementLocked() + + static_cast(ratio * + static_cast(ProcessingTimeForInputs())); + } +}; + +class Unknown : public Node { + public: + using Node::Node; + + virtual ~Unknown() {} + + protected: + std::shared_ptr Clone(std::shared_ptr output) const override + SHARED_LOCKS_REQUIRED(mu_) { + return std::make_shared(Args{id_, name_, std::move(output)}); + } + + int64 OutputTimeLocked(std::vector* input_times) const override + SHARED_LOCKS_REQUIRED(mu_) { + return OutputTimeForInputs(input_times); + } + + int64 ProcessingTimeLocked() const override SHARED_LOCKS_REQUIRED(mu_) { + return ProcessingTimeForInputs(); } +}; + +} // namespace + +std::shared_ptr MakeInterleaveManyNode(Node::Args args) { + return std::make_shared(std::move(args)); } -std::shared_ptr Model::Node::Snapshot( - std::shared_ptr output) { - tf_shared_lock l(mu_); - std::shared_ptr result = - std::make_shared(id_, name_, std::move(output)); - result->processing_time_ = processing_time_; - result->num_elements_ = num_elements_; - result->constant_params_ = constant_params_; - result->tunable_params_ = tunable_params_; - for (auto& input : inputs_) { - result->add_input(input->Snapshot(result)); - } - return result; +std::shared_ptr MakeAsyncInterleaveManyNode( + Node::Args args, std::vector> parameters) { + return std::make_shared(std::move(args), + std::move(parameters)); } -void Model::AddConstantParameter(const string& node_name, - const string& parameter_name, int64 value) { - tf_shared_lock l(mu_); - auto node = gtl::FindOrNull(lookup_table_, node_name); - if (node) { - (*node)->add_constant_param(parameter_name, value); - } +std::shared_ptr MakeKnownRatioNode(Node::Args args, double ratio) { + return std::make_shared(std::move(args), ratio); } -void Model::AddNode(const string& name, const string& output_name) { +std::shared_ptr MakeAsyncKnownRatioNode( + Node::Args args, double ratio, + std::vector> parameters) { + return std::make_shared(std::move(args), ratio, + std::move(parameters)); +} + +std::shared_ptr MakeSourceNode(Node::Args args) { + return MakeKnownRatioNode(std::move(args), 0); +} + +std::shared_ptr MakeUnknownRatioNode(Node::Args args) { + return std::make_shared(std::move(args)); +} + +std::shared_ptr MakeUnknownNode(Node::Args args) { + return std::make_shared(std::move(args)); +} + +void Model::AddNode(Node::Factory factory, const string& name, + const string& output_name) { // The name captures the sequence of iterators joined by `::`. We use the full // sequence as the key in the lookup table, but only the last element of the // sequence as the name node. @@ -303,7 +349,7 @@ void Model::AddNode(const string& name, const string& output_name) { if (it != lookup_table_.end()) { output = it->second; } - std::shared_ptr node(new Node(id_counter_++, tokens.back(), output)); + std::shared_ptr node = factory({id_counter_++, tokens.back(), output}); if (!output_) { output_ = node; } @@ -321,16 +367,6 @@ void Model::AddProcessingTime(const string& name, int64 delta) { } } -void Model::AddTunableParameter(const string& node_name, - const string& parameter_name, - std::shared_ptr state, int64 min, - int64 max) { - tf_shared_lock l(mu_); - auto node = *gtl::FindOrNull(lookup_table_, node_name); - DCHECK(node); - node->add_tunable_param(parameter_name, std::move(state), min, max); -} - // The optimization algorithm starts by setting all tunable parallelism // parameters to 1. It then repeatedly identifies the parameter whose increase // in parallelism decreases the output time the most. This process is repeated @@ -338,43 +374,43 @@ void Model::AddTunableParameter(const string& node_name, // is less than or equal to the processing time needed to produce an element // divided by CPU budget. void Model::Optimize(int64 cpu_budget) { - std::shared_ptr snapshot; + std::shared_ptr snapshot; { tf_shared_lock lock(mu_); snapshot = output_->Snapshot(nullptr); } const int64 processing_time = ProcessingTime(snapshot); - auto tunables = CollectTunables(snapshot); - for (auto tunable : tunables) { - tunable->value = 1; + auto parameters = CollectTunableParameters(snapshot); + for (auto& parameter : parameters) { + parameter->value = 1; } while (true) { const int64 output_time = OutputTime(snapshot); - bool all_tunables = true; - for (auto& tunable : tunables) { - if (tunable->value < tunable->max) { - all_tunables = false; + bool all_max = true; + for (auto& parameter : parameters) { + if (parameter->value < parameter->max) { + all_max = false; break; } } - if (output_time < processing_time / cpu_budget || all_tunables) { + if (output_time < processing_time / cpu_budget || all_max) { break; } int64 best_delta = -1; - Model::Node::Tunable* best_tunable = nullptr; - for (auto& tunable : tunables) { - if (tunable->value == tunable->max) { + Parameter* best_parameter = nullptr; + for (auto& parameter : parameters) { + if (parameter->value == parameter->max) { continue; } - tunable->value++; + parameter->value++; int64 delta = output_time - OutputTime(snapshot); if (delta > best_delta) { best_delta = delta; - best_tunable = tunable.get(); + best_parameter = parameter.get(); } - tunable->value--; + parameter->value--; } - if (!best_tunable) { + if (!best_parameter) { // This should never happen because we are using a model snapshot and // the output time is monotonically decreasing w.r.t. parallelism. LOG(WARNING) << "Failed to find a tunable parameter that would " @@ -382,14 +418,14 @@ void Model::Optimize(int64 cpu_budget) { "optimization attempt."; return; } - best_tunable->value++; + best_parameter->value++; } - VLOG(2) << "Number of knobs: " << tunables.size(); - for (auto& tunable : tunables) { - VLOG(2) << "Setting tunable parameter: " << tunable->value; - mutex_lock l(*tunable->state->mu); - tunable->state->value = tunable->value; - tunable->state->cond_var->notify_all(); + VLOG(2) << "Number of tunable parameters: " << parameters.size(); + for (auto& parameter : parameters) { + VLOG(2) << "Setting tunable parameter: " << parameter->value; + mutex_lock l(*parameter->state->mu); + parameter->state->value = parameter->value; + parameter->state->cond_var->notify_all(); } } @@ -432,19 +468,19 @@ void Model::RemoveNode(const string& name) { lookup_table_.erase(name); } -std::vector> Model::CollectTunables( - std::shared_ptr node) { - std::vector> tunables; - node->CollectTunables(&tunables); - return tunables; +std::vector> Model::CollectTunableParameters( + std::shared_ptr node) { + std::vector> parameters; + node->CollectTunableParameters(¶meters); + return parameters; } -int64 Model::OutputTime(std::shared_ptr node) { +int64 Model::OutputTime(std::shared_ptr node) { std::vector input_times(1, 0); return node->OutputTime(&input_times); } -int64 Model::ProcessingTime(std::shared_ptr node) { +int64 Model::ProcessingTime(std::shared_ptr node) { return node->ProcessingTime(); } diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h index 8c376492e1ac34..635a760b22aee4 100644 --- a/tensorflow/core/framework/model.h +++ b/tensorflow/core/framework/model.h @@ -37,15 +37,288 @@ namespace model { // the performance model. struct SharedState { public: - explicit SharedState(int64 value, std::shared_ptr mu, - std::shared_ptr cond_var) - : mu(std::move(mu)), cond_var(std::move(cond_var)), value(value) {} + SharedState(int64 value, std::shared_ptr mu, + std::shared_ptr cond_var) + : value(value), mu(std::move(mu)), cond_var(std::move(cond_var)) {} + int64 value; std::shared_ptr mu; std::shared_ptr cond_var; + bool tunable = false; +}; + +// Represents a parameter. +struct Parameter { + Parameter(const string& name, std::shared_ptr state, int64 min, + int64 max) + : name(name), + value(state->value), + min(min), + max(max), + state(std::move(state)) {} + + // Human-readable name of the parameter. + string name; + + // Identifies the model value of the parameter. This can be different from + // the actual value (e.g. during optimization search). int64 value; + + // Identifies the minimum value of the parameter. + int64 min; + + // Identifies the maximum value of the parameter. + int64 max; + + // Shared state of the parameter. + std::shared_ptr state; +}; + +std::shared_ptr MakeParameter(const string& name, + std::shared_ptr state, + int64 min, int64 max); + +// Abstract representation of a TensorFlow input pipeline node. It collects +// information about inputs to this node, processing time spent executing the +// node logic, number of elements produced by the node, various other +// information (e.g. batch size or execution parallelism). +// +// Developers of tf.data transformations are not expected to interact with +// this class directly. Boiler plate code for creating the abstract +// representation of the input pipeline and collecting common information has +// been added to the implementation of `DatasetBase` and `DatasetBaseIterator` +// respectively. +// +// In addition, `DatasetBaseIterator` provides wrappers that can be used for +// transformation-specific information collection. The `SetMetadata` wrapper +// can be used to pass arbitrary metadata to the modeling framework, while the +// `StartWork` and `StopWork` wrappers should be used to correctly account for +// processing time of multi-threaded transformation that yield the CPU; such +// transformations should invoke `StartWork()` when a transformation thread +// starts executing (e.g. when created or woken up) and `StopWork()` when a +// transformation thread stops executing (e.g. when returning or waiting). +class Node { + public: + // Arguments for `Node` constructor. + struct Args { + int64 id; + string name; + std::shared_ptr output; + }; + + using Factory = std::function(Args)>; + + explicit Node(Args args) + : id_(args.id), name_(args.name), output_(args.output.get()) {} + + // Adds an input. + void add_input(std::shared_ptr node) LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + inputs_.push_back(node); + } + + // Increments the aggregate processing time by the given delta. + void add_processing_time(int64 delta) LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + processing_time_ += delta; + } + + // Returns the unique node ID. + int64 id() const LOCKS_EXCLUDED(mu_) { return id_; } + + // Returns the node name. + const string& name() const { return name_; } + + // Returns the node inputs. + std::list> inputs() const LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return inputs_; + } + + // Returns the number of elements produced by the node. + int64 num_elements() const LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return num_elements_; + } + + // Returns the node output. + Node* output() const LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return output_; + } + + // Returns the aggregate processing time. + int64 processing_time() const LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return processing_time_; + } + + // Records that the node produced an element. + void record_element() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + num_elements_++; + } + + // Records that a node thread has started executing. + void record_start() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + work_start_[std::this_thread::get_id()] = Env::Default()->NowNanos(); + } + + // Records that a node thread has stopped executing. + void record_stop() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + std::thread::id tid = std::this_thread::get_id(); + auto start_time = gtl::FindOrNull(work_start_, tid); + if (start_time) { + processing_time_ += Env::Default()->NowNanos() - *start_time; + work_start_.erase(tid); + } else { + LOG(WARNING) + << "Encountered a stop event that was not preceded by a start event."; + } + } + + // Removes an input. + void remove_input(std::shared_ptr input) LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + inputs_.remove(input); + } + + // Collects tunable parameters in the subtree rooted in this node. + void CollectTunableParameters( + std::vector>* parameters) LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + for (auto& pair : parameters_) { + if (pair.second->state->tunable) { + parameters->push_back(pair.second); + } + } + for (auto& input : inputs_) { + input->CollectTunableParameters(parameters); + } + } + + // Returns the per-element output time for this node. + int64 OutputTime(std::vector* input_times) const LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return OutputTimeLocked(input_times); + } + + // Returns the per-element processing time spent in the subtree rooted in + // this node. + int64 ProcessingTime() const LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return ProcessingTimeLocked(); + } + + // Returns a copy of this node, making a deep copy of its inputs and a + // shallow copy of its tunable parameters. + // + // The purpose for this method is to allow the model optimization logic to + // operate over immutable state while allowing concurrent model updates. + std::shared_ptr Snapshot(std::shared_ptr output) + LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + std::shared_ptr result = Clone(output); + result->processing_time_ = processing_time_; + result->num_elements_ = num_elements_; + result->parameters_ = parameters_; + for (auto& input : inputs_) { + result->add_input(input->Snapshot(result)); + } + return result; + } + + protected: + // Creates a clone of this node. + virtual std::shared_ptr Clone(std::shared_ptr output) const + SHARED_LOCKS_REQUIRED(mu_) = 0; + + // Returns the per-element processing time spent in this node. + int64 NanosPerElementLocked() const SHARED_LOCKS_REQUIRED(mu_) { + if (num_elements_ == 0) { + return 0; + } + return static_cast(static_cast(processing_time_) / + static_cast(num_elements_)); + } + + // Returns the sum of per-element output time for the inputs of this node. + int64 OutputTimeForInputs(std::vector* input_times) const + SHARED_LOCKS_REQUIRED(mu_) { + int64 sum = 0; + for (auto& input : inputs_) { + sum += input->OutputTime(input_times); + } + return sum; + } + + // Returns the per-element output time for this node. + virtual int64 OutputTimeLocked(std::vector* input_times) const + SHARED_LOCKS_REQUIRED(mu_) = 0; + + // Returns the sum of per-element processing time for the inputs of this node. + // + // TODO(jsimsa): use processing time history as a prior for future inputs + int64 ProcessingTimeForInputs() const SHARED_LOCKS_REQUIRED(mu_) { + int64 sum = 0; + for (auto& input : inputs_) { + sum += input->ProcessingTime(); + } + return sum; + } + + // Returns the per-element processing time spent in the subtree rooted in + // this node. + virtual int64 ProcessingTimeLocked() const SHARED_LOCKS_REQUIRED(mu_) = 0; + + mutable mutex mu_; + const int64 id_; + const string name_; + int64 processing_time_ GUARDED_BY(mu_) = 0; + int64 num_elements_ GUARDED_BY(mu_) = 0; + std::map work_start_ GUARDED_BY(mu_); + std::map> parameters_ GUARDED_BY(mu_); + std::list> inputs_ GUARDED_BY(mu_); + + // The reference to the output node is not owned so that that deletion of a + // node results in recursive deletion of the subtree rooted in the node. + Node* output_ GUARDED_BY(mu_); }; +// InterleaveMany is used to model datasets whose inputs are used to create +// datasets whose elements are then interleaved. +std::shared_ptr MakeInterleaveManyNode(Node::Args args); + +// AsyncInterleaveMany nodes are the asynchronous version of InterleaveMany +// nodes. +std::shared_ptr MakeAsyncInterleaveManyNode( + Node::Args args, std::vector> parameters); + +// KnownMany nodes model datasets that synchronously consume known number of +// input element per output element. +std::shared_ptr MakeKnownRatioNode(Node::Args args, double ratio); + +// AsyncKnownRatio nodes are the asynchronous version of KnownRate nodes. +std::shared_ptr MakeAsyncKnownRatioNode( + Node::Args args, double ratio, + std::vector> parameters); + +// Source nodes represent data sources. +std::shared_ptr MakeSourceNode(Node::Args args); + +// UnknownMany nodes represent datasets that synchronously consume an +// unknown number of input elements per output. +// +// Unlike KnownRatio nodes which expect the ratio between inputs and outputs is +// specified as a parameter, UnknownRatio estimates the ratio empirically. +std::shared_ptr MakeUnknownRatioNode(Node::Args args); + +// Unknown nodes represent datasets for which we do not have a model. It acts +// as pass-through between inputs and output. +std::shared_ptr MakeUnknownNode(Node::Args args); + // Abstract representation of a TensorFlow input pipeline that can be used // for collecting runtime information and optimizing performance. It collects // runtime information about execution of the input pipeline that is used to @@ -60,24 +333,13 @@ class Model { public: Model() = default; - // Adds a constant parameter for the given node. - void AddConstantParameter(const string& node_name, - const string& parameter_name, int64 value) - LOCKS_EXCLUDED(mu_); - - // Adds a node with the given name and given output (identified by name). - void AddNode(const string& name, const string& output_name) - LOCKS_EXCLUDED(mu_); + // Adds a node with the given name and given output. + void AddNode(Node::Factory factory, const string& name, + const string& output_name) LOCKS_EXCLUDED(mu_); // Increments the processing time for the given node.. void AddProcessingTime(const string& name, int64 delta) LOCKS_EXCLUDED(mu_); - // Adds a tunable parameter for the given node. - void AddTunableParameter(const string& node_name, - const string& parameter_name, - std::shared_ptr value, int64 min, - int64 max) LOCKS_EXCLUDED(mu_); - // Runs optimization. void Optimize(int64 cpu_budget) LOCKS_EXCLUDED(mu_); @@ -96,305 +358,8 @@ class Model { void RemoveNode(const string& name) LOCKS_EXCLUDED(mu_); private: - // Abstract representation of a TensorFlow input pipeline node. It collects - // information about inputs to this node, processing time spent executing the - // node logic, number of elements produced by the node, various other - // information (e.g. batch size or execution parallelism). - // - // Developers of tf.data transformations are not expected to interact with - // this class directly. Boiler plate code for creating the abstract - // representation of the input pipeline and collecting common information has - // been added to the implementation of `DatasetBase` and `DatasetBaseIterator` - // respectively. - // - // In addition, `DatasetBaseIterator` provides wrappers that can be used for - // transformation-specific information collection. The `SetMetadata` wrapper - // can be used to pass arbitrary metadata to the modeling framework, while the - // `StartWork` and `StopWork` wrappers should be used to correctly account for - // processing time of multi-threaded transformation that yield the CPU; such - // transformations should invoke `StartWork()` when a transformation thread - // starts executing (e.g. when created or woken up) and `StopWork()` when a - // transformation thread stops executing (e.g. when returning or waiting). - // - // TODO(jsimsa): Create an API to capture the abstract semantics of each - // tf.data transformation and replace switch-case blocks with inheritance. - class Node { - public: - // Represents a tunable parameter. - struct Tunable { - Tunable(std::shared_ptr state, int64 min, int64 max) - : value(state->value), min(min), max(max), state(std::move(state)) {} - - // Identifies the model value of the parameter. This can be different from - // the actual value (e.g. during optimization search). - int64 value; - - // Identifies the minimum value of the parameter. - int64 min; - - // Identifies the maximum value of the parameter. - int64 max; - - // Shared state of the parameter. - std::shared_ptr state; - }; - - Node(int64 id, const string& name, std::shared_ptr output) - : id_(id), name_(name), type_(TypeFromName(name)), output_(output) {} - - // Adds a constant parameter. - void add_constant_param(const string& name, int64 value) - LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); - constant_params_[name] = value; - } - - // Adds an input. - void add_input(std::shared_ptr node) LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); - inputs_.push_back(node); - } - - // Increments the aggregate processing time by the given delta. - void add_processing_time(int64 delta) LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); - processing_time_ += delta; - } - - // Adds a tunable parameter. - void add_tunable_param(const string& name, - std::shared_ptr state, int64 min, - int64 max) LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); - tunable_params_[name] = - std::make_shared(std::move(state), min, max); - } - - // Returns the unique node ID. - int64 id() LOCKS_EXCLUDED(mu_) { return id_; } - - // Returns the node inputs. - std::list> inputs() LOCKS_EXCLUDED(mu_) { - tf_shared_lock l(mu_); - return inputs_; - } - - // Returns the node name. - const string& name() LOCKS_EXCLUDED(mu_) { - tf_shared_lock l(mu_); - return name_; - } - - // Returns the number of elements produced by the node. - int64 num_elements() LOCKS_EXCLUDED(mu_) { - tf_shared_lock l(mu_); - return num_elements_; - } - - // Returns the node output. - std::shared_ptr output() LOCKS_EXCLUDED(mu_) { - tf_shared_lock l(mu_); - return output_; - } - - // Records that the node produced an element. - void record_element() LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); - num_elements_++; - } - - // Records that a node thread has started executing. - void record_start() LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); - work_start_[std::this_thread::get_id()] = Env::Default()->NowNanos(); - } - - // Records that a node thread has stopped executing. - void record_stop() LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); - std::thread::id tid = std::this_thread::get_id(); - auto start_time = gtl::FindOrNull(work_start_, tid); - DCHECK(start_time) - << "Encountered a stop event that was not preceded by a start event."; - if (start_time) { - processing_time_ += Env::Default()->NowNanos() - *start_time; - work_start_.erase(tid); - } - } - - // Removes an input. - void remove_input(std::shared_ptr input) LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); - inputs_.remove(input); - } - - // Set the node output. - void set_output(std::shared_ptr output) LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); - output_ = output; - } - - // Collects tunable parameters in the subtree rooted in this node. - void CollectTunables(std::vector>* tunables) - LOCKS_EXCLUDED(mu_); - - // Returns the per-element output time for this node. - int64 OutputTime(std::vector* input_times) LOCKS_EXCLUDED(mu_) { - tf_shared_lock l(mu_); - return OutputTimeLocked(input_times); - } - - // Returns the per-element processing time spent in the subtree rooted in - // this node. - int64 ProcessingTime() LOCKS_EXCLUDED(mu_) { - tf_shared_lock l(mu_); - return ProcessingTimeLocked(); - } - - // Returns a copy of this node, making a deep copy of its inputs and a - // shallow copy of its tunable parameters. - // - // The purpose for this method is to allow the model optimization logic to - // operate over immutable state while allowing concurrent model updates. - std::shared_ptr Snapshot(std::shared_ptr output) - LOCKS_EXCLUDED(mu_); - - private: - enum class Type { - BATCH = 0, - CACHE, - CONCATENATE, - FILTER, - FLAT_MAP, - INTERLEAVE, - MAP, - MAP_AND_BATCH, - PADDED_BATCH, - PARALLEL_INTERLEAVE, - PARALLEL_INTERLEAVE_V2, - PARALLEL_MAP, - PREFETCH, - REPEAT, - SHUFFLE, - SKIP, - TAKE, - ZIP, - UNKNOWN, - }; - - // Gets a value of the given parameter (tunable or constant). - int64 GetParameterValue(const string& name) SHARED_LOCKS_REQUIRED(mu_); - - // Returns the per-element processing time spent in this node. - int64 NanosPerElement() LOCKS_EXCLUDED(mu_) { - tf_shared_lock l(mu_); - return NanosPerElementLocked(); - } - - int64 NanosPerElementLocked() SHARED_LOCKS_REQUIRED(mu_) { - if (num_elements_ == 0) { - return 0; - } - return (int64)((double)processing_time_ / (double)num_elements_); - } - - int64 OutputTimeLocked(std::vector* input_times) - SHARED_LOCKS_REQUIRED(mu_); - - int64 OutputTimeForInputs(std::vector* input_times) - SHARED_LOCKS_REQUIRED(mu_) { - int64 sum = 0; - for (auto input : inputs_) { - sum += input->OutputTime(input_times); - } - return sum; - } - - int64 ProcessingTimeLocked() SHARED_LOCKS_REQUIRED(mu_); - - // Returns the per-element processing time spent in the inputs of this node. - int64 ProcessingTimeForInputs() SHARED_LOCKS_REQUIRED(mu_) { - int64 sum = 0; - for (auto input : inputs_) { - sum += input->ProcessingTime(); - } - return sum; - } - - Type TypeFromName(const string& name) SHARED_LOCKS_REQUIRED(mu_) { - if (name_ == "Batch") { - return Type::BATCH; - } - if (str_util::EndsWith(name_, "Cache")) { - return Type::CACHE; - } - if (name_ == "Concatenate") { - return Type::CONCATENATE; - } - if (name_ == "Filter") { - return Type::FILTER; - } - if (name_ == "FlatMap") { - return Type::FLAT_MAP; - } - if (name_ == "Interleave") { - return Type::INTERLEAVE; - } - if (name_ == "Map") { - return Type::MAP; - } - if (name_ == "MapAndBatch" || name_ == "NumaMapAndBatch") { - return Type::MAP_AND_BATCH; - } - if (name_ == "PaddedBatch") { - return Type::PADDED_BATCH; - } - if (name_ == "ParallelInterleave") { - return Type::PARALLEL_INTERLEAVE; - } - if (name_ == "ParallelInterleaveV2") { - return Type::PARALLEL_INTERLEAVE_V2; - } - if (name_ == "ParallelMap") { - return Type::PARALLEL_MAP; - } - if (name_ == "Prefetch") { - return Type::PREFETCH; - } - if (str_util::EndsWith(name_, "Repeat")) { - return Type::REPEAT; - } - if (name_ == "Shuffle") { - return Type::SHUFFLE; - } - if (str_util::EndsWith(name_, "Skip")) { - return Type::SKIP; - } - if (str_util::EndsWith(name_, "Take")) { - return Type::TAKE; - } - if (name_ == "Zip") { - return Type::ZIP; - } - return Type::UNKNOWN; - } - - mutex mu_; - const int64 id_; - const string name_; - const Type type_; - int64 processing_time_ GUARDED_BY(mu_) = 0; - int64 num_elements_ GUARDED_BY(mu_) = 0; - std::map work_start_ GUARDED_BY(mu_); - std::map constant_params_ GUARDED_BY(mu_); - // Tunables are shared with the model during optimization. - std::map> tunable_params_ GUARDED_BY(mu_); - std::list> inputs_ GUARDED_BY(mu_); - std::shared_ptr output_ GUARDED_BY(mu_); - }; - - // Collects tunables in the tree rooted in the given node. - std::vector> CollectTunables( + // Collects tunable parameters in the tree rooted in the given node. + std::vector> CollectTunableParameters( std::shared_ptr node); // Collects the output time for the given node. diff --git a/tensorflow/core/framework/model_test.cc b/tensorflow/core/framework/model_test.cc new file mode 100644 index 00000000000000..53e35f25b28cb3 --- /dev/null +++ b/tensorflow/core/framework/model_test.cc @@ -0,0 +1,336 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/model.h" +#include + +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace data { +namespace model { +namespace { + +class AsyncInterleaveManyTest + : public ::testing::TestWithParam> {}; + +TEST_P(AsyncInterleaveManyTest, Model) { + const int64 parallelism = std::get<0>(GetParam()); + const int64 input_time = std::get<1>(GetParam()); + std::shared_ptr async_interleave_many = + model::MakeAsyncInterleaveManyNode( + {0, "async_interleave_many", nullptr}, + {model::MakeParameter( + "parallelism", + std::make_shared(parallelism, nullptr, nullptr), 1, + parallelism)}); + std::shared_ptr meta_source = + model::MakeSourceNode({1, "meta_source", async_interleave_many}); + async_interleave_many->add_input(meta_source); + auto cleanup_meta = gtl::MakeCleanup([async_interleave_many, meta_source]() { + async_interleave_many->remove_input(meta_source); + }); + std::shared_ptr source1 = + model::MakeSourceNode({1, "source1", async_interleave_many}); + async_interleave_many->add_input(source1); + auto cleanup1 = gtl::MakeCleanup([async_interleave_many, source1]() { + async_interleave_many->remove_input(source1); + }); + std::shared_ptr source2 = + model::MakeSourceNode({2, "source2", async_interleave_many}); + async_interleave_many->add_input(source2); + auto cleanup2 = gtl::MakeCleanup([async_interleave_many, source2]() { + async_interleave_many->remove_input(source2); + }); + std::vector input_times(1, input_time); + async_interleave_many->add_processing_time(100); + EXPECT_EQ(100, async_interleave_many->processing_time()); + EXPECT_EQ(0, async_interleave_many->ProcessingTime()); + EXPECT_EQ(0, async_interleave_many->OutputTime(&input_times)); + async_interleave_many->record_element(); + EXPECT_EQ(1, async_interleave_many->num_elements()); + EXPECT_EQ(100, async_interleave_many->ProcessingTime()); + EXPECT_EQ(std::max(0LL, 100 - input_time), + async_interleave_many->OutputTime(&input_times)); + source1->add_processing_time(200); + source2->add_processing_time(300); + EXPECT_EQ(100, async_interleave_many->ProcessingTime()); + EXPECT_EQ(std::max(0LL, 100 - input_time), + async_interleave_many->OutputTime(&input_times)); + source1->record_element(); + source2->record_element(); + EXPECT_EQ(100 + 250, async_interleave_many->ProcessingTime()); + EXPECT_EQ(std::max(0LL, 100 + 250 / parallelism - input_time), + async_interleave_many->OutputTime(&input_times)); + async_interleave_many->record_element(); + EXPECT_EQ(50 + 250, async_interleave_many->ProcessingTime()); + EXPECT_EQ(std::max(0LL, 50 + 250 / parallelism - input_time), + async_interleave_many->OutputTime(&input_times)); +} + +INSTANTIATE_TEST_CASE_P(Test, AsyncInterleaveManyTest, + ::testing::Combine(::testing::Values(1, 2), + ::testing::Values(0, 50, 100, 200))); + +class AsyncKnownRatioTest + : public ::testing::TestWithParam> {}; + +TEST_P(AsyncKnownRatioTest, Model) { + const int64 parallelism = std::get<0>(GetParam()); + const int64 input_time = std::get<1>(GetParam()); + const int64 num_inputs_per_output = std::get<2>(GetParam()); + std::shared_ptr async_known_many = model::MakeAsyncKnownRatioNode( + {0, "async_known_many", nullptr}, num_inputs_per_output, + {model::MakeParameter( + "parallelism", + std::make_shared(parallelism, nullptr, nullptr), 1, + parallelism)}); + std::shared_ptr source1 = + model::MakeSourceNode({1, "source1", async_known_many}); + async_known_many->add_input(source1); + std::shared_ptr source2 = + model::MakeSourceNode({2, "source2", async_known_many}); + async_known_many->add_input(source2); + std::vector input_times(1, input_time); + source1->add_processing_time(100); + EXPECT_EQ(0, async_known_many->ProcessingTime()); + EXPECT_EQ(0, async_known_many->OutputTime(&input_times)); + source2->add_processing_time(200); + EXPECT_EQ(0, async_known_many->ProcessingTime()); + EXPECT_EQ(0, async_known_many->OutputTime(&input_times)); + source1->record_element(); + EXPECT_EQ(num_inputs_per_output * 100, async_known_many->ProcessingTime()); + EXPECT_EQ(std::max(0LL, num_inputs_per_output * 100 - input_time), + async_known_many->OutputTime(&input_times)); + source2->record_element(); + EXPECT_EQ(num_inputs_per_output * (100 + 200), + async_known_many->ProcessingTime()); + EXPECT_EQ(std::max(0LL, num_inputs_per_output * (100 + 200) - input_time), + async_known_many->OutputTime(&input_times)); + source1->record_element(); + EXPECT_EQ(num_inputs_per_output * (50 + 200), + async_known_many->ProcessingTime()); + EXPECT_EQ(std::max(0LL, num_inputs_per_output * (50 + 200) - input_time), + async_known_many->OutputTime(&input_times)); + source2->record_element(); + EXPECT_EQ(num_inputs_per_output * (50 + 100), + async_known_many->ProcessingTime()); + EXPECT_EQ(std::max(0LL, num_inputs_per_output * (50 + 100) - input_time), + async_known_many->OutputTime(&input_times)); + async_known_many->add_processing_time(128); + EXPECT_EQ(num_inputs_per_output * (50 + 100), + async_known_many->ProcessingTime()); + EXPECT_EQ(std::max(0LL, num_inputs_per_output * (50 + 100) - input_time), + async_known_many->OutputTime(&input_times)); + async_known_many->record_element(); + EXPECT_EQ(num_inputs_per_output * (50 + 100) + 128, + async_known_many->ProcessingTime()); + EXPECT_EQ(std::max(0LL, num_inputs_per_output * (50 + 100) + + 128 / parallelism - input_time), + async_known_many->OutputTime(&input_times)); + async_known_many->record_element(); + EXPECT_EQ(num_inputs_per_output * (50 + 100) + 64, + async_known_many->ProcessingTime()); + EXPECT_EQ(std::max(0LL, num_inputs_per_output * (50 + 100) + + 64 / parallelism - input_time), + async_known_many->OutputTime(&input_times)); +} + +INSTANTIATE_TEST_CASE_P(Test, AsyncKnownRatioTest, + ::testing::Combine(::testing::Values(1, 2, 4, 8), + ::testing::Values(0, 50, 100, 200), + ::testing::Values(0, 1, 2, 4))); + +TEST(InterleaveManyTest, Model) { + std::shared_ptr interleave_many = + model::MakeInterleaveManyNode({0, "interleave_many", nullptr}); + std::shared_ptr meta_source = + model::MakeSourceNode({1, "meta_source", interleave_many}); + interleave_many->add_input(meta_source); + std::shared_ptr source1 = + model::MakeSourceNode({1, "source1", interleave_many}); + interleave_many->add_input(source1); + std::shared_ptr source2 = + model::MakeSourceNode({2, "source2", interleave_many}); + interleave_many->add_input(source2); + std::vector input_times(1, 0); + interleave_many->add_processing_time(100); + EXPECT_EQ(100, interleave_many->processing_time()); + EXPECT_EQ(0, interleave_many->ProcessingTime()); + EXPECT_EQ(0, interleave_many->OutputTime(&input_times)); + interleave_many->record_element(); + EXPECT_EQ(1, interleave_many->num_elements()); + EXPECT_EQ(100, interleave_many->ProcessingTime()); + EXPECT_EQ(100, interleave_many->OutputTime(&input_times)); + source1->add_processing_time(200); + source2->add_processing_time(300); + EXPECT_EQ(100, interleave_many->ProcessingTime()); + EXPECT_EQ(100, interleave_many->OutputTime(&input_times)); + source1->record_element(); + source2->record_element(); + EXPECT_EQ(350, interleave_many->ProcessingTime()); + EXPECT_EQ(350, interleave_many->OutputTime(&input_times)); + interleave_many->record_element(); + EXPECT_EQ(300, interleave_many->ProcessingTime()); + EXPECT_EQ(300, interleave_many->OutputTime(&input_times)); +} + +class KnownRatioTest : public ::testing::TestWithParam {}; + +TEST_P(KnownRatioTest, Model) { + const int64 num_inputs_per_output = GetParam(); + std::shared_ptr known_many = model::MakeKnownRatioNode( + {0, "known_many", nullptr}, num_inputs_per_output); + std::shared_ptr source1 = + model::MakeSourceNode({1, "source1", known_many}); + known_many->add_input(source1); + std::shared_ptr source2 = + model::MakeSourceNode({2, "source2", known_many}); + known_many->add_input(source2); + std::vector input_times(1, 0); + source1->add_processing_time(100); + EXPECT_EQ(0, known_many->ProcessingTime()); + EXPECT_EQ(0, known_many->OutputTime(&input_times)); + source2->add_processing_time(200); + EXPECT_EQ(0, known_many->ProcessingTime()); + EXPECT_EQ(0, known_many->OutputTime(&input_times)); + source1->record_element(); + EXPECT_EQ(num_inputs_per_output * 100, known_many->ProcessingTime()); + EXPECT_EQ(num_inputs_per_output * 100, known_many->OutputTime(&input_times)); + source2->record_element(); + EXPECT_EQ(num_inputs_per_output * (100 + 200), known_many->ProcessingTime()); + EXPECT_EQ(num_inputs_per_output * (100 + 200), + known_many->OutputTime(&input_times)); + source1->record_element(); + EXPECT_EQ(num_inputs_per_output * (50 + 200), known_many->ProcessingTime()); + EXPECT_EQ(num_inputs_per_output * (50 + 200), + known_many->OutputTime(&input_times)); + source2->record_element(); + EXPECT_EQ(num_inputs_per_output * (50 + 100), known_many->ProcessingTime()); + EXPECT_EQ(num_inputs_per_output * (50 + 100), + known_many->OutputTime(&input_times)); + known_many->add_processing_time(128); + EXPECT_EQ(num_inputs_per_output * (50 + 100), known_many->ProcessingTime()); + EXPECT_EQ(num_inputs_per_output * (50 + 100), + known_many->OutputTime(&input_times)); + known_many->record_element(); + EXPECT_EQ(num_inputs_per_output * (50 + 100) + 128, + known_many->ProcessingTime()); + EXPECT_EQ(num_inputs_per_output * (50 + 100) + 128, + known_many->OutputTime(&input_times)); + known_many->record_element(); + EXPECT_EQ(num_inputs_per_output * (50 + 100) + 64, + known_many->ProcessingTime()); + EXPECT_EQ(num_inputs_per_output * (50 + 100) + 64, + known_many->OutputTime(&input_times)); +} + +INSTANTIATE_TEST_CASE_P(Test, KnownRatioTest, ::testing::Values(0, 1, 2, 4)); + +TEST(SourceTest, Model) { + std::shared_ptr source = model::MakeSourceNode({0, "source", nullptr}); + std::vector input_times(1, 0); + source->add_processing_time(100); + EXPECT_EQ(100, source->processing_time()); + EXPECT_EQ(0, source->ProcessingTime()); + EXPECT_EQ(0, source->OutputTime(&input_times)); + source->record_element(); + EXPECT_EQ(1, source->num_elements()); + EXPECT_EQ(100, source->ProcessingTime()); + EXPECT_EQ(100, source->OutputTime(&input_times)); + source->record_element(); + EXPECT_EQ(2, source->num_elements()); + EXPECT_EQ(50, source->ProcessingTime()); + EXPECT_EQ(50, source->OutputTime(&input_times)); +} + +TEST(UnknownRatioTest, Model) { + std::shared_ptr unknown_many = + model::MakeUnknownRatioNode({0, "unknown_many", nullptr}); + std::shared_ptr source1 = + model::MakeSourceNode({1, "source1", unknown_many}); + unknown_many->add_input(source1); + std::shared_ptr source2 = + model::MakeSourceNode({2, "source2", unknown_many}); + unknown_many->add_input(source2); + std::vector input_times(1, 0); + unknown_many->add_processing_time(100); + EXPECT_EQ(100, unknown_many->processing_time()); + EXPECT_EQ(0, unknown_many->ProcessingTime()); + EXPECT_EQ(0, unknown_many->OutputTime(&input_times)); + unknown_many->record_element(); + EXPECT_EQ(1, unknown_many->num_elements()); + EXPECT_EQ(100, unknown_many->ProcessingTime()); + EXPECT_EQ(100, unknown_many->OutputTime(&input_times)); + source1->add_processing_time(100); + source2->add_processing_time(200); + EXPECT_EQ(100, unknown_many->ProcessingTime()); + EXPECT_EQ(100, unknown_many->OutputTime(&input_times)); + source1->record_element(); + source2->record_element(); + EXPECT_EQ(400, unknown_many->ProcessingTime()); + EXPECT_EQ(400, unknown_many->OutputTime(&input_times)); + unknown_many->record_element(); + EXPECT_EQ(200, unknown_many->ProcessingTime()); + EXPECT_EQ(200, unknown_many->OutputTime(&input_times)); +} + +TEST(UnknownTest, Model) { + std::shared_ptr unknown = + model::MakeUnknownNode({0, "unknown", nullptr}); + std::shared_ptr source1 = + model::MakeSourceNode({1, "source1", unknown}); + unknown->add_input(source1); + std::shared_ptr source2 = + model::MakeSourceNode({2, "source2", unknown}); + unknown->add_input(source2); + std::vector input_times(1, 0); + source1->add_processing_time(100); + EXPECT_EQ(0, unknown->ProcessingTime()); + EXPECT_EQ(0, unknown->OutputTime(&input_times)); + source2->add_processing_time(100); + EXPECT_EQ(0, unknown->ProcessingTime()); + EXPECT_EQ(0, unknown->OutputTime(&input_times)); + source1->record_element(); + EXPECT_EQ(100, unknown->ProcessingTime()); + EXPECT_EQ(100, unknown->OutputTime(&input_times)); + source2->record_element(); + EXPECT_EQ(200, unknown->ProcessingTime()); + EXPECT_EQ(200, unknown->OutputTime(&input_times)); + source1->record_element(); + EXPECT_EQ(150, unknown->ProcessingTime()); + EXPECT_EQ(150, unknown->OutputTime(&input_times)); + source2->record_element(); + EXPECT_EQ(100, unknown->ProcessingTime()); + EXPECT_EQ(100, unknown->OutputTime(&input_times)); + // Unknown node processing time should not affect its ProcessingTime() or + // OutputTime(). + unknown->add_processing_time(100); + EXPECT_EQ(100, unknown->processing_time()); + EXPECT_EQ(100, unknown->ProcessingTime()); + EXPECT_EQ(100, unknown->OutputTime(&input_times)); + // Unknown node number of elements should not affect its ProcessingTime() or + // OutputTime(). + unknown->record_element(); + EXPECT_EQ(1, unknown->num_elements()); + EXPECT_EQ(100, unknown->ProcessingTime()); + EXPECT_EQ(100, unknown->OutputTime(&input_times)); +} + +} // namespace +} // namespace model +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 1eb12d3f9539ca..5f08c130871751 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" -#include +#include // NOLINT #include #include #include @@ -286,6 +286,13 @@ OpKernelContext::~OpKernelContext() { } } if (params_->record_tensor_accesses) referenced_tensors_.Destroy(); + if (params_->track_allocations && !wrapped_allocators_.empty()) { + LOG(WARNING) << "OpKernelContext is tracking allocations but they are not " + << "being consumed by the StepStatsCollector."; + for (auto& wrapped_alloator : wrapped_allocators_) { + wrapped_alloator.second->GetRecordsAndUnRef(); + } + } } Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) { diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 6c71e118c0244e..165115aab32b7a 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -982,9 +982,10 @@ class OpKernelContext { return params_->output_attr_array[index]; } - gtl::InlinedVector wrapped_allocators() const { + gtl::InlinedVector ConsumeWrappedAllocators() { mutex_lock lock(mu_); - gtl::InlinedVector retrieved = wrapped_allocators_; + gtl::InlinedVector retrieved; + retrieved.swap(wrapped_allocators_); return retrieved; } diff --git a/tensorflow/core/framework/variant_tensor_data.h b/tensorflow/core/framework/variant_tensor_data.h index 8a240ee1e35fc4..8c69c870345a68 100644 --- a/tensorflow/core/framework/variant_tensor_data.h +++ b/tensorflow/core/framework/variant_tensor_data.h @@ -38,6 +38,8 @@ class VariantTensorDataProto; class VariantTensorData { public: VariantTensorData(); + // TODO(b/118823936): This silently returns if the proto is invalid. + // Consider calling FromProto explicitly instead. VariantTensorData(VariantTensorDataProto proto); ~VariantTensorData(); diff --git a/tensorflow/core/graph/edgeset.cc b/tensorflow/core/graph/edgeset.cc index 2e0c67146169d4..02315a3e27b9e8 100644 --- a/tensorflow/core/graph/edgeset.cc +++ b/tensorflow/core/graph/edgeset.cc @@ -37,7 +37,7 @@ std::pair EdgeSet::insert(value_type value) { } } // array is full. convert to set. - s = new std::set; + s = new gtl::FlatSet; for (int i = 0; i < kInline; i++) { s->insert(static_cast(ptrs_[i])); } diff --git a/tensorflow/core/graph/edgeset.h b/tensorflow/core/graph/edgeset.h index 0a1ee5a666cbd0..2776c8491c2b3f 100644 --- a/tensorflow/core/graph/edgeset.h +++ b/tensorflow/core/graph/edgeset.h @@ -17,17 +17,18 @@ limitations under the License. #define TENSORFLOW_GRAPH_EDGESET_H_ #include -#include -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { class Edge; // An unordered set of edges. Uses very little memory for small sets. -// Unlike std::set, EdgeSet does NOT allow mutations during iteration. +// Unlike gtl::FlatSet, EdgeSet does NOT allow mutations during +// iteration. class EdgeSet { public: EdgeSet(); @@ -54,12 +55,15 @@ class EdgeSet { private: // Up to kInline elements are stored directly in ptrs_ (nullptr means none). // If ptrs_[0] == this then ptrs_[1] points to a set. - static const int kInline = 4; // Must be >= 2. + // kInline must be >= 2, and is chosen such that ptrs_ fills a 64 byte + // cacheline. + static constexpr int kInline = 64 / sizeof(const void*); const void* ptrs_[kInline]; - std::set* get_set() const { + gtl::FlatSet* get_set() const { if (ptrs_[0] == this) { - return static_cast*>(const_cast(ptrs_[1])); + return static_cast*>( + const_cast(ptrs_[1])); } else { return nullptr; } @@ -99,7 +103,7 @@ class EdgeSet::const_iterator { friend class EdgeSet; void const* const* array_iter_ = nullptr; - typename std::set::const_iterator tree_iter_; + typename gtl::FlatSet::const_iterator tree_iter_; #ifdef NDEBUG inline void Init(const EdgeSet* e) {} diff --git a/tensorflow/core/graph/edgeset_test.cc b/tensorflow/core/graph/edgeset_test.cc index b4cef8f336550f..c5d2d6c70f0266 100644 --- a/tensorflow/core/graph/edgeset_test.cc +++ b/tensorflow/core/graph/edgeset_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/graph/edgeset.h" +#include #include #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/platform/test.h" @@ -22,30 +23,27 @@ limitations under the License. namespace tensorflow { class EdgeSetTest : public ::testing::Test { public: - EdgeSetTest() : edges_(nullptr), eset_(nullptr) {} - - ~EdgeSetTest() override { - delete eset_; - delete[] edges_; - } + EdgeSetTest() : edges_(nullptr) {} + ~EdgeSetTest() override { delete[] edges_; } void MakeEdgeSet(int n) { - delete eset_; - delete[] edges_; + if (edges_) { + delete[] edges_; + } edges_ = new Edge[n]; - eset_ = new EdgeSet; + eset_.clear(); model_.clear(); for (int i = 0; i < n; i++) { - eset_->insert(&edges_[i]); + eset_.insert(&edges_[i]); model_.insert(&edges_[i]); } } void CheckSame() { - EXPECT_EQ(model_.size(), eset_->size()); - EXPECT_EQ(model_.empty(), eset_->empty()); + EXPECT_EQ(model_.size(), eset_.size()); + EXPECT_EQ(model_.empty(), eset_.empty()); std::vector modelv(model_.begin(), model_.end()); - std::vector esetv(eset_->begin(), eset_->end()); + std::vector esetv(eset_.begin(), eset_.end()); std::sort(modelv.begin(), modelv.end()); std::sort(esetv.begin(), esetv.end()); EXPECT_EQ(modelv.size(), esetv.size()); @@ -54,26 +52,27 @@ class EdgeSetTest : public ::testing::Test { } } + static constexpr int kInline = 64 / sizeof(const void*); Edge nonexistent_; Edge* edges_; - EdgeSet* eset_; + EdgeSet eset_; std::set model_; }; namespace { TEST_F(EdgeSetTest, Ops) { - for (int n : {0, 1, 2, 3, 4, 10}) { + for (int n : {0, 1, 2, kInline + 1}) { MakeEdgeSet(n); CheckSame(); - EXPECT_EQ((n == 0), eset_->empty()); - EXPECT_EQ(n, eset_->size()); + EXPECT_EQ((n == 0), eset_.empty()); + EXPECT_EQ(n, eset_.size()); - eset_->clear(); + eset_.clear(); model_.clear(); CheckSame(); - eset_->insert(&edges_[0]); + eset_.insert(&edges_[0]); model_.insert(&edges_[0]); CheckSame(); } @@ -81,15 +80,14 @@ TEST_F(EdgeSetTest, Ops) { // Try insert/erase of existing elements at different positions. TEST_F(EdgeSetTest, Exists) { - for (int n : {0, 1, 2, 3, 4, 10}) { + for (int n : {0, 1, 2, kInline + 1}) { MakeEdgeSet(n); for (int pos = 0; pos < n; pos++) { - MakeEdgeSet(n); - auto p = eset_->insert(&edges_[pos]); + auto p = eset_.insert(&edges_[pos]); EXPECT_FALSE(p.second); EXPECT_EQ(&edges_[pos], *p.first); - EXPECT_EQ(1, eset_->erase(&edges_[pos])); + EXPECT_EQ(1, eset_.erase(&edges_[pos])); model_.erase(&edges_[pos]); CheckSame(); } @@ -98,10 +96,10 @@ TEST_F(EdgeSetTest, Exists) { // Try insert/erase of non-existent element. TEST_F(EdgeSetTest, DoesNotExist) { - for (int n : {0, 1, 2, 3, 4, 10}) { + for (int n : {0, 1, 2, kInline + 1}) { MakeEdgeSet(n); - EXPECT_EQ(0, eset_->erase(&nonexistent_)); - auto p = eset_->insert(&nonexistent_); + EXPECT_EQ(0, eset_.erase(&nonexistent_)); + auto p = eset_.insert(&nonexistent_); EXPECT_TRUE(p.second); EXPECT_EQ(&nonexistent_, *p.first); } diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc index 333c32567fc9b9..e7762fd4147dfb 100644 --- a/tensorflow/core/graph/graph_test.cc +++ b/tensorflow/core/graph/graph_test.cc @@ -799,5 +799,44 @@ BENCHMARK(BM_GraphCreation)->ArgPair(1 << 9, 16); BENCHMARK(BM_GraphCreation)->ArgPair(1 << 12, 16); BENCHMARK(BM_GraphCreation)->ArgPair(1 << 15, 16); +static void BM_ToGraphDef(int iters, int num_nodes, int num_edges_per_node) { + testing::StopTiming(); + const GraphDef graph_def = CreateGraphDef(num_nodes, num_edges_per_node); + const auto registry = OpRegistry::Global(); + GraphConstructorOptions opts; + // Warmup step. + Graph graph(registry); + TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph)); + int64 sum = 0; + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + sum += graph_def.node_size(); + } + VLOG(1) << sum; + testing::StopTiming(); +} +BENCHMARK(BM_ToGraphDef)->ArgPair(10, 2); +BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 2); +BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 2); +BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 2); +BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 2); +BENCHMARK(BM_ToGraphDef)->ArgPair(10, 4); +BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 4); +BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 4); +BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 4); +BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 4); +BENCHMARK(BM_ToGraphDef)->ArgPair(10, 8); +BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 8); +BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 8); +BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 8); +BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 8); +BENCHMARK(BM_ToGraphDef)->ArgPair(10, 16); +BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 16); +BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 16); +BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 16); +BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 16); + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/graph/optimizer_cse_test.cc b/tensorflow/core/graph/optimizer_cse_test.cc index c1f93ce05ae99f..642298fa95d9be 100644 --- a/tensorflow/core/graph/optimizer_cse_test.cc +++ b/tensorflow/core/graph/optimizer_cse_test.cc @@ -337,9 +337,13 @@ TEST_F(OptimizerCSETest, Constant_Dedup) { EXPECT_EQ(OriginalGraph(), "n/_0(Const);n/_1(Const);n/_2(Const);n/_3(Const);" "n/_4(Const);n/_5(Const);n/_6(Const);n/_7(Const)|"); - // In theory, there are 2^4 possible correct output of CSE. In this - // test, it happens to eliminate the last 4 nodes. - EXPECT_EQ(DoCSE(), "n/_0(Const);n/_1(Const);n/_2(Const);n/_3(Const)|"); + std::vector nodes = str_util::Split(DoCSE(), ";|"); + std::set node_set(nodes.begin(), nodes.end()); + // Expect exactly one of each type of node to be retained after CSE. + EXPECT_EQ(node_set.count("n/_0(Const)") + node_set.count("n/_7(Const)"), 1); + EXPECT_EQ(node_set.count("n/_1(Const)") + node_set.count("n/_6(Const)"), 1); + EXPECT_EQ(node_set.count("n/_2(Const)") + node_set.count("n/_5(Const)"), 1); + EXPECT_EQ(node_set.count("n/_3(Const)") + node_set.count("n/_4(Const)"), 1); } static void BM_CSE(int iters, int op_nodes) { diff --git a/tensorflow/core/graph/tensor_id.h b/tensorflow/core/graph/tensor_id.h index 0ba39426184e2c..b0f621fa6c4abc 100644 --- a/tensorflow/core/graph/tensor_id.h +++ b/tensorflow/core/graph/tensor_id.h @@ -41,6 +41,9 @@ struct TensorId : public std::pair { TensorId() : Base() {} TensorId(const SafeTensorId& id); + const StringPiece node() const { return first; } + int index() const { return second; } + string ToString() const { if (second == Graph::kControlSlot) return strings::StrCat("^", first); return strings::StrCat(first, ":", second); @@ -68,6 +71,9 @@ struct SafeTensorId : public std::pair { SafeTensorId(const string& str, int idx) : Base(str, idx) {} SafeTensorId(const TensorId& id); + const string& node() const { return first; } + int index() const { return second; } + string ToString() const { if (second == Graph::kControlSlot) return strings::StrCat("^", first); return strings::StrCat(first, ":", second); diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD index 3bad29a2390839..7b03ec38bf5bb1 100644 --- a/tensorflow/core/grappler/BUILD +++ b/tensorflow/core/grappler/BUILD @@ -23,6 +23,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/core:framework", + "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", @@ -67,8 +68,14 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":utils", + "//tensorflow/core:graph", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", ], ) @@ -82,6 +89,8 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", ], ) @@ -165,8 +174,10 @@ cc_library( ":graph_view", ":grappler_item", ":utils", + "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -179,6 +190,7 @@ tf_cc_test( "//tensorflow/cc:cc_ops", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", ], ) diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index 144d7f8ce6c784..5090e62b2ccfb0 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -42,9 +42,10 @@ cc_library( deps = [ ":utils", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "//tensorflow/core/grappler/utils:functions", "//tensorflow/core/grappler/utils:topological_sort", - "//tensorflow/core/grappler:graph_view", + "//tensorflow/core/grappler:mutable_graph_view", "//tensorflow/core/grappler:op_types", "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index ae48b9f159f40e..270b75269c7942 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -29,8 +29,9 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/grappler/costs/utils.h" -#include "tensorflow/core/grappler/graph_view.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/functions.h" @@ -290,9 +291,9 @@ bool HasAnyUnknownDimensions(const TensorShapeProto& proto) { // This really should be done in an external debugging tool void VerboseLogUnknownDimensionSources( const GraphDef& graph, - const std::map>& + const std::unordered_map>& input_properties_map, - const std::map>& + const std::unordered_map>& output_properties_map) { if (!VLOG_IS_ON(2)) { return; @@ -456,10 +457,10 @@ class SymbolicShapeRefiner { const GraphView& graph, const std::unordered_map>& fed_ports) : graph_(graph), - function_library_(OpRegistry::Global(), graph.GetGraph()->library()), + function_library_(OpRegistry::Global(), graph.graph()->library()), fed_ports_(fed_ports) { - graph_def_version_ = graph.GetGraph()->versions().producer(); - node_to_context_.reserve(graph.GetGraph()->node_size()); + graph_def_version_ = graph.graph()->versions().producer(); + node_to_context_.reserve(graph.graph()->node_size()); } const GraphView& graph() const { return graph_; } @@ -512,7 +513,7 @@ class SymbolicShapeRefiner { // Placeholder with Const) don't affect one in // fun_to_grappler_function_item_. GrapplerFunctionItem grappler_function_item = it->second; - GraphView gv(&grappler_function_item.graph); + MutableGraphView gv(&grappler_function_item.graph); // Forward shapes from function input nodes to argument nodes. for (int i = 0; i < grappler_function_item.inputs().size(); ++i) { @@ -524,27 +525,26 @@ class SymbolicShapeRefiner { "supported."); } NodeDef* fun_node = gv.GetNode(fun_input.input_name); - const string& input = function_node->input(i); - const string& node_name = NodeName(input); + const TensorId input_tensor = ParseTensorName(function_node->input(i)); - if (IsControlInput(input)) { + if (IsControlInput(input_tensor)) { return errors::FailedPrecondition( "Function inputs should not contain control nodes."); } - NodeDef* input_node = graph_.GetNode(node_name); + const NodeDef* input_node = graph_.GetNode(input_tensor.node()); if (input_node == nullptr) { - return errors::FailedPrecondition(node_name, + return errors::FailedPrecondition(input_tensor.node(), " was not found in the graph."); } InferenceContext* input_inference_context = GetContext(input_node); if (input_inference_context == nullptr) { return errors::FailedPrecondition( - "Inference context has not been created for ", node_name); + "Inference context has not been created for ", input_tensor.node()); } - int output_port_num = NodePosition(input); + int output_port_num = input_tensor.index(); AttrValue attr_output_shape; TensorShapeProto proto; const auto& handle = input_inference_context->output(output_port_num); @@ -566,7 +566,7 @@ class SymbolicShapeRefiner { for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) { const string& input = function_node->input(i); const string& node_name = NodeName(input); - NodeDef* input_node = graph_.GetNode(node_name); + const NodeDef* input_node = graph_.GetNode(node_name); if (IsConstant(*input_node)) { TF_CHECK_OK( ReplaceInputWithConst(*input_node, i, &grappler_function_item)); @@ -609,24 +609,22 @@ class SymbolicShapeRefiner { // It is guaranteed that output_tensors does not contain any control // inputs, so port_id >= 0. - string out_tensor = out_arg.output_tensors[0]; - int port_id; - string node_name = ParseNodeName(out_tensor, &port_id); + TensorId out_tensor = ParseTensorName(out_arg.output_tensors[0]); - const NodeDef* retnode = gv.GetNode(node_name); + const NodeDef* retnode = gv.GetNode(out_tensor.node()); if (retnode == nullptr) { return errors::FailedPrecondition( - "Unable to find return function_node ", node_name, " for ", + "Unable to find return function_node ", out_tensor.node(), " for ", function_node->name()); } auto output_properties = gp.GetOutputProperties(retnode->name()); - if (port_id >= output_properties.size()) { + if (out_tensor.index() >= output_properties.size()) { return errors::InvalidArgument( - out_tensor, " has invalid position ", port_id, + out_tensor.ToString(), " has invalid position ", out_tensor.index(), " (output_properties.size() = ", output_properties.size(), ")."); } - auto const& outprop = output_properties[port_id]; + auto const& outprop = output_properties[out_tensor.index()]; const TensorShapeProto& shape = outprop.shape(); ShapeHandle out; TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &out)); @@ -1427,8 +1425,8 @@ Status GraphProperties::UpdateMergeNode(SymbolicShapeRefiner* shape_refiner, continue; } ShapeHandle input = in->output(fanin.src.port_id); - CHECK_EQ(fanin.tgt.node, node); - c->SetInput(fanin.tgt.port_id, input); + CHECK_EQ(fanin.dst.node, node); + c->SetInput(fanin.dst.port_id, input); if (!out_initialized) { out_initialized = true; out = input; @@ -1653,13 +1651,12 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) { std::unordered_map> fed_ports; if (!assume_valid_feeds) { for (const auto& feed : item_.feed) { - int port_index = 0; - string node_name = ParseNodeName(feed.first, &port_index); - fed_ports[node_name].insert(port_index); + SafeTensorId tensor_id = ParseTensorName(feed.first); + fed_ports[tensor_id.node()].insert(tensor_id.index()); } } - GraphView graph_view(const_cast(&item_.graph)); + GraphView graph_view(&item_.graph); // List the resources and the nodes using them. Also collect the Merge nodes, // fed nodes, and primary inputs. @@ -1711,10 +1708,10 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) { for (const auto& resource : resources) { for (const NodeDef* src : resource.second.first) { resource_handles[src] = resource.first; - for (const NodeDef* tgt : resource.second.second) { + for (const NodeDef* dst : resource.second.second) { // Add control edges from enqueue to dequeue nodes to ensure they are // processed in their logical order. - extra_deps.emplace_back(src, tgt); + extra_deps.emplace_back(src, dst); } } } @@ -1923,12 +1920,12 @@ Status GraphProperties::InferFromCostGraph(const CostGraphDef& cost_graph) { return Status::OK(); } -bool GraphProperties::HasInputProperties(const string& name) const { - return input_properties_.find(name) != input_properties_.end(); +bool GraphProperties::HasInputProperties(const string& node_name) const { + return input_properties_.find(node_name) != input_properties_.end(); } -bool GraphProperties::HasOutputProperties(const string& name) const { - return output_properties_.find(name) != output_properties_.end(); +bool GraphProperties::HasOutputProperties(const string& node_name) const { + return output_properties_.find(node_name) != output_properties_.end(); } const std::vector& diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h index 28fd7565ccf5ba..fbae1ca5b437c1 100644 --- a/tensorflow/core/grappler/costs/graph_properties.h +++ b/tensorflow/core/grappler/costs/graph_properties.h @@ -63,8 +63,8 @@ class GraphProperties { // values strictly less than -1 to encode symbolic dimensions: although we // don't know the actual value of the symbolic dimension, we know that all the // dimensions denoted by the same negative value are the equal. - bool HasInputProperties(const string& name) const; - bool HasOutputProperties(const string& name) const; + bool HasInputProperties(const string& node_name) const; + bool HasOutputProperties(const string& node_name) const; const std::vector& GetInputProperties( const string& node_name) const; const std::vector& GetOutputProperties( @@ -123,8 +123,10 @@ class GraphProperties { // Data members const GrapplerItem& item_; - std::map> input_properties_; - std::map> output_properties_; + std::unordered_map> + input_properties_; + std::unordered_map> + output_properties_; const std::vector missing_properties_; }; diff --git a/tensorflow/core/grappler/graph_view.cc b/tensorflow/core/grappler/graph_view.cc index de0a63fc4e39de..9b3958b6c175d8 100644 --- a/tensorflow/core/grappler/graph_view.cc +++ b/tensorflow/core/grappler/graph_view.cc @@ -63,217 +63,5 @@ int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) { return OpPortIdToArgId(node, op.input_arg(), port_id); } -GraphView::GraphView(GraphDef* graph) : graph_(graph) { - for (int i = 0; i < graph_->node_size(); i++) { - auto node = graph_->mutable_node(i); - AddUniqueNodeOrDie(node); - } - - for (NodeDef& node : *graph_->mutable_node()) { - AddFanouts(&node); - } -} - -void GraphView::AddUniqueNodeOrDie(NodeDef* node) { - auto result = nodes_.emplace(node->name(), node); - // Check that the graph doesn't contain multiple nodes with the same name. - CHECK(result.second) << "Non unique node name detected: " << node->name(); -} - -void GraphView::AddFanouts(NodeDef* node) { - for (int i = 0; i < node->input_size(); ++i) { - OutputPort fanin; - const string fanin_name = ParseNodeName(node->input(i), &fanin.port_id); - fanin.node = nodes_[fanin_name]; - - InputPort input; - input.node = node; - if (fanin.port_id < 0) { - input.port_id = -1; - } else { - input.port_id = i; - num_regular_outputs_[fanin.node] = - std::max(num_regular_outputs_[fanin.node], fanin.port_id); - } - - fanouts_[fanin].insert(input); - } -} - -NodeDef* GraphView::GetNode(const string& node_name) const { - auto it = nodes_.find(node_name); - if (it == nodes_.end()) { - return nullptr; - } - return it->second; -} - -GraphView::InputPort GraphView::GetInputPort(const string& node_name, - int port_id) const { - InputPort result; - result.node = GetNode(node_name); - // TODO(bsteiner): verify that the node has at least port_id input ports - result.port_id = port_id; - return result; -} - -GraphView::OutputPort GraphView::GetOutputPort(const string& node_name, - int port_id) const { - OutputPort result; - result.node = GetNode(node_name); - // TODO(bsteiner): verify that the node has at least port_id output ports - result.port_id = port_id; - return result; -} - -const std::unordered_set& -GraphView::GetFanout(const GraphView::OutputPort& port) const { - auto it = fanouts_.find(port); - if (it == fanouts_.end()) { - return empty_set_; - } - return it->second; -} - -std::unordered_set -GraphView::GetFanin(const GraphView::InputPort& port) const { - std::unordered_set result; - if (port.port_id >= 0) { - result.insert(GetRegularFanin(port)); - } else { - for (int i = port.node->input_size() - 1; i >= 0; --i) { - OutputPort fanin; - string fanin_name = ParseNodeName(port.node->input(i), &fanin.port_id); - if (fanin.port_id < 0) { - auto it = nodes_.find(fanin_name); - if (it != nodes_.end()) { - fanin.node = it->second; - result.insert(fanin); - } - } else { - break; - } - } - } - return result; -} - -const GraphView::OutputPort GraphView::GetRegularFanin( - const GraphView::InputPort& port) const { - CHECK_LE(0, port.port_id); - OutputPort fanin; - string fanin_name = - ParseNodeName(port.node->input(port.port_id), &fanin.port_id); - auto it = nodes_.find(fanin_name); - if (it == nodes_.end()) { - fanin.node = nullptr; - } else { - fanin.node = it->second; - } - return fanin; -} - -std::unordered_set -GraphView::GetFanouts(const NodeDef& node, - bool include_controlled_nodes) const { - std::unordered_set result; - OutputPort port; - port.node = const_cast(&node); - const int first_port_id = include_controlled_nodes ? -1 : 0; - auto it = num_regular_outputs_.find(&node); - const int last_port_id = (it != num_regular_outputs_.end()) ? it->second : -1; - - for (int i = first_port_id; i <= last_port_id; ++i) { - port.port_id = i; - auto it = fanouts_.find(port); - if (it != fanouts_.end()) { - result.insert(it->second.begin(), it->second.end()); - } - } - return result; -} - -std::unordered_set -GraphView::GetFanins(const NodeDef& node, - bool include_controlling_nodes) const { - std::unordered_set result; - for (int i = 0; i < node.input_size(); ++i) { - OutputPort fanin; - string fanin_name = ParseNodeName(node.input(i), &fanin.port_id); - if (fanin.port_id < 0) { - if (!include_controlling_nodes) { - break; - } - } - auto it = nodes_.find(fanin_name); - if (it != nodes_.end()) { - fanin.node = it->second; - result.insert(fanin); - } - } - return result; -} - -int GraphView::NumFanins(const NodeDef& node, - bool include_controlling_nodes) const { - int count = 0; - for (const string& input : node.input()) { - if (!include_controlling_nodes && IsControlInput(input)) { - break; - } - count += 1; - } - return count; -} - -std::unordered_set -GraphView::GetFanoutEdges(const NodeDef& node, - bool include_controlled_edges) const { - std::unordered_set result; - OutputPort port; - port.node = const_cast(&node); - const int first_port_id = include_controlled_edges ? -1 : 0; - auto it = num_regular_outputs_.find(&node); - const int last_port_id = (it != num_regular_outputs_.end()) ? it->second : -1; - - for (int i = first_port_id; i <= last_port_id; ++i) { - port.port_id = i; - auto it = fanouts_.find(port); - if (it != fanouts_.end()) { - Edge fanout; - fanout.src.node = const_cast(&node); - fanout.src.port_id = i; - for (auto itr = it->second.begin(); itr != it->second.end(); ++itr) { - fanout.tgt = *itr; - result.insert(fanout); - } - } - } - return result; -} - -std::unordered_set -GraphView::GetFaninEdges(const NodeDef& node, - bool include_controlling_edges) const { - std::unordered_set result; - for (int i = 0; i < node.input_size(); ++i) { - Edge fanin; - fanin.tgt.node = const_cast(&node); - fanin.tgt.port_id = i; - string fanin_name = ParseNodeName(node.input(i), &fanin.src.port_id); - if (fanin.src.port_id < 0) { - if (!include_controlling_edges) { - break; - } - } - auto it = nodes_.find(fanin_name); - if (it != nodes_.end()) { - fanin.src.node = it->second; - result.insert(fanin); - } - } - return result; -} - } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h index 09c36a136834cf..77f4ec730a3a91 100644 --- a/tensorflow/core/grappler/graph_view.h +++ b/tensorflow/core/grappler/graph_view.h @@ -18,9 +18,16 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/hash/hash.h" +#include "absl/strings/string_view.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -36,114 +43,303 @@ namespace grappler { int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id); int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id); -// A utility class to simplify the traversal of a GraphDef. -class GraphView { +namespace internal { + +// GraphViewInternal is a helper class to simplify graph traversal. It creates +// an immutable view of the nodes and edges represented by a GraphDef protocol +// buffer. +// +// There are two public classes implementing GraphViewInternal: +// +// - GraphView: constructed from the `const GraphDef` and doesn't allow +// to mutate underlying graph via input/output ports lookup functions (ports +// have const pointers to nodes). +// +// - MutableGraphView: constructed from the 'GraphDef` and allows to mutate +// the graph via input/output ports lookup functions (ports have non-const +// pointers to nodes), and also have couple additional functions to +// add/remove/replace nodes in the graph. +// +// --------------------------- !!! WARNING !!! --------------------------------- +// Removing nodes from the graph outside of MutableGraphView will +// lead to segfaults! Guaranteed by absl::string_view! +// ----------------------------------------------------------------------------- +// +template +class GraphViewInternal { public: struct Port { - Port() = default; - Port(NodeDef* n, int port) : node(n), port_id(port) {} - - // TODO(prazek): ports should keep the constness of GraphView. The only way - // to modify graph through the view should be using MutableGraphView. - NodeDef* node = nullptr; - int port_id = -1; + Port() : node(nullptr), port_id(0) {} + Port(NodeDefT* n, int port) : node(n), port_id(port) {} bool operator==(const Port& other) const { return node == other.node && port_id == other.port_id; } + + template + friend H AbslHashValue(H h, const Port& p) { + return H::combine(std::move(h), p.node, p.port_id); + } + + NodeDefT* node; + int port_id; }; + struct InputPort : public Port { - InputPort() = default; - InputPort(NodeDef* n, int port_id) : Port(n, port_id) {} - InputPort(const NodeDef* n, int port_id) - : Port(const_cast(n), port_id) {} - }; - struct OutputPort : public Port { - OutputPort() = default; - OutputPort(NodeDef* n, int port_id) : Port(n, port_id) {} + using Port::Port; }; - struct HashPort { - std::size_t operator()(const Port& port) const { - return reinterpret_cast(port.node) + port.port_id; - } + struct OutputPort : public Port { + using Port::Port; }; struct Edge { - OutputPort src; - InputPort tgt; + Edge(OutputPort s, InputPort d) : src(s), dst(d) {} bool operator==(const Edge& other) const { - return src == other.src && tgt == other.tgt; + return src == other.src && dst == other.dst; } - }; - struct HashEdge { - std::size_t operator()(const Edge& edge) const { - return HashPort()(edge.src) + HashPort()(edge.tgt); + + template + friend H AbslHashValue(H h, const Edge& e) { + return H::combine(std::move(h), e.src, e.dst); } + + OutputPort src; + InputPort dst; }; - explicit GraphView(GraphDef* graph); - GraphDef* GetGraph() const { return graph_; } - NodeDef* GetNode(const string& node_name) const; + GraphDefT* graph() const { return graph_; } + + // Find a node by name or return `nullptr` if it's not in a graph view. + NodeDefT* GetNode(absl::string_view node_name) const { + return gtl::FindWithDefault(nodes_, node_name, nullptr); + } + // Get the specified input port. Note that the special '-1' port_id can be // used to access the controlling nodes (i.e. the nodes connected to node_name // through an incoming control dependency). - InputPort GetInputPort(const string& node_name, int port_id) const; + InputPort GetInputPort(absl::string_view node_name, int port_id) const { + return InputPort(GetNode(node_name), port_id); + } + // Get the specified output port. Note that the special '-1' port_id can be // used to access the controlled nodes (i.e. the nodes connected to node_name // through an outgoing control dependency). - OutputPort GetOutputPort(const string& node_name, int port_id) const; + OutputPort GetOutputPort(absl::string_view node_name, int port_id) const { + return OutputPort(GetNode(node_name), port_id); + } // Get the input (resp. output) port(s) in the immediate fanout (resp. fanin) // of an output (resp. input) port. - const std::unordered_set& GetFanout( - const OutputPort& port) const; - std::unordered_set GetFanin( - const InputPort& port) const; + const absl::flat_hash_set& GetFanout( + const OutputPort& port) const { + return gtl::FindWithDefault(fanouts_, port, fanout_not_found_value_); + } + + absl::flat_hash_set GetFanin(const InputPort& port) const { + if (port.port_id >= 0) return {GetRegularFanin(port)}; + + // Collect fanin for the control input. + absl::flat_hash_set result; + for (int i = port.node->input_size() - 1; i >= 0; --i) { + TensorId tensor_id = ParseTensorName(port.node->input(i)); + if (tensor_id.index() >= 0) break; // we reached regular inputs + + auto it = nodes_.find(tensor_id.node()); + if (it != nodes_.end()) result.emplace(it->second, tensor_id.index()); + } + return result; + } // Special case: regular (i.e. non-control) input ports can only have one // fanin. - const OutputPort GetRegularFanin(const InputPort& port) const; + const OutputPort GetRegularFanin(const InputPort& port) const { + DCHECK_GE(port.port_id, 0); + if (port.port_id < 0) return OutputPort(); + + TensorId tensor_id = ParseTensorName(port.node->input(port.port_id)); + return GetOutputPort(tensor_id.node(), tensor_id.index()); + } + + // Get all the input (resp. output) ports in the immediate fanout (resp + // fanin) of a node. Include the controlling nodes iff + // include_controlling_nodes is true. + absl::flat_hash_set GetFanouts( + const NodeDef& node, bool include_controlled_nodes) const { + absl::flat_hash_set result; + + OutputPort port; + port.node = const_cast(&node); + const int first_port_id = include_controlled_nodes ? -1 : 0; + const int last_port_id = + gtl::FindWithDefault(max_regular_output_port_, port.node, -1); + + for (int i = first_port_id; i <= last_port_id; ++i) { + port.port_id = i; + auto it = fanouts_.find(port); + if (it != fanouts_.end()) { + result.insert(it->second.begin(), it->second.end()); + } + } + return result; + } + + absl::flat_hash_set GetFanins( + const NodeDef& node, bool include_controlling_nodes) const { + absl::flat_hash_set result; + for (int i = 0; i < node.input_size(); ++i) { + TensorId tensor_id = ParseTensorName(node.input(i)); + if (tensor_id.index() < 0 && !include_controlling_nodes) break; - // Get all the input (resp. output) ports in the immediate fanout (resp fanin) - // of a node. Include the controlling nodes iff include_controlling_nodes is - // true. - std::unordered_set GetFanouts( - const NodeDef& node, bool include_controlled_nodes) const; - std::unordered_set GetFanins( - const NodeDef& node, bool include_controlling_nodes) const; + auto it = nodes_.find(tensor_id.node()); + if (it != nodes_.end()) result.emplace(it->second, tensor_id.index()); + } + return result; + } // Get the number of ports in the immediate fanin of a node. Count the // controlling nodes iff include_controlling_nodes is true. - int NumFanins(const NodeDef& node, bool include_controlling_nodes) const; + int NumFanins(const NodeDef& node, bool include_controlling_nodes) const { + int count = 0; + for (const string& input : node.input()) { + if (!include_controlling_nodes && IsControlInput(input)) { + break; + } + count += 1; + } + return count; + } + + // Get the number of ports in the immediate fanout of a node. Count the + // controlling nodes iff include_controlling_nodes is true. + int NumFanouts(const NodeDef& node, bool include_controlling_nodes) const { + int count = 0; + + OutputPort port; + port.node = const_cast(&node); + const int first_port_id = include_controlling_nodes ? -1 : 0; + const int last_port_id = + gtl::FindWithDefault(max_regular_output_port_, port.node, -1); + + for (int i = first_port_id; i <= last_port_id; ++i) { + port.port_id = i; + auto it = fanouts_.find(port); + if (it != fanouts_.end()) count += it->second.size(); + } + + return count; + } + + // Get all the edges in the immediate fanout (resp fanin) of a node. + // Include the control edges iff include_controlling_edges is true. + absl::flat_hash_set GetFanoutEdges( + const NodeDef& node, bool include_controlled_edges) const { + absl::flat_hash_set result; + + OutputPort port; + port.node = const_cast(&node); + const int first_port_id = include_controlled_edges ? -1 : 0; + const int last_port_id = + gtl::FindWithDefault(max_regular_output_port_, &node, -1); - // Get all the edge in the immediate fanout (resp fanin) of a node. Include - // the control edges iff include_controlling_edges is true. - std::unordered_set GetFanoutEdges( - const NodeDef& node, bool include_controlled_edges) const; - std::unordered_set GetFaninEdges( - const NodeDef& node, bool include_controlling_edges) const; + for (int i = first_port_id; i <= last_port_id; ++i) { + port.port_id = i; + auto it = fanouts_.find(port); + if (it != fanouts_.end()) { + for (auto itr = it->second.begin(); itr != it->second.end(); ++itr) { + result.emplace(/*src*/ OutputPort(const_cast(&node), i), + /*dst*/ *itr); + } + } + } + return result; + } + + absl::flat_hash_set GetFaninEdges( + const NodeDef& node, bool include_controlling_edges) const { + absl::flat_hash_set result; + for (int i = 0; i < node.input_size(); ++i) { + TensorId tensor_id = ParseTensorName(node.input(i)); + if (tensor_id.index() < 0 && !include_controlling_edges) break; + + auto it = nodes_.find(tensor_id.node()); + if (it != nodes_.end()) { + result.emplace(/*src*/ OutputPort(it->second, tensor_id.index()), + /*dst*/ InputPort(const_cast(&node), i)); + } + } + return result; + } protected: - // Add a new `node` to the graph. - void AddUniqueNodeOrDie(NodeDef* node); - // Add fanout to every `node` input. - void AddFanouts(NodeDef* node); - std::unordered_map* MutableNodes() { return &nodes_; } - GraphDef* MutableGraph() { return graph_; } - - using FanoutsMapType = - std::unordered_map, - HashPort>; - FanoutsMapType* MutableFanouts() { return &fanouts_; } + explicit GraphViewInternal(GraphDefT* graph) : graph_(graph) {} + + void AddUniqueNodeOrDie(NodeDefT* node) { + auto result = nodes_.emplace(node->name(), node); + // TODO(ezhulenev): Replace CHECK with factory method returning + // absl::StatusOr (when available). + CHECK(result.second) << "Non unique node name detected: " << node->name(); + } + + void AddFanouts(NodeDefT* node) { + for (int i = 0; i < node->input_size(); ++i) { + TensorId tensor_id = ParseTensorName(node->input(i)); + OutputPort output(nodes_[tensor_id.node()], tensor_id.index()); + + if (output.port_id < 0) { + fanouts_[output].emplace(node, -1); + } else { + max_regular_output_port_[output.node] = + std::max(max_regular_output_port_[output.node], output.port_id); + fanouts_[output].emplace(node, i); + } + } + } + + // Access to the mutable internal state for MutableGraphView. + absl::flat_hash_map& nodes() { return nodes_; } + + absl::flat_hash_map>& fanouts() { + return fanouts_; + } + + absl::flat_hash_map& max_regular_output_port() { + return max_regular_output_port_; + } private: - GraphDef* graph_; - std::unordered_map nodes_; - std::unordered_set empty_set_; - FanoutsMapType fanouts_; - std::unordered_map num_regular_outputs_; + GraphDefT* graph_; // must outlive the graph view + + // A mapping from the node name to the node itself. + absl::flat_hash_map nodes_; + + // A mapping from the output port to all inputs that read from it. + absl::flat_hash_map> fanouts_; + + // Keep a maximum index of tensor fetched from the node. It doesn't guarantee + // that all tensors in the [0, max_regular_output_port] range are actually + // fetched by other nodes. + absl::flat_hash_map max_regular_output_port_; + + // If the node has no fanouts at given output port (output tensor consumers) + // we return a reference to this set from `GetFanout` (we can't construct new + // empty set every time, because we need a non-dangling reference). + absl::flat_hash_set fanout_not_found_value_; +}; + +} // namespace internal + +// Immutable GraphView that keeps the constness of the GraphDef. If you need to +// mutate the graph or the nodes via the graph view lookup functions, see +// MutableGraphView. +class GraphView + : public internal::GraphViewInternal { + public: + explicit GraphView(const GraphDef* graph) : GraphViewInternal(graph) { + for (const NodeDef& node : graph->node()) AddUniqueNodeOrDie(&node); + for (const NodeDef& node : graph->node()) AddFanouts(&node); + } }; } // end namespace grappler diff --git a/tensorflow/core/grappler/graph_view_test.cc b/tensorflow/core/grappler/graph_view_test.cc index f90e2c8cfcd765..cbf859a4a99d7c 100644 --- a/tensorflow/core/grappler/graph_view_test.cc +++ b/tensorflow/core/grappler/graph_view_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/graph_view.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" #include "tensorflow/cc/ops/parsing_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/grappler/grappler_item.h" @@ -158,19 +160,22 @@ TEST_F(GraphViewTest, BasicGraph) { const NodeDef* add_node = graph.GetNode("AddN"); EXPECT_NE(nullptr, add_node); - string fanouts; + + absl::flat_hash_set fanouts; + absl::flat_hash_set expected_fanouts = {"AddN_2:0", "AddN_3:0"}; for (const auto& fo : graph.GetFanouts(*add_node, false)) { - strings::StrAppend(&fanouts, - strings::StrCat(fo.node->name(), ":", fo.port_id, " ")); + fanouts.insert(absl::StrCat(fo.node->name(), ":", fo.port_id)); } - EXPECT_EQ("AddN_2:0 AddN_3:0 ", fanouts); + EXPECT_EQ(graph.NumFanouts(*add_node, false), 2); + EXPECT_EQ(fanouts, expected_fanouts); - string fanins; + absl::flat_hash_set fanins; + absl::flat_hash_set expected_fanins = {"Square_1:0", "Square:0"}; for (const auto& fi : graph.GetFanins(*add_node, false)) { - strings::StrAppend(&fanins, - strings::StrCat(fi.node->name(), ":", fi.port_id, " ")); + fanins.insert(absl::StrCat(fi.node->name(), ":", fi.port_id)); } - EXPECT_EQ("Square_1:0 Square:0 ", fanins); + EXPECT_EQ(graph.NumFanins(*add_node, false), 2); + EXPECT_EQ(fanins, expected_fanins); } TEST_F(GraphViewTest, ControlDependencies) { diff --git a/tensorflow/core/grappler/mutable_graph_view.cc b/tensorflow/core/grappler/mutable_graph_view.cc index f0aff90c6c237c..1a4754153bca9b 100644 --- a/tensorflow/core/grappler/mutable_graph_view.cc +++ b/tensorflow/core/grappler/mutable_graph_view.cc @@ -14,13 +14,34 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/mutable_graph_view.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/substitute.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/grappler/utils.h" namespace tensorflow { namespace grappler { +const absl::flat_hash_set& +MutableGraphView::GetFanout(const GraphView::OutputPort& port) const { + return GetFanout(MutableGraphView::OutputPort(const_cast(port.node), + port.port_id)); +} + +absl::flat_hash_set MutableGraphView::GetFanin( + const GraphView::InputPort& port) const { + return GetFanin(MutableGraphView::InputPort(const_cast(port.node), + port.port_id)); +} + +const MutableGraphView::OutputPort MutableGraphView::GetRegularFanin( + const GraphView::InputPort& port) const { + return GetRegularFanin(MutableGraphView::InputPort( + const_cast(port.node), port.port_id)); +} + NodeDef* MutableGraphView::AddNode(NodeDef&& node) { - auto* node_in_graph = GetGraph()->add_node(); + auto* node_in_graph = graph()->add_node(); *node_in_graph = std::move(node); AddUniqueNodeOrDie(node_in_graph); @@ -29,54 +50,137 @@ NodeDef* MutableGraphView::AddNode(NodeDef&& node) { return node_in_graph; } -NodeDef* MutableGraphView::InsertNode(const NodeDef& input_node, NodeDef&& node, - const int output_port_id) { - auto* node_in_graph = GetGraph()->add_node(); - *node_in_graph = std::move(node); +void MutableGraphView::UpdateFanouts(absl::string_view from_node, + absl::string_view to_node) { + NodeDef* from_node_ptr = GetNode(from_node); + NodeDef* to_node_ptr = GetNode(to_node); + if (from_node_ptr && to_node_ptr) { + UpdateFanouts(from_node_ptr, to_node_ptr); + } else if (!from_node_ptr) { + LOG(WARNING) << absl::Substitute( + "Can't update fanouts from '$0' to '$1', from node was not found.", + from_node, to_node); + } else { + LOG(WARNING) << absl::Substitute( + "Can't update fanouts from '$0' to '$1', to node was not found.", + from_node, to_node); + } +} - AddUniqueNodeOrDie(node_in_graph); +void MutableGraphView::UpdateFanouts(NodeDef* from_node, NodeDef* to_node) { + VLOG(0) << absl::Substitute("Update fanouts from '$0' to '$1'.", + from_node->name(), to_node->name()); + + // Update internal state with the new output_port->input_port edge. + const auto add_edge = [this](const OutputPort& output_port, + const InputPort& input_port) { + fanouts()[output_port].insert(input_port); + }; + + // Remove invalidated edge from the internal state. + const auto remove_edge = [this](const OutputPort& output_port, + const InputPort& input_port) { + fanouts()[output_port].erase(input_port); + }; + + // First we update regular fanouts. For the regular fanouts + // `input_port:port_id` is the input index in NodeDef. + + auto regular_edges = + GetFanoutEdges(*from_node, /*include_controlled_edges=*/false); + + // Maximum index of the `from_node` output tensor that is still used as an + // input to some other node. + int keep_max_regular_output_port = -1; + + for (const Edge& edge : regular_edges) { + const OutputPort output_port = edge.src; + const InputPort input_port = edge.dst; + + // If the `to_node` reads from the `from_node`, skip this edge (see + // AddAndUpdateFanoutsWithoutSelfLoops test for an example). + if (input_port.node == to_node) { + keep_max_regular_output_port = + std::max(keep_max_regular_output_port, input_port.port_id); + continue; + } + + // Update input at destination node. + input_port.node->set_input( + input_port.port_id, + output_port.port_id == 0 + ? to_node->name() + : absl::StrCat(to_node->name(), ":", output_port.port_id)); + + // Remove old edge between the `from_node` and the fanout node. + remove_edge(output_port, input_port); + // Add an edge between the `to_node` and new fanout node. + add_edge(OutputPort(to_node, output_port.port_id), input_port); + } - // replace input for the output nodes of `input_node` with `node` - ReplaceInput(input_node, *node_in_graph, output_port_id); + // For the control fanouts we do not know the input index in a NodeDef, + // so we have to traverse all control inputs. + + auto control_fanouts = + GetFanout(GraphView::OutputPort(from_node, Graph::kControlSlot)); + if (control_fanouts.empty()) return; + + const string from_control_input = absl::StrCat("^", from_node->name()); + const string to_control_input = absl::StrCat("^", to_node->name()); + + for (const InputPort& control_port : control_fanouts) { + // Node can't be control dependency of itself. + if (control_port.node == to_node) continue; + + // Find and update input corresponding to control dependency. + NodeDef* node = control_port.node; + for (int i = node->input_size() - 1; i >= 0; --i) { + const string& input = node->input(i); + if (!IsControlInput(input)) break; // we reached regular inputs + if (input == from_control_input) { + node->set_input(i, to_control_input); + } + } + + // Remove old edge between the `from_node` and the fanout node. + remove_edge(OutputPort(from_node, Graph::kControlSlot), control_port); + // Add an edge between the `to_node` and new fanout node. + add_edge(OutputPort(to_node, Graph::kControlSlot), control_port); + } - AddFanouts(node_in_graph); - return node_in_graph; -} + // Because we update all regular fanouts of `from_node`, we can just copy + // the value `num_regular_outputs`. + max_regular_output_port()[to_node] = max_regular_output_port()[from_node]; -void MutableGraphView::ReplaceInput(const NodeDef& old_input, - const NodeDef& new_input, - const int output_port_id) { - GraphView::OutputPort output_port = - GetOutputPort(old_input.name(), output_port_id); - auto fanout = GetFanout(output_port); - for (auto& input_port : fanout) { - input_port.node->set_input(input_port.port_id, new_input.name()); - AddFanouts(input_port.node); + // Check if all fanouts were updated to read from the `to_node`. + if (keep_max_regular_output_port >= 0) { + max_regular_output_port()[from_node] = keep_max_regular_output_port; + } else { + max_regular_output_port().erase(from_node); } } void MutableGraphView::DeleteNodes(const std::set& nodes_to_delete) { for (const string& node_name_to_delete : nodes_to_delete) - RemoveFanouts(MutableNodes()->at(node_name_to_delete)); + RemoveFanouts(nodes().at(node_name_to_delete)); for (const string& node_name_to_delete : nodes_to_delete) - MutableNodes()->erase(node_name_to_delete); - EraseNodesFromGraph(nodes_to_delete, GetGraph()); + nodes().erase(node_name_to_delete); + EraseNodesFromGraph(nodes_to_delete, graph()); } -void MutableGraphView::RemoveFanouts(NodeDef* node) { - for (int i = 0; i < node->input_size(); ++i) { - OutputPort fanin; - string fanin_name = ParseNodeName(node->input(i), &fanin.port_id); - fanin.node = (*MutableNodes())[fanin_name]; +void MutableGraphView::RemoveFanouts(NodeDef* deleted_node) { + for (int i = 0; i < deleted_node->input_size(); ++i) { + TensorId tensor_id = ParseTensorName(deleted_node->input(i)); + OutputPort fanin(nodes()[tensor_id.node()], tensor_id.index()); InputPort input; - input.node = node; - if (fanin.port_id < 0) - input.port_id = -1; + input.node = deleted_node; + if (tensor_id.index() < 0) + input.port_id = Graph::kControlSlot; else input.port_id = i; - (*MutableFanouts())[fanin].erase(input); + fanouts()[fanin].erase(input); } } diff --git a/tensorflow/core/grappler/mutable_graph_view.h b/tensorflow/core/grappler/mutable_graph_view.h index 971e5503d4ce90..355dd6c491763e 100644 --- a/tensorflow/core/grappler/mutable_graph_view.h +++ b/tensorflow/core/grappler/mutable_graph_view.h @@ -24,37 +24,64 @@ namespace grappler { // A utility class to simplify the traversal of a GraphDef that, unlike // GraphView, supports updating the graph. Note that you should not modify the // graph separately, because the view will get out of sync. -class MutableGraphView : public GraphView { + +class MutableGraphView : public internal::GraphViewInternal { public: - using GraphView::GraphView; + explicit MutableGraphView(GraphDef* graph) : GraphViewInternal(graph) { + for (NodeDef& node : *graph->mutable_node()) AddUniqueNodeOrDie(&node); + for (NodeDef& node : *graph->mutable_node()) AddFanouts(&node); + } - GraphDef* GetGraph() { return MutableGraph(); } + // Lookup fanouts/fanins using immutable ports. + using GraphViewInternal::GetFanout; + const absl::flat_hash_set& GetFanout( + const GraphView::OutputPort& port) const; - // Adds a new node to graph and updates the view. - NodeDef* AddNode(NodeDef&& node); + using GraphViewInternal::GetFanin; + absl::flat_hash_set GetFanin( + const GraphView::InputPort& port) const; - // Inserts a new node to the graph after `input` node and updates the view. - // This adds `node` to the graph and replaces the input for the output - // nodes of `input` with a port `output_port_id` with the new node. - NodeDef* InsertNode(const NodeDef& input, NodeDef&& node, - int output_port_id = 0); + using GraphViewInternal::GetRegularFanin; + const OutputPort GetRegularFanin(const GraphView::InputPort& port) const; - // Replaces the input for the output nodes of 'old_input' with a port - // `output_port_id` with 'new_input'. + // Adds a new node to graph and updates the view. Returns a pointer to the + // node in graph. + NodeDef* AddNode(NodeDef&& node); + + // Updates all fanouts (input ports fetching output tensors) from `from_node` + // to the `to_node`, including control dependencies. + // + // Example: We have 2 nodes that use `bar` node output tensors as inputs: + // 1. foo1(bar:0, bar:1, other:0, ^bar) + // 2. foo2(bar:1, other:1) // - // E.g: We have 2 nodes that use 'bar' node outputs as inputs: - // foo(bar:0, bar:1), foo2(other:0, bar:0) - // Calling ReplaceInput(bar, new, 0) changes every occurrence of bar:0 for - // new:0. Result: - // foo(new:0, bar:1), foo2(other:0, new:0) - void ReplaceInput(const NodeDef& old_input, const NodeDef& new_input, - int output_port_id = 0); + // After calling ForwardOutputs(bar, new_bar): + // 1. foo1(new_bar:0, new_bar:1, other:0, ^new_bar) + // 2. foo2(new_bar:1, other:1) + void UpdateFanouts(absl::string_view from_node, absl::string_view to_node); // Deletes nodes from the graph. void DeleteNodes(const std::set& nodes_to_delete); private: - void RemoveFanouts(NodeDef* node); + // Updates all fanouts (input ports fetching output tensors) from `from_node` + // to the `to_node`, including control dependencies. + // + // Example: We have 2 nodes that use `bar` node output tensors as inputs: + // 1. foo1(bar:0, bar:1, other:0, ^bar) + // 2. foo2(bar:1, other:1) + // + // After calling ForwardOutputs(bar, new_bar): + // 1. foo1(new_bar:0, new_bar:1, other:0, ^new_bar) + // 2. foo2(new_bar:1, other:1) + // + // IMPORTANT: If `from_node` or `to_node` is not in the underlying graph, the + // behavior is undefined. + void UpdateFanouts(NodeDef* from_node, NodeDef* to_node); + + // Remove fanouts of the deleted node from internal state (including control + // dependencies). + void RemoveFanouts(NodeDef* deleted_node); }; } // end namespace grappler diff --git a/tensorflow/core/grappler/mutable_graph_view_test.cc b/tensorflow/core/grappler/mutable_graph_view_test.cc index 2536bec35ddcf7..c1b3f8c01cf3db 100644 --- a/tensorflow/core/grappler/mutable_graph_view_test.cc +++ b/tensorflow/core/grappler/mutable_graph_view_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" #include "tensorflow/core/platform/test.h" @@ -23,103 +24,122 @@ namespace tensorflow { namespace grappler { namespace { -bool FindChildWithName(const MutableGraphView& graph, - const string& output_port_name, - const string& input_name) { - GraphView::OutputPort output_port = graph.GetOutputPort(output_port_name, 0); - auto fanout = graph.GetFanout(output_port); - for (auto& input_port : fanout) { - if (input_port.node->name() == input_name) return true; - } - return false; +using ::tensorflow::test::function::NDef; + +TEST(MutableGraphViewTest, AddAndUpdateFanouts) { + // Actual node.op() is not important in this test. + GraphDef graph_def = test::function::GDef( + {NDef("bar", "NotImportant", {}, {}), + NDef("other", "NotImportant", {}, {}), + NDef("foo_1", "NotImportant", {"bar", "other", "bar:1", "^bar"}), + NDef("foo_2", "NotImportant", {"other:1", "bar:2", "^bar"})}, + /* empty function library */ {}); + + MutableGraphView graph(&graph_def); + + NodeDef* new_bar = graph.AddNode(NDef("new_bar", "NotImportant", {}, {})); + NodeDef* bar = graph.GetNode("bar"); + + graph.UpdateFanouts(bar->name(), new_bar->name()); + + // Fanout nodes must have their inputs updated. + NodeDef* foo_1 = graph.GetNode("foo_1"); + ASSERT_NE(foo_1, nullptr); + ASSERT_EQ(foo_1->input_size(), 4); + EXPECT_EQ(foo_1->input(0), "new_bar"); + EXPECT_EQ(foo_1->input(1), "other"); + EXPECT_EQ(foo_1->input(2), "new_bar:1"); + EXPECT_EQ(foo_1->input(3), "^new_bar"); + + NodeDef* foo_2 = graph.GetNode("foo_2"); + ASSERT_NE(foo_2, nullptr); + ASSERT_EQ(foo_2->input_size(), 3); + EXPECT_EQ(foo_2->input(0), "other:1"); + EXPECT_EQ(foo_2->input(1), "new_bar:2"); + EXPECT_EQ(foo_2->input(2), "^new_bar"); + + // And fanouts mapping must be also updated for both nodes. + bool include_control_fanouts = true; + auto old_node_fanouts = graph.GetFanouts(*bar, include_control_fanouts); + auto new_node_fanouts = graph.GetFanouts(*new_bar, include_control_fanouts); + + EXPECT_TRUE(old_node_fanouts.empty()); + EXPECT_EQ(new_node_fanouts.count(MutableGraphView::InputPort(foo_1, 0)), 1); + EXPECT_EQ(new_node_fanouts.count(MutableGraphView::InputPort(foo_1, 2)), 1); + EXPECT_EQ(new_node_fanouts.count(MutableGraphView::InputPort(foo_1, -1)), 1); + EXPECT_EQ(new_node_fanouts.count(MutableGraphView::InputPort(foo_2, 1)), 1); + EXPECT_EQ(new_node_fanouts.count(MutableGraphView::InputPort(foo_2, -1)), 1); } -TrivialTestGraphInputYielder SimpleGraph() { - // This outputs simple graph like: - // x - // / \ - // Square Square_1 - // | \ / | - // | \/ | - // | /\ | - // | / \ | - // AddN AddN_1 - // \ / - // y - TrivialTestGraphInputYielder simple_graph(2, 2, 2, false, - {"/CPU:0", "/GPU:0"}); - return simple_graph; -} - -TEST(MutableGraphViewTest, AddAndReplaceInput) { - TrivialTestGraphInputYielder fake_input = SimpleGraph(); - GrapplerItem item; - CHECK(fake_input.NextItem(&item)); +TEST(MutableGraphViewTest, AddAndUpdateFanoutsWithoutSelfLoops) { + // Actual node.op() is not important in this test. + GraphDef graph_def = + test::function::GDef({NDef("bar", "NotImportant", {}, {}), + NDef("foo", "NotImportant", {"bar", "^bar"})}, + /* empty function library */ {}); - GraphDef new_graph = item.graph; - MutableGraphView graph(&new_graph); + MutableGraphView graph(&graph_def); - GraphView::InputPort input = graph.GetInputPort("AddN", 0); - EXPECT_EQ("AddN", input.node->name()); - EXPECT_EQ(0, input.port_id); - GraphView::OutputPort fanin = graph.GetRegularFanin(input); - EXPECT_EQ("Square", fanin.node->name()); - EXPECT_EQ(0, fanin.port_id); + // `new_bar` reads the output of an original `bar` node. + NodeDef* new_bar = graph.AddNode(NDef("new_bar", "NewBar", {"bar"}, {})); + NodeDef* bar = graph.GetNode("bar"); - EXPECT_FALSE(FindChildWithName(graph, "Square", "new_node")); + graph.UpdateFanouts("bar", new_bar->name()); - NodeDef new_node = *input.node; - new_node.set_name("new_node"); + // Foo node must read from `new_bar`. + NodeDef* foo = graph.GetNode("foo"); + ASSERT_NE(foo, nullptr); + ASSERT_EQ(foo->input_size(), 2); + EXPECT_EQ(foo->input(0), "new_bar"); + EXPECT_EQ(foo->input(1), "^new_bar"); - EXPECT_EQ(graph.GetNode("new_node"), nullptr); - NodeDef* node_in_graph = graph.AddNode(std::move(new_node)); - EXPECT_NE(graph.GetNode("new_node"), nullptr); + // And the `new_bar` should read from the original `bar`. + ASSERT_EQ(new_bar->input_size(), 1); + ASSERT_EQ(new_bar->input(0), "bar"); - graph.ReplaceInput(*input.node, *node_in_graph); - EXPECT_TRUE(FindChildWithName(graph, "Square", "new_node")); - EXPECT_TRUE(FindChildWithName(graph, "new_node", "y")); -} + // And fanouts mapping must be also updated for both nodes. + bool include_control_fanouts = true; + auto bar_fanouts = graph.GetFanouts(*bar, include_control_fanouts); + auto new_bar_fanouts = graph.GetFanouts(*new_bar, include_control_fanouts); -TEST(MutableGraphViewTest, InsertNodes) { - TrivialTestGraphInputYielder fake_input = SimpleGraph(); + EXPECT_EQ(bar_fanouts.size(), 1); + EXPECT_EQ(bar_fanouts.count(MutableGraphView::InputPort(new_bar, 0)), 1); - GrapplerItem item; - CHECK(fake_input.NextItem(&item)); + EXPECT_EQ(new_bar_fanouts.size(), 2); + EXPECT_EQ(new_bar_fanouts.count(MutableGraphView::InputPort(foo, 0)), 1); + EXPECT_EQ(new_bar_fanouts.count(MutableGraphView::InputPort(foo, -1)), 1); +} - GraphDef new_graph = item.graph; - MutableGraphView graph(&new_graph); +TEST(MutableGraphViewTest, DeleteNodes) { + // Actual node.op() is not important in this test. + GraphDef graph_def = test::function::GDef( + {NDef("bar", "NotImportant", {}, {}), + NDef("other", "NotImportant", {}, {}), + NDef("foo_1", "NotImportant", {"bar", "other", "bar:1", "^bar"}), + NDef("foo_2", "NotImportant", {"other:1", "bar:2", "^bar"})}, + /* empty function library */ {}); - GraphView::InputPort input = graph.GetInputPort("AddN", 0); + MutableGraphView graph(&graph_def); - NodeDef new_node = *input.node; - new_node.set_name("new_node"); - new_node.set_input(0, input.node->name()); + EXPECT_NE(graph.GetNode("foo_1"), nullptr); + graph.DeleteNodes({"foo_1"}); - EXPECT_EQ(graph.GetNode("new_node"), nullptr); - graph.InsertNode(*input.node, std::move(new_node)); - EXPECT_NE(graph.GetNode("new_node"), nullptr); - EXPECT_TRUE(FindChildWithName(graph, "Square", "AddN")); - EXPECT_TRUE(FindChildWithName(graph, "Square", "AddN_1")); - EXPECT_TRUE(FindChildWithName(graph, "Square_1", "AddN")); - EXPECT_TRUE(FindChildWithName(graph, "Square_1", "AddN_1")); - EXPECT_TRUE(FindChildWithName(graph, "AddN", "new_node")); - EXPECT_TRUE(FindChildWithName(graph, "AddN_1", "y")); - EXPECT_TRUE(FindChildWithName(graph, "new_node", "y")); -} + EXPECT_EQ(graph.GetNode("foo_1"), nullptr); -TEST(MutableGraphViewTest, DeleteNodes) { - // Outputs simple graph as described in first test. - TrivialTestGraphInputYielder fake_input = SimpleGraph(); - GrapplerItem item; - CHECK(fake_input.NextItem(&item)); + NodeDef* bar = graph.GetNode("bar"); + NodeDef* other = graph.GetNode("other"); + NodeDef* foo_2 = graph.GetNode("foo_2"); - GraphDef new_graph = item.graph; - MutableGraphView graph(&new_graph); + bool include_control_fanouts = true; + auto bar_fanouts = graph.GetFanouts(*bar, include_control_fanouts); + auto other_fanouts = graph.GetFanouts(*other, include_control_fanouts); - EXPECT_NE(graph.GetNode("AddN"), nullptr); - graph.DeleteNodes({"AddN"}); + EXPECT_EQ(bar_fanouts.size(), 2); + EXPECT_EQ(bar_fanouts.count(MutableGraphView::InputPort(foo_2, 1)), 1); + EXPECT_EQ(bar_fanouts.count(MutableGraphView::InputPort(foo_2, -1)), 1); - EXPECT_EQ(graph.GetNode("AddN"), nullptr); + EXPECT_EQ(other_fanouts.size(), 1); + EXPECT_EQ(other_fanouts.count(MutableGraphView::InputPort(foo_2, 0)), 1); } } // namespace diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 127c1603ba4974..3a5b1334d3f42e 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -145,8 +145,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:mutable_graph_view", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/utils:functions", @@ -212,6 +212,8 @@ cc_library( hdrs = ["graph_optimizer_stage.h"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/core:graph", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", @@ -422,8 +424,8 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:mutable_graph_view", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:cluster", @@ -625,12 +627,13 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:mutable_graph_view", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/costs:graph_properties", "//tensorflow/core/grappler/utils:frame", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -663,8 +666,8 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:mutable_graph_view", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/costs:graph_properties", diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index a5d618d4f737f5..cf294cd20bb3dc 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -235,18 +235,17 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage { // TODO(ezhulenev): move to GraphOptimizerStage? bool IsDrivenByControlDependency(const NodeDef& node) const { - return std::any_of(node.input().begin(), node.input().end(), - IsControlInput); + return std::any_of( + node.input().begin(), node.input().end(), + [](const string& input) { return IsControlInput(input); }); } // TODO(ezhulenev): move to GraphOptimizerStage? bool DrivesControlDependency(const NodeDef& node) const { - int position; for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) { for (int i = 0; i < output->input_size(); ++i) { - auto input = output->input(i); - StringPiece name = ParseNodeNameAsStringPiece(input, &position); - if (name == node.name() && /*control input*/ position < 0) { + const TensorId tensor = ParseTensorName(output->input(i)); + if (tensor.node() == node.name() && tensor.index() < 0) { return true; } } @@ -1551,11 +1550,9 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { const auto& outputs = ctx().node_map->GetOutputs(node.name()); for (NodeDef* output : outputs) { if (IsControlInput(output->input(0))) continue; - int port; - const StringPiece node_name = - ParseNodeNameAsStringPiece(output->input(0), &port); - if (node_name == node.name()) { - tails->insert(ChainLink(output, port)); + TensorId tensor_id = ParseTensorName(output->input(0)); + if (tensor_id.node() == node.name()) { + tails->insert(ChainLink(output, tensor_id.index())); } else { // This output node has a non-control input other than the split node, // abort. @@ -1602,14 +1599,12 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { new_tails->insert(ChainLink(new_tail, link.port_origin)); } else { for (NodeDef* new_tail : ctx().node_map->GetOutputs(tail->name())) { - int port; - const StringPiece node_name = - ParseNodeNameAsStringPiece(new_tail->input(0), &port); - if (node_name != tail->name()) { + const TensorId tensor = ParseTensorName(new_tail->input(0)); + if (tensor.node() != tail->name()) { return Status::OK(); } // Skip control outputs. - if (port >= 0) { + if (tensor.index() >= 0) { // Remember original port. new_tails->insert(ChainLink(new_tail, link.port_origin)); } @@ -1763,7 +1758,8 @@ class SqrtDivToRsqrtMulStage : public ArithmeticOptimizerStage { TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y)); // Optimize only if divisor is a Sqrt whose output is not being consumed // elsewhere. - if (IsSqrt(*y) && (NumNonControlOutputs(*y, *ctx().node_map) == 1)) { + if (IsSqrt(*y) && !IsInPreserveSet(*y) && + (NumNonControlOutputs(*y, *ctx().node_map) == 1)) { // a / sqrt(b) = a * rsqrt(b) node->set_op("Mul"); y->set_op("Rsqrt"); @@ -1922,15 +1918,16 @@ class ReorderCastLikeAndValuePreserving : public ArithmeticOptimizerStage { NodeDef* new_consumer = AddCopyNode(optimized_producer_name, producer); new_consumer->set_input(0, new_producer->name()); + NodeDef* new_value_preserving = + producer_is_cast ? new_producer : new_consumer; + const DataType new_input_type = + producer_is_cast ? cast_src_type : cast_dst_type; // Update the input type of the value-preserving node. The input and // output types of the cast-like nodes remain the same. - if (producer_is_cast) { - // Op(Cast()) -> Cast(Op()) - TF_RETURN_IF_ERROR(SetInputType(cast_src_type, new_producer)); - } else { - // Cast(Op()) -> Op(Cast()) - TF_RETURN_IF_ERROR(SetInputType(cast_dst_type, new_consumer)); - } + TF_RETURN_IF_ERROR(SetInputType(new_input_type, new_value_preserving)); + // Make sure there is a kernel registered for the value preserving op + // with the new input type. + TF_RETURN_IF_ERROR(IsKernelRegisteredForNode(*new_value_preserving)); ctx().node_map->AddOutput(new_producer->name(), new_consumer->name()); AddToOptimizationQueue(new_producer); @@ -3246,10 +3243,10 @@ uint64 UniqueNodes::ComputeSignature(const NodeDef& node) const { h = Hash64Combine(Hash64(node.device()), h); for (const auto& input : node.input()) { - int pos; - const StringPiece node_name = ParseNodeNameAsStringPiece(input, &pos); - h = Hash64CombineUnordered(Hash64(node_name.data(), node_name.size()), h); - h = Hash64CombineUnordered(std::hash()(pos), h); + const TensorId input_tensor = ParseTensorName(input); + h = Hash64CombineUnordered( + Hash64(input_tensor.node().data(), input_tensor.node().size()), h); + h = Hash64CombineUnordered(std::hash()(input_tensor.index()), h); } for (const auto& attr : node.attr()) { h = Hash64CombineUnordered(Hash64(attr.first), h); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 4c6dcee539c330..b6286c425e51b6 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -1388,8 +1388,6 @@ TEST_F(ArithmeticOptimizerTest, ReorderS2DCast_ProducerIsCast) { ArithmeticOptimizer optimizer; OptimizeAndPrune(&optimizer, &item, &output); - LOG(INFO) << output.DebugString(); - const NodeDef* s2d_node = nullptr; for (const NodeDef& node : output.node()) { if (node.op() == "SpaceToDepth") { @@ -1862,7 +1860,7 @@ TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) { // Conv2D(Transpose(Cast(I)), W*S) // => // Conv2D(Cast(Transpose(I)), W*S) - tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/gpu:0"); + tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/cpu:0"); Output inputs = ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3})); @@ -1883,7 +1881,6 @@ TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) { GraphDef output; ArithmeticOptimizer optimizer; // all optimization stages are on OptimizeTwiceAndPrune(&optimizer, &item, &output, /*const_folding=*/true); - LOG(INFO) << output.DebugString(); NodeMap node_map(&output); // Expected names for reordered cast and transpose. @@ -1918,7 +1915,7 @@ TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) { TEST_F(ArithmeticOptimizerTest, OptimizeMultipleMulTransposeConv) { // This unit test exercises optimization of folding mul into conv for // multiple nodes in the graph. - tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/gpu:0"); + tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/cpu:0"); GrapplerItem item; Output conv[2]; @@ -2649,6 +2646,48 @@ TEST_F(ArithmeticOptimizerTest, ConvertSqrtDivToRsqrtMul) { } } +TEST_F(ArithmeticOptimizerTest, DoNotConvertSqrtDivToRsqrtMulDivisorFetchNode) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output floats = ops::Const(s.WithOpName("floats"), + {0.7423212f, 0.19757693f, 0.53124744f}, {1, 3}); + Output output0 = ops::Sqrt(s.WithOpName("output0"), floats); + Output const1 = ops::Const(s.WithOpName("const1"), 1.0f, {3}); + Output mul1 = ops::Multiply(s.WithOpName("mul1"), const1, 0.5f); + Output grad = ops::Div(s.WithOpName("grad"), mul1, output0); + + GrapplerItem item; + item.fetch = {"grad", "output0"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + ASSERT_EQ(2, tensors_expected.size()); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlySqrtDivToRsqrtMul(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); + auto tensors = EvaluateNodes(output, item.fetch); + ASSERT_EQ(2, tensors.size()); + + for (int i = 0; i < tensors.size(); i++) { + EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements()); + test::ExpectTensorNear(tensors_expected[i], tensors[i], 1e-6); + } + EXPECT_EQ(item.graph.node_size(), output.node_size()); + for (int i = 0; i < output.node_size(); ++i) { + const NodeDef& node = output.node(i); + if (node.name() == "grad") { + EXPECT_EQ("Div", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("mul1", node.input(0)); + EXPECT_EQ("output0", node.input(1)); + } else if (node.name() == "output0") { + EXPECT_EQ("Sqrt", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("floats", node.input(0)); + } + } +} + TEST_F(ArithmeticOptimizerTest, ConvertPow) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 8107d383f6b839..032f41c9d25a28 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -157,6 +157,16 @@ bool GetConcatAxis(const GraphProperties& properties, NodeDef* node, return true; } +bool HasTPUAttributes(const NodeDef& node) { + AttrSlice attrs(node); + for (auto attr : attrs) { + if (attr.first.find("_tpu_") != attr.first.npos) { + return true; + } + } + return false; +} + } // namespace ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level, @@ -764,6 +774,13 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { return false; } + // Don't fold nodes that contain TPU attributes. + // TODO(rmlarsen): We should be able to fold many of these nodes as long as we + // properly forward custom attributes, b/119051778. + if (HasTPUAttributes(node)) { + return false; + } + const OpDef* op_def = nullptr; Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); if (!status.ok()) { @@ -988,9 +1005,8 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node, }); for (const auto& input : node.input()) { - int port = 0; - ParseNodeNameAsStringPiece(input, &port); - if (port < 0) { + const TensorId input_tensor = ParseTensorName(input); + if (input_tensor.index() < 0) { // Control dependency break; } @@ -1129,9 +1145,12 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph, std::vector const_nodes; TF_RETURN_IF_ERROR( EvaluateOneFoldable(*node, &const_nodes, result_too_large)); + VLOG(1) << "Folded node:\n" << node->DebugString(); + NodeDef* constant_output = nullptr; for (int i = 0; i < const_nodes.size(); i++) { NodeDef* const_node = &const_nodes[i]; + VLOG(1) << "Generated constant node:\n" << const_node->DebugString(); if (const_node->name().empty()) { // Dead output: we can't create a constant to encode its value, so we'll // just skip it. We'll preserve the edges that originate from that @@ -1596,15 +1615,19 @@ Status ConstantFolding::ReplaceOperationWithConstant( Status ConstantFolding::SimplifyGraph( bool use_shape_info, GraphDef* optimized_graph, GraphProperties* properties, - const absl::flat_hash_set& nodes_to_not_simplify) { + absl::flat_hash_set* nodes_to_not_simplify) { for (int i = 0; i < optimized_graph->node_size(); ++i) { + NodeDef* node = optimized_graph->mutable_node(i); // TODO(lyandy): Move nodes to not simplify check into SimplifyNode and // generalize to only restrict certain simplifications. - if (nodes_to_not_simplify.find(optimized_graph->node(i).name()) == - nodes_to_not_simplify.end()) { - TF_RETURN_IF_ERROR(SimplifyNode(use_shape_info, - optimized_graph->mutable_node(i), - optimized_graph, properties)); + if (nodes_to_not_simplify->find(node->name()) == + nodes_to_not_simplify->end()) { + if (HasTPUAttributes(optimized_graph->node(i))) { + nodes_to_not_simplify->insert(node->name()); + continue; + } + TF_RETURN_IF_ERROR( + SimplifyNode(use_shape_info, node, optimized_graph, properties)); } } return Status::OK(); @@ -3043,7 +3066,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster, TF_RETURN_IF_ERROR(FoldGraph(optimized_graph, &nodes_to_not_simplify)); node_map_.reset(new NodeMap(optimized_graph)); TF_RETURN_IF_ERROR(SimplifyGraph(can_use_shape_info, optimized_graph, - &properties, nodes_to_not_simplify)); + &properties, &nodes_to_not_simplify)); return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index c81d3067d50d1c..d1898cdb04938a 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -100,10 +100,9 @@ class ConstantFolding : public GraphOptimizer { const GraphProperties& properties) const; bool IsSimplifiableReshape(const NodeDef& node, const GraphProperties& properties) const; - Status SimplifyGraph( - bool use_shape_info, GraphDef* optimized_graph, - GraphProperties* properties, - const absl::flat_hash_set& nodes_to_not_simplify); + Status SimplifyGraph(bool use_shape_info, GraphDef* optimized_graph, + GraphProperties* properties, + absl::flat_hash_set* nodes_to_not_simplify); Status SimplifyNode(bool use_shape_info, NodeDef* node, GraphDef* optimized_graph, GraphProperties* properties); diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index 244480a7516b61..89e95067b83d70 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -591,6 +591,7 @@ cc_library( deps = [ ":function_utils", ":graph_utils", + "@com_google_absl//absl/container:flat_hash_set", "//tensorflow/cc:ops", "@com_google_absl//absl/strings", "//tensorflow/core:core_cpu", diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc index 1ad495bbad023b..89b568ecf161cd 100644 --- a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc @@ -37,7 +37,7 @@ NodeDef MakeFusedFilterNode(const NodeDef& first_filter_node, const FunctionDef& fused_function, MutableGraphView* graph) { NodeDef fused_node; - graph_utils::SetUniqueGraphNodeName("fused_filter", graph->GetGraph(), + graph_utils::SetUniqueGraphNodeName("fused_filter", graph->graph(), &fused_node); fused_node.set_op("FilterDataset"); @@ -109,7 +109,7 @@ Status FilterFusion::Optimize(Cluster* cluster, const GrapplerItem& item, const auto* fused_filter_node = graph.AddNode(MakeFusedFilterNode( *first_filter_node, *second_filter_node, *fused_predicate, &graph)); - graph.ReplaceInput(*second_filter_node, *fused_filter_node); + graph.UpdateFanouts(second_filter_node->name(), fused_filter_node->name()); // TODO(prazek): we should run some optimizations on the fused filter // functions, or make sure that optimization passes run after filter diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc index b863a25dc5f699..90208c1fba6b08 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc @@ -72,7 +72,7 @@ NodeDef* AddScalarConstNodeHelper( MutableGraphView* graph) { NodeDef node; node.set_op(kConstOpName); - SetUniqueGraphNodeName(kConstOpName, graph->GetGraph(), &node); + SetUniqueGraphNodeName(kConstOpName, graph->graph(), &node); (*node.mutable_attr())["dtype"].set_type(dtype); std::unique_ptr tensor = @@ -92,7 +92,7 @@ NodeDef* AddScalarConstNodeHelper( NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph) { NodeDef node; node.set_op("Placeholder"); - SetUniqueGraphNodeName(node.op(), graph->GetGraph(), &node); + SetUniqueGraphNodeName(node.op(), graph->graph(), &node); (*node.mutable_attr())["dtype"].set_type(dtype); TensorShapeProto* shape = (*node.mutable_attr())["shape"].mutable_shape(); shape->set_unknown_rank(false); @@ -107,7 +107,7 @@ NodeDef* AddNode(StringPiece name, StringPiece op, if (!name.empty()) { node.set_name(string(name)); } else { - SetUniqueGraphNodeName(op, graph->GetGraph(), &node); + SetUniqueGraphNodeName(op, graph->graph(), &node); } node.set_op(string(op)); for (const string& input : inputs) { @@ -228,7 +228,7 @@ std::vector FindAllGraphNodesWithOp(const string& op, NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) { if (node.input_size() == 0) return nullptr; - GraphView::InputPort input_port = graph.GetInputPort(node.name(), 0); + MutableGraphView::InputPort input_port = graph.GetInputPort(node.name(), 0); return graph.GetRegularFanin(input_port).node; } diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc index 4ab6d71532ce00..5c0f03dca8774d 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc @@ -41,7 +41,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeBool) { GraphDef graph_def; MutableGraphView graph(&graph_def); NodeDef* bool_node = AddScalarConstNode(true, &graph); - EXPECT_TRUE(ContainsGraphNodeWithName(bool_node->name(), *graph.GetGraph())); + EXPECT_TRUE(ContainsGraphNodeWithName(bool_node->name(), *graph.graph())); EXPECT_EQ(bool_node->attr().at("value").tensor().bool_val(0), true); } @@ -49,8 +49,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeDouble) { GraphDef graph_def; MutableGraphView graph(&graph_def); NodeDef* double_node = AddScalarConstNode(3.14, &graph); - EXPECT_TRUE( - ContainsGraphNodeWithName(double_node->name(), *graph.GetGraph())); + EXPECT_TRUE(ContainsGraphNodeWithName(double_node->name(), *graph.graph())); EXPECT_FLOAT_EQ(double_node->attr().at("value").tensor().double_val(0), 3.14); } @@ -58,7 +57,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeFloat) { GraphDef graph_def; MutableGraphView graph(&graph_def); NodeDef* float_node = AddScalarConstNode(3.14, &graph); - EXPECT_TRUE(ContainsGraphNodeWithName(float_node->name(), *graph.GetGraph())); + EXPECT_TRUE(ContainsGraphNodeWithName(float_node->name(), *graph.graph())); EXPECT_FLOAT_EQ(float_node->attr().at("value").tensor().float_val(0), 3.14); } @@ -66,7 +65,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeInt) { GraphDef graph_def; MutableGraphView graph(&graph_def); NodeDef* int_node = AddScalarConstNode(42, &graph); - EXPECT_TRUE(ContainsGraphNodeWithName(int_node->name(), *graph.GetGraph())); + EXPECT_TRUE(ContainsGraphNodeWithName(int_node->name(), *graph.graph())); EXPECT_EQ(int_node->attr().at("value").tensor().int_val(0), 42); } @@ -74,7 +73,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeInt64) { GraphDef graph_def; MutableGraphView graph(&graph_def); NodeDef* int64_node = AddScalarConstNode(42, &graph); - EXPECT_TRUE(ContainsGraphNodeWithName(int64_node->name(), *graph.GetGraph())); + EXPECT_TRUE(ContainsGraphNodeWithName(int64_node->name(), *graph.graph())); EXPECT_EQ(int64_node->attr().at("value").tensor().int64_val(0), 42); } @@ -82,8 +81,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeString) { GraphDef graph_def; MutableGraphView graph(&graph_def); NodeDef* string_node = AddScalarConstNode("hello", &graph); - EXPECT_TRUE( - ContainsGraphNodeWithName(string_node->name(), *graph.GetGraph())); + EXPECT_TRUE(ContainsGraphNodeWithName(string_node->name(), *graph.graph())); EXPECT_EQ(string_node->attr().at("value").tensor().string_val(0), "hello"); } @@ -106,13 +104,13 @@ TEST(GraphUtilsTest, Compare) { TEST(GraphUtilsTest, ContainsGraphNodeWithName) { GraphDef graph_def; MutableGraphView graph(&graph_def); - EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.GetGraph())); + EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.graph())); AddNode("A", "OpA", {}, {}, &graph); - EXPECT_TRUE(ContainsGraphNodeWithName("A", *graph.GetGraph())); + EXPECT_TRUE(ContainsGraphNodeWithName("A", *graph.graph())); graph.DeleteNodes({"A"}); - EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.GetGraph())); + EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.graph())); } TEST(GraphUtilsTest, ContainsGraphFunctionWithName) { @@ -128,25 +126,25 @@ TEST(GraphUtilsTest, ContainsGraphFunctionWithName) { TEST(GraphUtilsTest, ContainsNodeWithOp) { GraphDef graph_def; MutableGraphView graph(&graph_def); - EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.GetGraph())); + EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.graph())); AddNode("A", "OpA", {}, {}, &graph); - EXPECT_TRUE(ContainsNodeWithOp("OpA", *graph.GetGraph())); + EXPECT_TRUE(ContainsNodeWithOp("OpA", *graph.graph())); graph.DeleteNodes({"A"}); - EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.GetGraph())); + EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.graph())); } TEST(GraphUtilsTest, FindGraphNodeWithName) { GraphDef graph_def; MutableGraphView graph(&graph_def); - EXPECT_EQ(FindGraphNodeWithName("A", *graph.GetGraph()), -1); + EXPECT_EQ(FindGraphNodeWithName("A", *graph.graph()), -1); AddNode("A", "OpA", {}, {}, &graph); - EXPECT_NE(FindGraphNodeWithName("A", *graph.GetGraph()), -1); + EXPECT_NE(FindGraphNodeWithName("A", *graph.graph()), -1); graph.DeleteNodes({"A"}); - EXPECT_EQ(FindGraphNodeWithName("A", *graph.GetGraph()), -1); + EXPECT_EQ(FindGraphNodeWithName("A", *graph.graph()), -1); } TEST(GraphUtilsTest, FindGraphFunctionWithName) { @@ -162,35 +160,35 @@ TEST(GraphUtilsTest, FindGraphFunctionWithName) { TEST(GraphUtilsTest, FindGraphNodeWithOp) { GraphDef graph_def; MutableGraphView graph(&graph_def); - EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), -1); + EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), -1); AddNode("A", "OpA", {}, {}, &graph); AddNode("B", "OpB", {"A"}, {}, &graph); AddNode("A2", "OpA", {"B"}, {}, &graph); - EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), 0); + EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), 0); graph.DeleteNodes({"B"}); - EXPECT_EQ(FindGraphNodeWithOp("OpB", *graph.GetGraph()), -1); - EXPECT_EQ(FindGraphNodeWithName("A2", *graph.GetGraph()), 1); + EXPECT_EQ(FindGraphNodeWithOp("OpB", *graph.graph()), -1); + EXPECT_EQ(FindGraphNodeWithName("A2", *graph.graph()), 1); } TEST(GraphUtilsTest, FindAllGraphNodesWithOp) { GraphDef graph_def; MutableGraphView graph(&graph_def); - EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), -1); + EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), -1); AddNode("A", "OpA", {}, {}, &graph); AddNode("B", "OpB", {"A"}, {}, &graph); AddNode("A2", "OpA", {"B"}, {}, &graph); std::vector result_indices = - FindAllGraphNodesWithOp("OpA", *graph.GetGraph()); + FindAllGraphNodesWithOp("OpA", *graph.graph()); EXPECT_EQ(result_indices.size(), 2); EXPECT_EQ(result_indices.at(0), 0); EXPECT_EQ(result_indices.at(1), 2); graph.DeleteNodes({"A2"}); std::vector result_indices_new = - FindAllGraphNodesWithOp("OpA", *graph.GetGraph()); + FindAllGraphNodesWithOp("OpA", *graph.graph()); EXPECT_EQ(result_indices_new.size(), 1); EXPECT_EQ(result_indices_new.at(0), 0); } diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc index ce0b2db03963b2..5af9fbadf76bfd 100644 --- a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc +++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc @@ -39,7 +39,7 @@ NodeDef MakeStatelessMap(const NodeDef& map_node, const NodeDef& zip_node, const FunctionDef& stateless_function, MutableGraphView* graph) { NodeDef stateless_map; - graph_utils::SetUniqueGraphNodeName("stateless_map", graph->GetGraph(), + graph_utils::SetUniqueGraphNodeName("stateless_map", graph->graph(), &stateless_map); stateless_map.set_op("MapDataset"); @@ -68,7 +68,7 @@ NodeDef MakeRandomDataset(const NodeDef& random_uniform_node, MutableGraphView* graph) { NodeDef random_dataset; random_dataset.set_op("RandomDataset"); - graph_utils::SetUniqueGraphNodeName("RandomDataset", graph->GetGraph(), + graph_utils::SetUniqueGraphNodeName("RandomDataset", graph->graph(), &random_dataset); const auto* seed = graph_utils::AddScalarConstNode( @@ -89,7 +89,7 @@ NodeDef MakeRandomDataset(const NodeDef& random_uniform_node, NodeDef MakeBatchTwo(const NodeDef& random_dataset, MutableGraphView* graph) { NodeDef batch_dataset; batch_dataset.set_op("BatchDatasetV2"); - graph_utils::SetUniqueGraphNodeName("pair_of_random", graph->GetGraph(), + graph_utils::SetUniqueGraphNodeName("pair_of_random", graph->graph(), &batch_dataset); const auto* batch_size = graph_utils::AddScalarConstNode(2, graph); const auto* drop_reminder = graph_utils::AddScalarConstNode(false, graph); @@ -112,7 +112,7 @@ NodeDef MakeBatchTwo(const NodeDef& random_dataset, MutableGraphView* graph) { NodeDef MakeZipNode(const NodeDef& first_node, const NodeDef& second_node, MutableGraphView* graph) { NodeDef zip_node; - graph_utils::SetUniqueGraphNodeName("zip_with_random", graph->GetGraph(), + graph_utils::SetUniqueGraphNodeName("zip_with_random", graph->graph(), &zip_node); zip_node.set_op("ZipDataset"); @@ -266,7 +266,7 @@ Status HoistRandomUniform::Optimize(Cluster* cluster, const GrapplerItem& item, const auto* stateless_map = graph.AddNode( MakeStatelessMap(*map_node, *zip_node, *stateless_func, &graph)); - graph.ReplaceInput(*map_node, *stateless_map); + graph.UpdateFanouts(map_node->name(), stateless_map->name()); // TODO(b/116285210): we could also remove map functions from library if // they are not used anymore. diff --git a/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc b/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc index 9e382aeef9c257..16b2efb3ed3c25 100644 --- a/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc +++ b/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc @@ -37,8 +37,7 @@ NodeDef MakeLatencyNode(const NodeDef& node, MutableGraphView* graph) { NodeDef new_node; new_node.set_op(kInsertOpName); graph_utils::SetUniqueGraphNodeName( - strings::StrCat(kInsertOpName, "_generated"), graph->GetGraph(), - &new_node); + strings::StrCat(kInsertOpName, "_generated"), graph->graph(), &new_node); // Set the input of LatencyDataset node as `node` new_node.add_input(node.name()); @@ -81,7 +80,8 @@ Status LatencyAllEdges::Optimize(Cluster* cluster, const GrapplerItem& item, // node corresponds to a `Dataset` op. continue; } - GraphView::OutputPort output_port = graph.GetOutputPort(node.name(), 0); + MutableGraphView::OutputPort output_port = + graph.GetOutputPort(node.name(), 0); auto fanout = graph.GetFanout(output_port); if (fanout.size() > 1) { LOG(WARNING) << node.name() << " has fanout size " << fanout.size(); @@ -96,7 +96,8 @@ Status LatencyAllEdges::Optimize(Cluster* cluster, const GrapplerItem& item, } } - graph.InsertNode(node, MakeLatencyNode(node, &graph)); + NodeDef* latency_node = graph.AddNode(MakeLatencyNode(node, &graph)); + graph.UpdateFanouts(node.name(), latency_node->name()); } return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/data/make_numa_aware.cc b/tensorflow/core/grappler/optimizers/data/make_numa_aware.cc index f9d7d027c12e86..e5de981822376d 100644 --- a/tensorflow/core/grappler/optimizers/data/make_numa_aware.cc +++ b/tensorflow/core/grappler/optimizers/data/make_numa_aware.cc @@ -29,7 +29,7 @@ namespace { NodeDef MakeNumaAwareNode(const NodeDef& node, MutableGraphView* graph) { NodeDef numa_aware_node = node; - graph_utils::SetUniqueGraphNodeName("make_numa_aware", graph->GetGraph(), + graph_utils::SetUniqueGraphNodeName("make_numa_aware", graph->graph(), &numa_aware_node); numa_aware_node.set_op("ExperimentalNumaMapAndBatchDataset"); return numa_aware_node; @@ -47,7 +47,7 @@ Status MakeNumaAware::Optimize(Cluster* cluster, const GrapplerItem& item, if (node.op() != "MapAndBatchDatasetV2") continue; auto* numa_node = graph.AddNode(MakeNumaAwareNode(node, &graph)); - graph.ReplaceInput(node, *numa_node); + graph.UpdateFanouts(node.name(), numa_node->name()); nodes_to_delete.insert(node.name()); } graph.DeleteNodes(nodes_to_delete); diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc index e66766eb23bd53..800050b840326d 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc @@ -36,8 +36,7 @@ NodeDef MakeMapAndBatchNode(const NodeDef& map_node, const NodeDef& batch_node, MutableGraphView* graph) { NodeDef new_node; new_node.set_op(kFusedOpName); - graph_utils::SetUniqueGraphNodeName(kFusedOpName, graph->GetGraph(), - &new_node); + graph_utils::SetUniqueGraphNodeName(kFusedOpName, graph->graph(), &new_node); // Set the `input` input argument. new_node.add_input(map_node.input(0)); @@ -114,7 +113,7 @@ Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item, auto* new_node = graph.AddNode(MakeMapAndBatchNode(*map_node, batch_node, &graph)); - graph.ReplaceInput(batch_node, *new_node); + graph.UpdateFanouts(batch_node.name(), new_node->name()); // Mark the `Map` and `Batch` nodes for removal. nodes_to_delete.insert(map_node->name()); diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc index b676246b318d5b..eed558de7eb42c 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc @@ -309,7 +309,7 @@ TEST(MapAndBatchFusionTest, NoChange) { GraphDef output; TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); - EXPECT_TRUE(graph_utils::Compare(*graph.GetGraph(), output)); + EXPECT_TRUE(graph_utils::Compare(*graph.graph(), output)); } } // namespace diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc index c4868eacbbf6d4..2b0a347ce62514 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc @@ -37,23 +37,29 @@ NodeDef MakeFusedNode(const NodeDef& map_node, const FunctionDef& fused_function, MutableGraphView* graph) { NodeDef fused_node; - graph_utils::SetUniqueGraphNodeName("fused_map", graph->GetGraph(), - &fused_node); - fused_node.set_op("MapDataset"); - fused_node.add_input(map_node.input(0)); + graph_utils::SetUniqueGraphNodeName("fused_map", graph->graph(), &fused_node); + fused_node.set_op(map_node.op()); + + // Copy over inputs. + for (int i = 0; i < map_node.input_size(); ++i) { + fused_node.add_input(map_node.input(i)); + } auto attr = map_node.attr().at("f"); attr.mutable_func()->set_name(fused_function.signature().name()); (*fused_node.mutable_attr())["f"] = std::move(attr); - graph_utils::CopyAttribute("Targuments", map_node, &fused_node); - - for (auto key : {"output_shapes", "output_types"}) + // Required attrs. + for (auto key : {"Targuments", "output_shapes", "output_types"}) { graph_utils::CopyAttribute(key, map_node, &fused_node); + } - if (const auto* attr = - gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism")) - (*fused_node.mutable_attr())["use_inter_op_parallelism"] = *attr; + // Optional attrs. + for (auto key : {"use_inter_op_parallelism", "sloppy"}) { + if (const auto* attr = gtl::FindOrNull(map_node.attr(), key)) { + graph_utils::CopyAttribute(key, map_node, &fused_node); + } + } // Add the predicate output attributes. (*fused_node.mutable_attr())["output_types"] @@ -72,8 +78,8 @@ NodeDef MakeFilterByLastComponentNode(const NodeDef& fused_map_node, const NodeDef& filter_node, MutableGraphView* graph) { NodeDef filter_by_component; - graph_utils::SetUniqueGraphNodeName("FilterByLastComponent", - graph->GetGraph(), &filter_by_component); + graph_utils::SetUniqueGraphNodeName("FilterByLastComponent", graph->graph(), + &filter_by_component); filter_by_component.set_op("FilterByLastComponentDataset"); filter_by_component.add_input(fused_map_node.name()); @@ -98,7 +104,9 @@ Status MapAndFilterFusion::Optimize(Cluster* cluster, const GrapplerItem& item, FunctionLibraryDefinition function_library(OpRegistry::Global(), item.graph.library()); auto get_map_node = [](const NodeDef& node) -> const NodeDef* { - if (node.op() == "MapDataset") return &node; + if (node.op() == "MapDataset" || node.op() == "ParallelMapDataset") { + return &node; + } return nullptr; }; @@ -146,7 +154,7 @@ Status MapAndFilterFusion::Optimize(Cluster* cluster, const GrapplerItem& item, const auto* filter_by_component = graph.AddNode( MakeFilterByLastComponentNode(*fused_maps, *filter_node, &graph)); - graph.ReplaceInput(*filter_node, *filter_by_component); + graph.UpdateFanouts(filter_node->name(), filter_by_component->name()); TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_function)); // TODO(prazek): we could also remove functions from library if they are not diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc index 6e6da37d7c20de..c5a5e22aba6cd2 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc @@ -30,6 +30,7 @@ namespace grappler { namespace { using graph_tests_utils::MakeFilterNode; using graph_tests_utils::MakeMapNode; +using graph_tests_utils::MakeParallelMapNode; TEST(MapAndFilterFusionTest, FuseMapAndFilter) { using test::function::NDef; @@ -58,6 +59,41 @@ TEST(MapAndFilterFusionTest, FuseMapAndFilter) { graph_utils::ContainsNodeWithOp("FilterByLastComponentDataset", output)); } +TEST(MapAndFilterFusionTest, FuseParallelMapAndFilter) { + using test::function::NDef; + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}), + NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}), + NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("range", "RangeDataset", {"start", "stop", "step"}, {}), + NDef("num_parallel_calls", "Const", {}, + {{"value", 3}, {"dtype", "DT_INT32"}}), + MakeParallelMapNode("map", "range", "num_parallel_calls", "XTimesTwo", + /*sloppy=*/false), + MakeFilterNode("filter", "map")}, + // FunctionLib + { + test::function::XTimesTwo(), + test::function::IsZero(), + }); + + MapAndFilterFusion optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map", output)); + EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("filter", output)); + EXPECT_TRUE(graph_utils::ContainsNodeWithOp("ParallelMapDataset", output)) + << output.DebugString(); + auto& map_node = output.node( + graph_utils::FindGraphNodeWithOp("ParallelMapDataset", output)); + EXPECT_FALSE(map_node.attr().at("sloppy").b()) << map_node.DebugString(); + EXPECT_TRUE( + graph_utils::ContainsNodeWithOp("FilterByLastComponentDataset", output)) + << output.DebugString(); +} + TEST(MapAndFilterFusionTest, FuseMapAndFilterWithExtraChild) { using test::function::NDef; GrapplerItem item; @@ -103,6 +139,56 @@ TEST(MapAndFilterFusionTest, FuseMapAndFilterWithExtraChild) { EXPECT_EQ(cache_node.input(0), filter_by_component.name()); } +TEST(MapAndFilterFusionTest, FuseParallelMapAndFilterWithExtraChild) { + using test::function::NDef; + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}), + NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}), + NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_STRING}}), + NDef("range", "RangeDataset", {"start", "stop", "step"}, {}), + NDef("num_parallel_calls", "Const", {}, + {{"value", 3}, {"dtype", "DT_INT32"}}), + MakeParallelMapNode("map", "range", "num_parallel_calls", "XTimesTwo", + /*sloppy=*/true), + MakeFilterNode("filter", "map"), + NDef("cache", "CacheDataset", {"filter", "filename"}, {})}, + // FunctionLib + { + test::function::XTimesTwo(), + test::function::IsZero(), + }); + + MapAndFilterFusion optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map", output)); + EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("filter", output)); + ASSERT_TRUE(graph_utils::ContainsNodeWithOp("ParallelMapDataset", output)); + ASSERT_TRUE( + graph_utils::ContainsNodeWithOp("FilterByLastComponentDataset", output)); + ASSERT_TRUE(graph_utils::ContainsNodeWithOp("CacheDataset", output)); + + int map_id = graph_utils::FindGraphNodeWithOp("ParallelMapDataset", output); + auto& map_node = output.node(map_id); + ASSERT_EQ(map_node.input_size(), 2); + EXPECT_EQ(map_node.input(0), "range"); + EXPECT_EQ(map_node.input(1), "num_parallel_calls"); + + int filter_by_component_id = + graph_utils::FindGraphNodeWithOp("FilterByLastComponentDataset", output); + auto& filter_by_component = output.node(filter_by_component_id); + ASSERT_EQ(filter_by_component.input_size(), 1); + EXPECT_EQ(filter_by_component.input(0), map_node.name()); + + int cache_id = graph_utils::FindGraphNodeWithOp("CacheDataset", output); + auto& cache_node = output.node(cache_id); + ASSERT_EQ(cache_node.input_size(), 2); + EXPECT_EQ(cache_node.input(0), filter_by_component.name()); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_fusion.cc index bd943342e8009b..6ca0da27551bc7 100644 --- a/tensorflow/core/grappler/optimizers/data/map_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_fusion.cc @@ -39,8 +39,7 @@ NodeDef MakeFusedNode(const NodeDef& parent_map_node, const NodeDef& map_node, const FunctionDef& fused_function, MutableGraphView* graph) { NodeDef fused_node; - graph_utils::SetUniqueGraphNodeName("fused_map", graph->GetGraph(), - &fused_node); + graph_utils::SetUniqueGraphNodeName("fused_map", graph->graph(), &fused_node); fused_node.set_op("MapDataset"); fused_node.add_input(parent_map_node.input(0)); @@ -124,7 +123,7 @@ Status MapFusion::Optimize(Cluster* cluster, const GrapplerItem& item, const auto* fused_maps_node = graph.AddNode( MakeFusedNode(*parent_map_node, *map_node, *fused_function, &graph)); - graph.ReplaceInput(*map_node, *fused_maps_node); + graph.UpdateFanouts(map_node->name(), fused_maps_node->name()); // TODO(prazek): we should run some optimizations on the fused map // functions, or make sure that optimization passes run after map diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc index 782c9f48b74b94..8e49f908a77288 100644 --- a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc +++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc @@ -47,7 +47,7 @@ bool CanParallelize(const FunctionDef& function, NodeDef MakeParallelMap(const NodeDef& map_node, MutableGraphView* graph) { NodeDef parallel_map = map_node; - graph_utils::SetUniqueGraphNodeName("parallel_map", graph->GetGraph(), + graph_utils::SetUniqueGraphNodeName("parallel_map", graph->graph(), ¶llel_map); parallel_map.set_op("ParallelMapDataset"); // TODO(b/114475558): We want to set `num_parallel_calls` to a special value, @@ -83,7 +83,7 @@ Status MapParallelization::Optimize(Cluster* cluster, const GrapplerItem& item, if (!CanParallelize(*function, function_library)) continue; auto* parallel_map = graph.AddNode(MakeParallelMap(*map_node, &graph)); - graph.ReplaceInput(*map_node, *parallel_map); + graph.UpdateFanouts(map_node->name(), parallel_map->name()); nodes_to_delete.insert(map_node->name()); } diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc index 0576d075c252a5..3401dcc6f23bae 100644 --- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc +++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc @@ -147,7 +147,7 @@ NodeDef MakeNewBatchNode(const NodeDef& old_batch_node, MutableGraphView* graph) { NodeDef batch_node; batch_node.set_op(old_batch_node.op()); - graph_utils::SetUniqueGraphNodeName(batch_node.op(), graph->GetGraph(), + graph_utils::SetUniqueGraphNodeName(batch_node.op(), graph->graph(), &batch_node); // Set the `input_dataset` input argument @@ -187,8 +187,7 @@ NodeDef MakeNewMapNode(const NodeDef& old_map_node, MutableGraphView* graph) { NodeDef map_node; map_node.set_op(old_map_node.op()); - graph_utils::SetUniqueGraphNodeName(map_node.op(), graph->GetGraph(), - &map_node); + graph_utils::SetUniqueGraphNodeName(map_node.op(), graph->graph(), &map_node); // Set the `input_dataset` input argument map_node.add_input(new_batch_node.name()); @@ -265,7 +264,7 @@ Status MapVectorization::Optimize(Cluster* cluster, const GrapplerItem& item, auto* new_map_node = graph.AddNode(MakeNewMapNode( *map_node, batch_node, *new_batch_node, *vectorized_func, &graph)); - graph.ReplaceInput(batch_node, *new_map_node); + graph.UpdateFanouts(batch_node.name(), new_map_node->name()); // Mark the `Map` and `Batch` nodes for removal. nodes_to_delete.insert(map_node->name()); diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc index e47e91a375bc72..bd405c83294647 100644 --- a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc +++ b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc @@ -30,7 +30,7 @@ namespace tensorflow { namespace grappler { namespace { -bool IsTakeAll(const NodeDef& take_node, const GraphView& graph) { +bool IsTakeAll(const NodeDef& take_node, const MutableGraphView& graph) { if (take_node.op() != "TakeDataset") return false; const auto& count_node = *graph.GetNode(take_node.input(1)); @@ -44,25 +44,26 @@ bool IsConstNodeWithValue(const NodeDef& node, int value) { return node.attr().at("value").tensor().int64_val(0) == value; } -bool IsSkipNone(const NodeDef& skip_node, const GraphView& graph) { +bool IsSkipNone(const NodeDef& skip_node, const MutableGraphView& graph) { if (skip_node.op() != "SkipDataset") return false; // We are looking only for skip(0) nodes. return IsConstNodeWithValue(*graph.GetNode(skip_node.input(1)), 0); } -bool IsRepeatOne(const NodeDef& repeat_node, const GraphView& graph) { +bool IsRepeatOne(const NodeDef& repeat_node, const MutableGraphView& graph) { if (repeat_node.op() != "RepeatDataset") return false; // We are looking only for repeat(1) nodes. return IsConstNodeWithValue(*graph.GetNode(repeat_node.input(1)), 1); } -bool IsPrefetchZero(const NodeDef& prefetch_node, const GraphView& graph) { +bool IsPrefetchZero(const NodeDef& prefetch_node, + const MutableGraphView& graph) { if (prefetch_node.op() != "PrefetchDataset") return false; // We are looking only for prefetch(0) nodes. return IsConstNodeWithValue(*graph.GetNode(prefetch_node.input(1)), 0); } -bool IsNoOp(const NodeDef& node, const GraphView& graph) { +bool IsNoOp(const NodeDef& node, const MutableGraphView& graph) { return IsTakeAll(node, graph) || IsSkipNone(node, graph) || IsRepeatOne(node, graph) || IsPrefetchZero(node, graph); } @@ -78,7 +79,7 @@ Status NoOpElimination::Optimize(Cluster* cluster, const GrapplerItem& item, if (!IsNoOp(node, graph)) continue; NodeDef* const parent = graph_utils::GetInputNode(node, graph); - graph.ReplaceInput(node, *parent); + graph.UpdateFanouts(node.name(), parent->name()); nodes_to_delete.insert(node.name()); } diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc index 99c4afa6340094..d9af78d38cd590 100644 --- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc @@ -86,7 +86,7 @@ Status ShuffleAndRepeatFusion::Optimize(Cluster* cluster, NodeDef* shuffle_and_repeat_node = graph.AddNode(make_shuffle_and_repeat_node(shuffle_node, repeat_node)); - graph.ReplaceInput(repeat_node, *shuffle_and_repeat_node); + graph.UpdateFanouts(repeat_node.name(), shuffle_and_repeat_node->name()); // Mark the `Shuffle` and `Repeat` nodes for removal. nodes_to_delete.insert(shuffle_node.name()); diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc index f0696eb76d02cc..556e1d3ab57947 100644 --- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc +++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc @@ -127,7 +127,7 @@ TEST(ShuffleAndRepeatFusionTest, NoChange) { GraphDef output; TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); - EXPECT_TRUE(graph_utils::Compare(*graph.GetGraph(), output)); + EXPECT_TRUE(graph_utils::Compare(*graph.graph(), output)); } } // namespace diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD index 01652172e3d077..5175f6af7a213d 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD +++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD @@ -60,13 +60,6 @@ tf_cc_test( ] + tf_protos_all(), ) -cc_library( - name = "cast_vectorizer", - srcs = ["cast_vectorizer.cc"], - deps = VECTORIZER_DEPS, - alwayslink = 1, -) - cc_library( name = "cwise_op_vectorizer", srcs = ["cwise_op_vectorizer.cc"], @@ -95,6 +88,18 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "transpose_vectorizer", + srcs = ["transpose_vectorizer.cc"], + deps = VECTORIZER_DEPS + [ + ":vectorizer", + ":wrapped_tensor", + "//tensorflow/cc:scope", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + cc_library( name = "unpack_vectorizer", srcs = ["unpack_vectorizer.cc"], @@ -107,13 +112,14 @@ cc_library( hdrs = ["vectorizer_registry.h"], visibility = ["//visibility:public"], deps = [ - ":cast_vectorizer", ":cwise_op_vectorizer", ":decode_csv_vectorizer", ":parse_single_example_vectorizer", ":reshape_vectorizer", + ":transpose_vectorizer", ":unpack_vectorizer", ":vectorizer", ":vectorizer_registry", + "@com_google_absl//absl/container:flat_hash_map", ], ) diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc deleted file mode 100644 index f4451575313903..00000000000000 --- a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" -#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" - -namespace tensorflow { -namespace grappler { -namespace { - -class CastVectorizer : public Vectorizer { - public: - Status Vectorize(const Node& node, Graph* outer_scope, - std::vector&& inputs, - std::vector* outputs) override { - Status s; - if (node.num_inputs() != 1) { - return errors::Internal("Cast op should only have one input."); - } - - // Add new Cast node with the same op and attrs as the original node - auto new_cast_node = outer_scope->AddNode(node.def(), &s); - TF_RETURN_IF_ERROR(s); - - outer_scope->AddEdge(inputs[0].node, inputs[0].output_index, new_cast_node, - 0); - - // Add output mappings - outputs->push_back({new_cast_node, 0, true}); - return Status::OK(); - } -}; - -REGISTER_VECTORIZER("Cast", CastVectorizer); - -} // namespace -} // namespace grappler -} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cwise_op_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cwise_op_vectorizer.cc index d26de1b36dd9f8..709882e45aec2b 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/cwise_op_vectorizer.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization/cwise_op_vectorizer.cc @@ -119,136 +119,163 @@ Status ExpandDimsForBroadcast(std::vector* inputs, Graph* g) { return status; } -// Vectorizer for component-wise ops. Since these operations act component-wise, -// the vectorized op is the same as the original, with additional -// instrumentation to support correct broadcasting for binary ops. -class CwiseOpVectorizer : public Vectorizer { +// Vectorization helper for component-wise ops. Since these operations act +// component-wise, the vectorized op is the same as the original. +Status CwiseVectorizeHelper(const Node& node, Graph* outer_scope, + std::vector&& inputs, + std::vector* outputs) { + // Add new node with the same op type and attrs as the original node + Node* new_node; + auto node_builder = NodeBuilder(strings::StrCat("vectorized/", node.name()), + node.type_string()); + for (const auto& input : inputs) { + node_builder = node_builder.Input(input.node, input.output_index); + } + for (const auto& attr_slice : node.attrs()) { + node_builder = node_builder.Attr(attr_slice.first, attr_slice.second); + } + TF_RETURN_IF_ERROR(node_builder.Finalize(outer_scope, &new_node)); + + // Add output mappings + outputs->push_back({new_node, 0, true}); + return Status::OK(); +} + +class UnaryCwiseOpVectorizer : public Vectorizer { public: Status Vectorize(const Node& node, Graph* outer_scope, std::vector&& inputs, std::vector* outputs) override { - if (inputs.size() > 1) { - // Binary ops support broadcasting - TF_RETURN_IF_ERROR(ExpandDimsForBroadcast(&inputs, outer_scope)); + if (inputs.size() != 1) { + return errors::Internal("Failed to vectorize ", node.type_string(), + ". The op should have 1 input, but has ", + inputs.size()); } - // Add new node with the same op type and attrs as the original node - Node* new_node; - auto node_builder = NodeBuilder(strings::StrCat("vectorized/", node.name()), - node.type_string()); - for (const auto& input : inputs) { - node_builder = node_builder.Input(input.node, input.output_index); - } - for (const auto& attr_slice : node.attrs()) { - node_builder = node_builder.Attr(attr_slice.first, attr_slice.second); + return CwiseVectorizeHelper(node, outer_scope, std::move(inputs), outputs); + } +}; + +class BinaryCwiseOpVectorizer : public Vectorizer { + public: + Status Vectorize(const Node& node, Graph* outer_scope, + std::vector&& inputs, + std::vector* outputs) override { + if (inputs.size() != 2) { + return errors::Internal("Failed to vectorize ", node.type_string(), + ". The op should have 2 input, but has ", + inputs.size()); } - TF_RETURN_IF_ERROR(node_builder.Finalize(outer_scope, &new_node)); + // Binary ops support broadcasting + TF_RETURN_IF_ERROR(ExpandDimsForBroadcast(&inputs, outer_scope)); - // Add output mappings - outputs->push_back({new_node, 0, true}); - return Status::OK(); + return CwiseVectorizeHelper(node, outer_scope, std::move(inputs), outputs); } }; // Bitwise unary -REGISTER_VECTORIZER("Invert", CwiseOpVectorizer); +REGISTER_VECTORIZER("Invert", UnaryCwiseOpVectorizer); // Logical unary -REGISTER_VECTORIZER("LogicalNot", CwiseOpVectorizer); +REGISTER_VECTORIZER("LogicalNot", UnaryCwiseOpVectorizer); // Complex unary -REGISTER_VECTORIZER("Angle", CwiseOpVectorizer); -REGISTER_VECTORIZER("ComplexAbs", CwiseOpVectorizer); -REGISTER_VECTORIZER("Conj", CwiseOpVectorizer); -REGISTER_VECTORIZER("Imag", CwiseOpVectorizer); -REGISTER_VECTORIZER("Real", CwiseOpVectorizer); +REGISTER_VECTORIZER("Angle", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("ComplexAbs", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Conj", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Imag", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Real", UnaryCwiseOpVectorizer); // Real unary -REGISTER_VECTORIZER("Abs", CwiseOpVectorizer); -REGISTER_VECTORIZER("Acos", CwiseOpVectorizer); -REGISTER_VECTORIZER("Acosh", CwiseOpVectorizer); -REGISTER_VECTORIZER("Asin", CwiseOpVectorizer); -REGISTER_VECTORIZER("Asinh", CwiseOpVectorizer); -REGISTER_VECTORIZER("Atan", CwiseOpVectorizer); -REGISTER_VECTORIZER("Atanh", CwiseOpVectorizer); -REGISTER_VECTORIZER("BesselI0e", CwiseOpVectorizer); -REGISTER_VECTORIZER("BesselI1e", CwiseOpVectorizer); -REGISTER_VECTORIZER("Ceil", CwiseOpVectorizer); -REGISTER_VECTORIZER("Cos", CwiseOpVectorizer); -REGISTER_VECTORIZER("Cosh", CwiseOpVectorizer); -REGISTER_VECTORIZER("Digamma", CwiseOpVectorizer); -REGISTER_VECTORIZER("Elu", CwiseOpVectorizer); -REGISTER_VECTORIZER("Erf", CwiseOpVectorizer); -REGISTER_VECTORIZER("Erfc", CwiseOpVectorizer); -REGISTER_VECTORIZER("Exp", CwiseOpVectorizer); -REGISTER_VECTORIZER("Expm1", CwiseOpVectorizer); -REGISTER_VECTORIZER("Floor", CwiseOpVectorizer); -REGISTER_VECTORIZER("Inv", CwiseOpVectorizer); -REGISTER_VECTORIZER("IsFinite", CwiseOpVectorizer); -REGISTER_VECTORIZER("IsInf", CwiseOpVectorizer); -REGISTER_VECTORIZER("Lgamma", CwiseOpVectorizer); -REGISTER_VECTORIZER("Log", CwiseOpVectorizer); -REGISTER_VECTORIZER("Log1p", CwiseOpVectorizer); -REGISTER_VECTORIZER("Neg", CwiseOpVectorizer); -REGISTER_VECTORIZER("Reciprocal", CwiseOpVectorizer); -REGISTER_VECTORIZER("Relu", CwiseOpVectorizer); -REGISTER_VECTORIZER("Relu6", CwiseOpVectorizer); -REGISTER_VECTORIZER("Rint", CwiseOpVectorizer); -REGISTER_VECTORIZER("Round", CwiseOpVectorizer); -REGISTER_VECTORIZER("Rsqrt", CwiseOpVectorizer); -REGISTER_VECTORIZER("Selu", CwiseOpVectorizer); -REGISTER_VECTORIZER("Sigmoid", CwiseOpVectorizer); -REGISTER_VECTORIZER("Sign", CwiseOpVectorizer); -REGISTER_VECTORIZER("Sin", CwiseOpVectorizer); -REGISTER_VECTORIZER("Sinh", CwiseOpVectorizer); -REGISTER_VECTORIZER("Softplus", CwiseOpVectorizer); -REGISTER_VECTORIZER("Softsign", CwiseOpVectorizer); -REGISTER_VECTORIZER("Sqrt", CwiseOpVectorizer); -REGISTER_VECTORIZER("Square", CwiseOpVectorizer); -REGISTER_VECTORIZER("Tanh", CwiseOpVectorizer); -REGISTER_VECTORIZER("Tan", CwiseOpVectorizer); +REGISTER_VECTORIZER("Abs", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Acos", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Acosh", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Asin", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Asinh", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Atan", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Atanh", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("BesselI0e", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("BesselI1e", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Ceil", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Cos", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Cosh", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Digamma", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Elu", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Erf", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Erfc", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Exp", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Expm1", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Floor", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Inv", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("IsFinite", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("IsInf", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Lgamma", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Log", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Log1p", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Neg", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Reciprocal", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Relu", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Relu6", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Rint", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Round", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Rsqrt", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Selu", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Sigmoid", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Sign", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Sin", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Sinh", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Softplus", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Softsign", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Sqrt", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Square", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Tanh", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Tan", UnaryCwiseOpVectorizer); + +// Miscellaneous unary +REGISTER_VECTORIZER("Cast", UnaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Identity", UnaryCwiseOpVectorizer); // Bitwise binary -REGISTER_VECTORIZER("BitwiseAnd", CwiseOpVectorizer); -REGISTER_VECTORIZER("BitwiseOr", CwiseOpVectorizer); -REGISTER_VECTORIZER("BitwiseXor", CwiseOpVectorizer); -REGISTER_VECTORIZER("LeftShift", CwiseOpVectorizer); -REGISTER_VECTORIZER("RightShift", CwiseOpVectorizer); +REGISTER_VECTORIZER("BitwiseAnd", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("BitwiseOr", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("BitwiseXor", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("LeftShift", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("RightShift", BinaryCwiseOpVectorizer); // Logical binary -REGISTER_VECTORIZER("LogicalAnd", CwiseOpVectorizer); -REGISTER_VECTORIZER("LogicalOr", CwiseOpVectorizer); +REGISTER_VECTORIZER("LogicalAnd", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("LogicalOr", BinaryCwiseOpVectorizer); // Real binary -REGISTER_VECTORIZER("Add", CwiseOpVectorizer); -REGISTER_VECTORIZER("AddV2", CwiseOpVectorizer); -REGISTER_VECTORIZER("Atan2", CwiseOpVectorizer); -REGISTER_VECTORIZER("Complex", CwiseOpVectorizer); -REGISTER_VECTORIZER("Div", CwiseOpVectorizer); -REGISTER_VECTORIZER("DivNoNan", CwiseOpVectorizer); -REGISTER_VECTORIZER("Equal", CwiseOpVectorizer); -REGISTER_VECTORIZER("FloorDiv", CwiseOpVectorizer); -REGISTER_VECTORIZER("FloorMod", CwiseOpVectorizer); -REGISTER_VECTORIZER("Greater", CwiseOpVectorizer); -REGISTER_VECTORIZER("GreaterEqual", CwiseOpVectorizer); -REGISTER_VECTORIZER("Igamma", CwiseOpVectorizer); -REGISTER_VECTORIZER("Igammac", CwiseOpVectorizer); -REGISTER_VECTORIZER("IgammaGradA", CwiseOpVectorizer); -REGISTER_VECTORIZER("Less", CwiseOpVectorizer); -REGISTER_VECTORIZER("LessEqual", CwiseOpVectorizer); -REGISTER_VECTORIZER("Maximum", CwiseOpVectorizer); -REGISTER_VECTORIZER("Minimum", CwiseOpVectorizer); -REGISTER_VECTORIZER("Mod", CwiseOpVectorizer); -REGISTER_VECTORIZER("Mul", CwiseOpVectorizer); -REGISTER_VECTORIZER("NotEqual", CwiseOpVectorizer); -REGISTER_VECTORIZER("Polygamma", CwiseOpVectorizer); -REGISTER_VECTORIZER("Pow", CwiseOpVectorizer); -REGISTER_VECTORIZER("RealDiv", CwiseOpVectorizer); -REGISTER_VECTORIZER("SquaredDifference", CwiseOpVectorizer); -REGISTER_VECTORIZER("Sub", CwiseOpVectorizer); -REGISTER_VECTORIZER("TruncateDiv", CwiseOpVectorizer); -REGISTER_VECTORIZER("TruncateMod", CwiseOpVectorizer); -REGISTER_VECTORIZER("Zeta", CwiseOpVectorizer); +REGISTER_VECTORIZER("Add", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("AddV2", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Atan2", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Complex", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Div", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("DivNoNan", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Equal", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("FloorDiv", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("FloorMod", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Greater", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("GreaterEqual", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Igamma", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Igammac", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("IgammaGradA", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Less", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("LessEqual", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Maximum", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Minimum", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Mod", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Mul", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("NotEqual", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Polygamma", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Pow", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("RealDiv", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("SquaredDifference", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Sub", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("TruncateDiv", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("TruncateMod", BinaryCwiseOpVectorizer); +REGISTER_VECTORIZER("Zeta", BinaryCwiseOpVectorizer); } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/transpose_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/transpose_vectorizer.cc new file mode 100644 index 00000000000000..4c286d9c4a925b --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/vectorization/transpose_vectorizer.cc @@ -0,0 +1,84 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/framework/scope_internal.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/math_ops.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h" +#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" +#include "tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { + +namespace { + +constexpr char kTransposePrefix[] = "vectorized/transpose"; + +class TransposeVectorizer : public Vectorizer { + public: + Status Vectorize(const Node& node, Graph* outer_scope, + std::vector&& inputs, + std::vector* outputs) override { + if (!inputs[0].stacked || inputs[1].stacked) { + return errors::InvalidArgument( + "Expecting input 0 (`x`) to be stacked and input 1 (`perm`) to " + "be unstacked."); + } + + Status status; + Scope parent = NewInternalScope(outer_scope, &status, /*refiner=*/nullptr); + Scope scope = parent.NewSubScope(kTransposePrefix); + + Output tensor = {inputs[0].node, inputs[0].output_index}; + Output original_perm = {inputs[1].node, inputs[1].output_index}; + + // The vectorized permutation is the original permutation with an additional + // leading 0 and all other values incremented by 1. + // perm = tf.concat([[0], original_perm + 1], axis=0) + Output perm = + ops::Concat(scope, + std::initializer_list( + {ops::Const(scope, {0}), + ops::Add(scope, original_perm, ops::Const(scope, 1))}), + ops::Const(scope, 0)); + + Output vectorized_transpose = ops::Transpose(scope, tensor, perm); + + TF_RETURN_IF_ERROR(status); + + // Add output mappings. + outputs->push_back({vectorized_transpose.node(), 0, true}); + return Status::OK(); + } +}; + +REGISTER_VECTORIZER("Transpose", TransposeVectorizer); + +} // namespace + +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc index 8b93b1f2b8339f..60c557d557e311 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" #include "absl/strings/str_join.h" @@ -129,6 +130,7 @@ void RemoveMapDefunOutput(int output_position, Graph* outer_scope, // This class transforms the input FunctionDefs into their corresponding // Graph objects and works on the graphs directly, then converts them back // to FunctionDefs when GetResult is called. +// TODO(rachelim): Move this to its own header. class Vectorization { public: explicit Vectorization(FunctionDefLibrary* lib) @@ -181,18 +183,25 @@ class Vectorization { Status StackTensor(WrappedTensor* unstacked, TensorDesc* result); // Recursively looks for unstacked nodes in the `map_defun_fn_` graph by - // doing a depth-first search from the ret nodes. Lifts nodes that are - // unstacked (i.e. don't derive from arg nodes) into `outer_scope_` directly - // and add mappings to `conversion_map_`. - Status AddUnstackedNodeMappings(); - - // Recursive helper for `AddUnstackedNodeMappings`, returns true if tensor - // is unstacked. - bool AddUnstackedNodeMappingsHelper(TensorDesc&& tensor, Status* status); - - // Add mappings from `map_defun_fn_` arg nodes to `map_defun_node_` input - // nodes to `conversion_map_`. - Status AddArgNodeMappings(); + // doing a depth-first search from the ret nodes. Lifts tensors that are + // unstacked (i.e. don't derive from arg tensors) into `outer_scope_` directly + // and adds mappings to `conversion_map_`. + // Note that this function may have false negatives, i.e. not + // add mappings for some tensors that are unstacked. This may happen in the + // following cases: 1) a vectorized op produces unstacked outputs from stacked + // inputs (e.g. the vectorized "Shape" op), 2) the tensors are in a cycle, or + // 3) the unstacked op could not be lifted into `outer_scope`. + Status AddUnstackedTensorMappings(); + + // Recursive helper for `AddUnstackedTensorMappings`. If an op node is + // unstacked, lifts its output tensors into `outer_scope`, adding the mappings + // to `conversion_map`. Returns true if the unstacked mappings were added. + bool AddUnstackedTensorMappingsHelper( + TensorDesc&& tensor, absl::flat_hash_set* visited); + + // Add mappings from `map_defun_fn_` arg tensors to `map_defun_node_` input + // tensors to `conversion_map_`. + Status AddArgTensorMappings(); // Maps a tensor to the corresponding WrappedTensor. For example, // {"Cast" Node*, 0} -> WrappedTensor({"Vectorize/Cast" Node*, 0}, true) @@ -395,9 +404,8 @@ Status Vectorization::Initialize(const FunctionDef& outer_scope, } map_defun_node_ = outer_scope_->FindNodeId(node_id); - TF_RETURN_IF_ERROR(AddArgNodeMappings()); - - TF_RETURN_IF_ERROR(AddUnstackedNodeMappings()); + TF_RETURN_IF_ERROR(AddArgTensorMappings()); + TF_RETURN_IF_ERROR(AddUnstackedTensorMappings()); loop_len_node_ = nullptr; return Status::OK(); @@ -488,7 +496,7 @@ Status Vectorization::StackTensor(WrappedTensor* unstacked, return Status::OK(); } -Status Vectorization::AddArgNodeMappings() { +Status Vectorization::AddArgTensorMappings() { // Note that inputs to map_defun_fn_ are either regular arguments (for which // the operations are mapped across their 0th dimension) or captured inputs // (for which the operations apply to the argument wholesale). @@ -523,8 +531,8 @@ Status Vectorization::AddArgNodeMappings() { return Status::OK(); } -bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor, - Status* status) { +bool Vectorization::AddUnstackedTensorMappingsHelper( + TensorDesc&& tensor, absl::flat_hash_set* visited) { if (auto found = gtl::FindOrNull(conversion_map_, tensor)) { return !found->stacked; } @@ -536,14 +544,22 @@ bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor, } bool is_unstacked = true; - for (auto edge : tensor.first->in_edges()) { + for (const auto& edge : tensor.first->in_edges()) { // Ignore Source nodes. Note that these are also ignored in the // GraphToFunctionDef conversion. if (edge->src()->IsSource()) continue; + if (visited->find(edge) != visited->end()) { + // If we've visited this edge already, we're in a cycle. In this case, we + // are conservative and don't mark the node as unstacked. + is_unstacked = false; + continue; + } + visited->insert(edge); + // A node is unstacked if all of its inputs are unstacked - is_unstacked &= AddUnstackedNodeMappingsHelper( - {edge->src(), edge->src_output()}, status); + is_unstacked &= AddUnstackedTensorMappingsHelper( + {edge->src(), edge->src_output()}, visited); } if (!is_unstacked) { @@ -553,11 +569,12 @@ bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor, // If the node is unstacked, we copy it into outer_scope_ and // add it to the map. Note that we don't clean up the nodes that are copied // in map_defun_fn_, and rely on them being pruned out later. - Node* node = outer_scope_->AddNode(tensor.first->def(), status); - if (!status->ok()) return true; + Status status; + Node* node = outer_scope_->AddNode(tensor.first->def(), &status); + if (!status.ok()) return false; // Add input edges to nodes that should already have been lifted. - for (auto edge : tensor.first->in_edges()) { + for (const auto& edge : tensor.first->in_edges()) { // Ignore Source nodes. Note that these are also ignored in the // GraphToFunctionDef conversion. if (edge->src()->IsSource()) continue; @@ -567,9 +584,7 @@ bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor, outer_scope_->AddEdge(found->node, found->output_index, node, edge->dst_input()); } else { - status->Update(errors::Internal( - "Could not find input conversion even though we did depth first " - "conversion.")); + return false; } } @@ -583,14 +598,13 @@ bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor, return true; } -Status Vectorization::AddUnstackedNodeMappings() { - SetVector unstacked_nodes; - Status s; +Status Vectorization::AddUnstackedTensorMappings() { + absl::flat_hash_set visited; for (const auto& ret_node : map_defun_fn_->ret_nodes) { const Edge* in_edge = nullptr; TF_RETURN_IF_ERROR(ret_node->input_edge(0, &in_edge)); - AddUnstackedNodeMappingsHelper({in_edge->src(), in_edge->src_output()}, &s); - TF_RETURN_IF_ERROR(s); + AddUnstackedTensorMappingsHelper({in_edge->src(), in_edge->src_output()}, + &visited); } return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc index aa3696b5be4cca..f5aa8c888e0dae 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc @@ -15,14 +15,30 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h" +#include +#include +#include +#include + +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/grappler/optimizers/data/function_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/tools/graph_transforms/transform_utils.h" namespace tensorflow { @@ -30,39 +46,73 @@ namespace grappler { namespace vectorization_utils { namespace { -NodeDef* AddCastNode(const string& name, const std::vector& inputs, - DataType src, DataType dst, bool truncate, - FunctionDef* fn) { - NodeDef* node = function_utils::AddNode(name, "Cast", inputs, {}, fn); - graph_transforms::SetNodeAttr("SrcT", src, node); - graph_transforms::SetNodeAttr("DstT", dst, node); - graph_transforms::SetNodeAttr("Truncate", truncate, node); - return node; +// Wraps a function in another function with a MapDefun node +Status WrapFunctionWithMapDefun(const FunctionDef& inner, FunctionDef* result) { + Graph graph(OpRegistry::Global()); + std::vector inputs; + inputs.reserve(inner.signature().input_arg_size()); + for (int i = 0; i < inner.signature().input_arg_size(); ++i) { + Node* arg; + TF_RETURN_IF_ERROR( + NodeBuilder(strings::StrCat("arg", i), /*op_name=*/"_Arg") + .Attr("T", inner.signature().input_arg(i).type()) + .Attr("index", i) + .Finalize(&graph, &arg)); + inputs.push_back(arg); + } + + DataTypeVector output_types; + output_types.reserve(inner.signature().output_arg_size()); + for (const auto& output_arg : inner.signature().output_arg()) { + output_types.push_back(output_arg.type()); + } + + Node* map_defun_node; + NameAttrList func_attr; + func_attr.set_name(inner.signature().name()); + TF_RETURN_IF_ERROR( + NodeBuilder("map_defun", "MapDefun") + .Input(inputs) // arguments + .Input(std::vector()) // captured_inputs + .Attr("f", func_attr) + .Attr("output_types", output_types) + .Attr("output_shapes", std::vector( + inner.signature().output_arg_size())) + .Finalize(&graph, &map_defun_node)); + + for (size_t i = 0; i < map_defun_node->num_outputs(); ++i) { + Node* ret; + TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat("ret", i), "_Retval") + .Input(map_defun_node, i) + .Attr("index", static_cast(i)) + .Finalize(&graph, &ret)); + } + + return GraphToFunctionDef(graph, "outer_function", result); } -NodeDef* AddUnstackNode(const string& name, const std::vector& inputs, - DataType t, int axis, int num, FunctionDef* fn) { - NodeDef* node = function_utils::AddNode(name, "Unpack", inputs, {}, fn); - graph_transforms::SetNodeAttr("T", t, node); - graph_transforms::SetNodeAttr("axis", axis, node); - graph_transforms::SetNodeAttr("num", num, node); - return node; +// Wraps the function `fn` in another function with a MapDefun node, then +// vectorizes the wrapper function with VectorizeMapDefun. +Status WrapAndVectorize(const FunctionDef& fn, FunctionDefLibrary* lib, + FunctionDef** result) { + FunctionDef outer; + TF_RETURN_IF_ERROR(WrapFunctionWithMapDefun(fn, &outer)); + const NodeDef& map_defun_node = outer.node_def(0); + + *lib->add_function() = outer; + *lib->add_function() = fn; + + TF_RETURN_IF_ERROR(VectorizeMapDefun(outer, map_defun_node, lib, result)); + + return Status::OK(); } -NodeDef* AddMapDefunNode(const string& name, const std::vector& inputs, - const std::vector& t_arguments, - const std::vector& output_types, - const std::vector& output_shapes, - const string& function_name, FunctionDef* fn) { - NameAttrList func; - func.set_name(function_name); - NodeDef* node = function_utils::AddNode(name, "MapDefun", inputs, {}, fn); - graph_transforms::SetNodeAttr("Targuments", t_arguments, node); - graph_transforms::SetNodeAttr("Tcaptured", DataTypeVector(), node); - graph_transforms::SetNodeAttr("output_types", output_types, node); - graph_transforms::SetNodeAttr("output_shapes", output_shapes, node); - graph_transforms::SetNodeAttr("f", func, node); - return node; +FunctionDefHelper::Node Cast(string&& name, std::vector&& inputs, + DataType src, DataType dst) { + return {{name}, + "Cast", + inputs, + {{"SrcT", src}, {"DstT", dst}, {"Truncate", false}}}; } string GetRetval(const FunctionDef& function_def, int index) { @@ -70,31 +120,6 @@ string GetRetval(const FunctionDef& function_def, int index) { function_def.signature().output_arg(index).name()); } -// TODO(rachelim): Use FunctionDefHelper::Create instead -FunctionDef CreateFunction( - StringPiece name, const std::vector>& inputs, - const std::vector>& outputs, - const std::map& rets) { - FunctionDef func; - auto* signature = func.mutable_signature(); - signature->set_name(string(name)); - for (const auto& x : inputs) { - auto* arg_def = signature->add_input_arg(); - arg_def->set_name(x.first); - arg_def->set_type(x.second); - } - for (const auto& x : outputs) { - auto* arg_def = signature->add_output_arg(); - arg_def->set_name(x.first); - arg_def->set_type(x.second); - } - for (const auto& x : rets) { - (*func.mutable_ret())[x.first] = x.second; - } - - return func; -} - ///==================================// // Tests for vectorization framework // ///==================================// @@ -131,31 +156,23 @@ FunctionDef CreateFunction( // +------+ +------+ // TEST(VectorizeMapDefunTest, VectorizeWithNoOps) { - FunctionDef inner = - CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}}, - {{"ret0", DT_INT32}, {"ret1", DT_INT32}}, - {{"ret0", "arg0"}, {"ret1", "arg1"}}); - FunctionDef outer = CreateFunction( - "outer_function", {{"ret0", DT_INT32}, {"ret1", DT_INT32}}, - {{"mapdefun", DT_INT32}, {"mapdefun_0", DT_INT32}}, - {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}}); - - NodeDef* map_defun = AddMapDefunNode( - "MapDefun", {"ret0", "ret1"}, {DT_INT32, DT_INT32}, {DT_INT32, DT_INT32}, - {{}, {}}, inner.signature().name(), &outer); - CHECK_NOTNULL(map_defun); - + FunctionDef inner = FunctionDefHelper::Create( + /*function_name=*/"inner_function", + /*in_def=*/{"arg0: int32", "arg1: int32"}, + /*out_def=*/{"ret0: int32", "ret1: int32"}, + /*attr_def=*/{}, + /*node_def=*/{}, + /*ret_def=*/{{"ret0", "arg0"}, {"ret1", "arg1"}}); FunctionDefLibrary lib; - *lib.add_function() = outer; - *lib.add_function() = inner; FunctionDef* vectorized; - Status s = VectorizeMapDefun(outer, *map_defun, &lib, &vectorized); - LOG(ERROR) << s; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); + TF_ASSERT_OK(WrapAndVectorize(inner, &lib, &vectorized)); + EXPECT_TRUE( !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); - EXPECT_EQ(GetRetval(*vectorized, 0), "ret0"); - EXPECT_EQ(GetRetval(*vectorized, 1), "ret1"); + EXPECT_EQ(GetRetval(*vectorized, 0), + vectorized->signature().input_arg(0).name()); + EXPECT_EQ(GetRetval(*vectorized, 1), + vectorized->signature().input_arg(1).name()); } // Before: @@ -214,41 +231,32 @@ TEST(VectorizeMapDefunTest, VectorizeWithNoOps) { // +------+ +------+ // TEST(VectorizeMapDefunTest, VectorizeWithUnvectorizableOp) { - FunctionDef inner = - CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}}, - {{"ret0", DT_INT32}, {"ret1", DT_INT32}}, - {{"ret0", "MatMul:product:0"}, {"ret1", "Cast:y:0"}}); + FunctionDef inner = FunctionDefHelper::Create( + /*function_name=*/"inner_function", + /*in_def=*/{"arg0: int32", "arg1: int32"}, + /*out_def=*/{"ret0: int32", "ret1: int32"}, + /*attr_def=*/{}, + /*node_def=*/ + {{{"MatMul"}, "MatMul", {"arg0", "arg0"}, {{"T", DT_INT32}}}, + Cast("Cast", {"arg1"}, DT_INT32, DT_INT32)}, // + /*ret_def=*/{{"ret0", "MatMul:product:0"}, {"ret1", "Cast:y:0"}}); // TODO(rachelim): If we ever write a converter for MatMul, we have to // change this test. - NodeDef* x_op1 = - function_utils::AddNode("MatMul", "MatMul", {"arg0", "arg0"}, {}, &inner); - CHECK_NOTNULL(x_op1); - graph_transforms::SetNodeAttr("T", DT_INT32, x_op1); - - NodeDef* cast_node = - AddCastNode("Cast", {"arg1"}, DT_INT32, DT_INT32, false, &inner); - CHECK_NOTNULL(cast_node); - - FunctionDef outer = CreateFunction( - "outer_function", {{"x", DT_INT32}, {"y", DT_INT32}}, - {{"mapdefun", DT_INT32}, {"mapdefun_0", DT_INT32}}, - {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}}); - - NodeDef* map_defun = AddMapDefunNode( - "MapDefun", {"x", "y"}, {DT_INT32, DT_INT32}, {DT_INT32, DT_INT32}, - {{}, {}}, inner.signature().name(), &outer); - CHECK_NOTNULL(map_defun); FunctionDefLibrary lib; - *lib.add_function() = outer; - *lib.add_function() = inner; FunctionDef* vectorized; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); + TF_ASSERT_OK(WrapAndVectorize(inner, &lib, &vectorized)); + ASSERT_TRUE( + function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); auto map_defun_node = vectorized->node_def( function_utils::FindFunctionNodeWithOp("MapDefun", *vectorized)); + // The Cast node should be converted just fine. - EXPECT_EQ(GetRetval(*vectorized, 1), "Cast:y:0"); + ASSERT_TRUE(function_utils::ContainsFunctionNodeWithOp("Cast", *vectorized)); + auto cast = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); + EXPECT_EQ(GetRetval(*vectorized, 1), strings::StrCat(cast.name(), ":y:0")); // The inner function should only have one retval. FunctionLibraryDefinition lib_def(OpRegistry::Global(), lib); @@ -301,34 +309,23 @@ TEST(VectorizeMapDefunTest, VectorizeWithUnvectorizableOp) { // TEST(VectorizeMapDefunTest, VectorizeWithOutputUsedTwice) { // Tests that behavior is correct when an output is used more than once. - FunctionDef inner = - CreateFunction("inner_function", {{"arg0", DT_INT32}}, - {{"ret0", DT_INT64}, {"ret1", DT_INT64}}, - {{"ret0", "Cast:y:0"}, {"ret1", "Cast:y:0"}}); - NodeDef* cast_op = - AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner); - CHECK_NOTNULL(cast_op); - - FunctionDef outer = CreateFunction( - "outer_function", {{"x", DT_INT32}}, - {{"mapdefun", DT_INT64}, {"mapdefun_0", DT_INT64}}, - {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}}); - - NodeDef* map_defun = - AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64, DT_INT64}, - {{}, {}}, inner.signature().name(), &outer); - CHECK_NOTNULL(map_defun); + FunctionDef inner = FunctionDefHelper::Create( + /*function_name=*/"inner_function", + /*in_def=*/{"arg0: int32"}, + /*out_def=*/{"ret0: int64", "ret1: int64"}, + /*attr_def=*/{}, + /*node_def=*/{Cast("Cast", {"arg0"}, DT_INT32, DT_INT64)}, + /*ret_def=*/{{"ret0", "Cast:y:0"}, {"ret1", "Cast:y:0"}}); FunctionDefLibrary lib; - *lib.add_function() = outer; - *lib.add_function() = inner; FunctionDef* vectorized; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); + TF_ASSERT_OK(WrapAndVectorize(inner, &lib, &vectorized)); + EXPECT_TRUE( !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); const NodeDef& cast_node = vectorized->node_def( function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); - EXPECT_EQ(cast_node.input(0), "x"); + EXPECT_EQ(cast_node.input(0), vectorized->signature().input_arg(0).name()); EXPECT_EQ(GetRetval(*vectorized, 0), strings::StrCat(cast_node.name(), ":y:0")); EXPECT_EQ(GetRetval(*vectorized, 1), @@ -386,42 +383,30 @@ TEST(VectorizeMapDefunTest, VectorizeWithOutputUsedTwice) { // +------+ +------+ +------+ // TEST(VectorizeMapDefunTest, VectorizeWithChainedConvertibleOps) { - FunctionDef inner = CreateFunction( - "inner_function", {{"arg0", DT_INT32}}, - {{"ret0", DT_INT32}, {"ret1", DT_INT32}, {"ret2", DT_INT32}}, + FunctionDef inner = FunctionDefHelper::Create( + /*function_name=*/"inner_function", + /*in_def=*/{"arg0: int32"}, + /*out_def=*/{"ret0: int32", "ret1: int32", "ret2: int32"}, + /*attr_def=*/{}, + /*node_def=*/ + {Cast("Cast", {"arg0"}, DT_INT32, DT_INT32), + {{"MyUnstack"}, + "Unpack", + {"Cast:y:0"}, + {{"T", DT_INT32}, {"axis", 0}, {"num", 3}}}}, + /*ret_def=*/ {{"ret0", "MyUnstack:output:0"}, {"ret1", "MyUnstack:output:1"}, {"ret2", "MyUnstack:output:2"}}); - NodeDef* cast_op = - AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT32, false, &inner); - CHECK_NOTNULL(cast_op); - NodeDef* unstack_op = - AddUnstackNode("MyUnstack", {"Cast:y:0"}, DT_INT32, 0, 3, &inner); - CHECK_NOTNULL(unstack_op); - - FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}}, - {{"mapdefun", DT_INT32}, - {"mapdefun_0", DT_INT32}, - {"mapdefun_1", DT_INT32}}, - {{"mapdefun", "MapDefun:output:0"}, - {"mapdefun_0", "MapDefun:output:1"}, - {"mapdefun_1", "MapDefun:output:2"}}); - - NodeDef* map_defun = AddMapDefunNode( - "MapDefun", {"x"}, {DT_INT32}, {DT_INT32, DT_INT32, DT_INT32}, - {{1}, {1}, {1}}, inner.signature().name(), &outer); - CHECK_NOTNULL(map_defun); FunctionDefLibrary lib; - *lib.add_function() = outer; - *lib.add_function() = inner; FunctionDef* vectorized; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); + TF_ASSERT_OK(WrapAndVectorize(inner, &lib, &vectorized)); EXPECT_TRUE( !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); const NodeDef& cast_node = vectorized->node_def( function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); - EXPECT_EQ(cast_node.input(0), "x"); + EXPECT_EQ(cast_node.input(0), vectorized->signature().input_arg(0).name()); const NodeDef& unpack_node = vectorized->node_def( function_utils::FindFunctionNodeWithOp("Unpack", *vectorized)); EXPECT_EQ(unpack_node.input(0), strings::StrCat(cast_node.name(), ":y:0")); @@ -470,33 +455,22 @@ TEST(VectorizeMapDefunTest, VectorizeWithChainedConvertibleOps) { // No change because we don't deal with control inputs for now. // TEST(VectorizeMapDefunTest, VectorizeWithControlInputs) { - FunctionDef inner = - CreateFunction("inner_function", {{"arg0", DT_INT32}}, - {{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}}); - NodeDef* print_op = function_utils::AddNode( - "Print", "Print", {"arg0", "arg0"}, {/*attrs*/}, &inner); - graph_transforms::SetNodeAttr("T", DT_INT32, print_op); - graph_transforms::SetNodeAttr("U", gtl::ArraySlice({DT_INT32}), - print_op); - CHECK_NOTNULL(print_op); - NodeDef* cast_op = AddCastNode("Cast", {"arg0", "^Print"}, DT_INT32, DT_INT64, - false, &inner); - CHECK_NOTNULL(cast_op); - - FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}}, - {{"mapdefun", DT_INT64}}, - {{"mapdefun", "MapDefun:output:0"}}); - - NodeDef* map_defun = - AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64}, {{}}, - inner.signature().name(), &outer); - CHECK_NOTNULL(map_defun); + FunctionDef inner = FunctionDefHelper::Create( + /*function_name=*/"inner_function", + /*in_def=*/{"arg0: int32"}, + /*out_def=*/{"ret0: int64"}, + /*attr_def=*/{}, + /*node_def=*/ + {{{"Print"}, + "Print", + {"arg0", "arg0"}, + {{"T", DT_INT32}, {"U", gtl::ArraySlice({DT_INT32})}}}, + Cast("Cast", {"arg0", "^Print"}, DT_INT32, DT_INT64)}, + /*ret_def=*/{{"ret0", "Cast:y:0"}}); FunctionDefLibrary lib; - *lib.add_function() = outer; - *lib.add_function() = inner; FunctionDef* vectorized; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); + TF_ASSERT_OK(WrapAndVectorize(inner, &lib, &vectorized)); // They should be unchanged // We check this somewhat manually as the names of nodes may have changed EXPECT_EQ(vectorized->node_def_size(), 1); @@ -571,24 +545,18 @@ TEST(VectorizeMapDefunTest, VectorizeWithControlInputs) { // TEST(VectorizeMapDefunTest, VectorizeWithUnstackedOutput) { FunctionDef inner = FunctionDefHelper::Create( - "inner_function", {"arg0: int32"}, {"ret0: int64"}, {/* attrs */}, - {/* nodes */ FunctionDefHelper::Const("Const", 2)}, - {{"ret0", "Cast:y:0"}}); - AddCastNode("Cast", {"Const:output:0"}, DT_INT32, DT_INT64, false, &inner); - - FunctionDef outer = FunctionDefHelper::Create( - "outer_function", {"outer_arg0: int32"}, {"mapdefun: int64"}, - {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}}); - - NodeDef* map_defun = - AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT64}, {{}}, - inner.signature().name(), &outer); + /*function_name=*/"inner_function", + /*in_def=*/{"arg0: int32"}, + /*out_def=*/{"ret0: int64"}, + /*attr_def=*/{}, + /*node_def=*/ + {FunctionDefHelper::Const("Const", 2), + Cast("Cast", {"Const:output:0"}, DT_INT32, DT_INT64)}, + /*ret_def=*/{{"ret0", "Cast:y:0"}}); FunctionDefLibrary lib; - *lib.add_function() = outer; - *lib.add_function() = inner; FunctionDef* vectorized; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); + TF_ASSERT_OK(WrapAndVectorize(inner, &lib, &vectorized)); EXPECT_TRUE( !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); auto const_node = vectorized->node_def( @@ -653,27 +621,19 @@ TEST(VectorizeMapDefunTest, VectorizeWithUnstackedOutput) { // TEST(VectorizeMapDefunTest, VectorizeWithUnstackedControl) { FunctionDef inner = FunctionDefHelper::Create( - "inner_function", {"arg0: int32"}, {"ret0: int64"}, {/* attrs */}, - {/* nodes */ FunctionDefHelper::Const("Const", 2), - FunctionDefHelper::Const("ConstDep", 3)}, - {{"ret0", "Cast:y:0"}}); - AddCastNode("Cast", {"Const:output:0", "^ConstDep"}, DT_INT32, DT_INT64, - false, &inner); - - FunctionDef outer = FunctionDefHelper::Create( - "outer_function", {"outer_arg0: int32"}, {"mapdefun: int64"}, - {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}}); - - NodeDef* map_defun = - AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT64}, {{}}, - inner.signature().name(), &outer); + /*function_name=*/"inner_function", + /*in_def=*/{"arg0: int32"}, + /*out_def=*/{"ret0: int64"}, + /*attr_def=*/{}, + /*node_def=*/ + {FunctionDefHelper::Const("Const", 2), + FunctionDefHelper::Const("ConstDep", 3), + Cast("Cast", {"Const:output:0", "^ConstDep"}, DT_INT32, DT_INT64)}, + /*ret_def=*/{{"ret0", "Cast:y:0"}}); FunctionDefLibrary lib; - *lib.add_function() = outer; - *lib.add_function() = inner; - FunctionDef* vectorized; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); + TF_ASSERT_OK(WrapAndVectorize(inner, &lib, &vectorized)); auto find_const = [vectorized](int val) -> const NodeDef* { for (const auto& n : vectorized->node_def()) { @@ -745,39 +705,29 @@ TEST(VectorizeMapDefunTest, VectorizeWithUnstackedControl) { // +------+ +------+ +------+ // TEST(VectorizerTest, VectorizeUnstack) { - FunctionDef inner = CreateFunction( - "inner_function", {{"arg0", DT_INT32}}, - {{"ret0", DT_INT32}, {"ret1", DT_INT32}, {"ret2", DT_INT32}}, + FunctionDef inner = FunctionDefHelper::Create( + /*function_name=*/"inner_function", + /*in_def=*/{"arg0: int32"}, + /*out_def=*/{"ret0: int32", "ret1: int32", "ret2: int32"}, + /*attr_def=*/{}, + /*node_def=*/ + {{{"MyUnstack"}, + "Unpack", + {"arg0"}, + {{"T", DT_INT32}, {"axis", 0}, {"num", 3}}}}, + /*ret_def=*/ {{"ret0", "MyUnstack:output:0"}, {"ret1", "MyUnstack:output:1"}, {"ret2", "MyUnstack:output:2"}}); - NodeDef* unstack_op = - AddUnstackNode("MyUnstack", {"arg0"}, DT_INT32, 0, 3, &inner); - CHECK_NOTNULL(unstack_op); - - FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}}, - {{"mapdefun", DT_INT32}, - {"mapdefun_0", DT_INT32}, - {"mapdefun_1", DT_INT32}}, - {{"mapdefun", "MapDefun:output:0"}, - {"mapdefun_0", "MapDefun:output:1"}, - {"mapdefun_1", "MapDefun:output:2"}}); - - NodeDef* map_defun = AddMapDefunNode( - "MapDefun", {"x"}, {DT_INT32}, {DT_INT32, DT_INT32, DT_INT32}, - {{1}, {1}, {1}}, inner.signature().name(), &outer); - CHECK_NOTNULL(map_defun); FunctionDefLibrary lib; - *lib.add_function() = outer; - *lib.add_function() = inner; FunctionDef* vectorized; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); + TF_ASSERT_OK(WrapAndVectorize(inner, &lib, &vectorized)); EXPECT_TRUE( !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); const NodeDef& unpack_node = vectorized->node_def( function_utils::FindFunctionNodeWithOp("Unpack", *vectorized)); - EXPECT_EQ(unpack_node.input(0), "x"); + EXPECT_EQ(unpack_node.input(0), vectorized->signature().input_arg(0).name()); EXPECT_EQ(unpack_node.attr().at("axis").i(), 1); EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32); EXPECT_EQ(unpack_node.attr().at("num").i(), 3); @@ -830,32 +780,22 @@ TEST(VectorizerTest, VectorizeUnstack) { // +------+ // TEST(VectorizerTest, VectorizeCast) { - FunctionDef inner = - CreateFunction("inner_function", {{"arg0", DT_INT32}}, - {{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}}); - NodeDef* cast_op = - AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner); - CHECK_NOTNULL(cast_op); - - FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}}, - {{"mapdefun", DT_INT64}}, - {{"mapdefun", "MapDefun:output:0"}}); - - NodeDef* map_defun = - AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64}, {{}}, - inner.signature().name(), &outer); - CHECK_NOTNULL(map_defun); + FunctionDef inner = FunctionDefHelper::Create( + /*function_name=*/"inner_function", + /*in_def=*/{"arg0: int32"}, + /*out_def=*/{"ret0: int64"}, + /*attr_def=*/{}, + /*node_def=*/{Cast("Cast", {"arg0"}, DT_INT32, DT_INT64)}, + /*ret_def=*/{{"ret0", "Cast:y:0"}}); FunctionDefLibrary lib; - *lib.add_function() = outer; - *lib.add_function() = inner; FunctionDef* vectorized; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); + TF_ASSERT_OK(WrapAndVectorize(inner, &lib, &vectorized)); EXPECT_TRUE( !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); const NodeDef& cast_node = vectorized->node_def( function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); - EXPECT_EQ(cast_node.input(0), "x"); + EXPECT_EQ(cast_node.input(0), vectorized->signature().input_arg(0).name()); EXPECT_EQ(GetRetval(*vectorized, 0), strings::StrCat(cast_node.name(), ":y:0")); EXPECT_EQ(vectorized->node_def_size(), 1); @@ -921,73 +861,22 @@ TEST(VectorizerTest, VectorizeAdd) { // tensorflow/python/data/experimental/kernel_tests/optimization/ // map_vectorization_test.py FunctionDef inner = FunctionDefHelper::Create( - "inner_function", {"arg0: int32"}, {"ret0: int32"}, {/* attrs */}, - {/* nodes */ FunctionDefHelper::Const("Const", 2), + /*function_name=*/"inner_function", + /*in_def=*/{"arg0: int32"}, + /*out_def=*/{"ret0: int32"}, + /*attr_def=*/{}, + /*node_def=*/ + {FunctionDefHelper::Const("Const", 2), {{"Add"}, "Add", {"arg0", "Const:output:0"}, {{"T", DT_INT32}}}}, - {{"ret0", "Add:z:0"}}); - - FunctionDef outer = FunctionDefHelper::Create( - "outer_function", {"outer_arg0: int32"}, {"mapdefun: int32"}, - {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}}); - - NodeDef* map_defun = - AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT32}, {{}}, - inner.signature().name(), &outer); - CHECK_NOTNULL(map_defun); + /*ret_def=*/{{"ret0", "Add:z:0"}}); FunctionDefLibrary lib; - *lib.add_function() = outer; - *lib.add_function() = inner; FunctionDef* vectorized; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); + TF_ASSERT_OK(WrapAndVectorize(inner, &lib, &vectorized)); EXPECT_TRUE( !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); } -// Wraps inner function with a MapDefun node. -Status WrapFunctionWithMapDefun(const FunctionDef& inner, FunctionDef* result) { - Graph graph(OpRegistry::Global()); - std::vector inputs; - inputs.reserve(inner.signature().input_arg_size()); - for (int i = 0; i < inner.signature().input_arg_size(); ++i) { - Node* arg; - TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat("arg", i), "_Arg") - .Attr("T", inner.signature().input_arg(i).type()) - .Attr("index", i) - .Finalize(&graph, &arg)); - inputs.push_back(arg); - } - - DataTypeVector output_types; - output_types.reserve(inner.signature().output_arg_size()); - for (const auto& output_arg : inner.signature().output_arg()) { - output_types.push_back(output_arg.type()); - } - - Node* map_defun_node; - NameAttrList func_attr; - func_attr.set_name(inner.signature().name()); - TF_RETURN_IF_ERROR( - NodeBuilder("map_defun", "MapDefun") - .Input(inputs) - .Input(std::vector({})) // captured - .Attr("f", func_attr) - .Attr("output_types", output_types) - .Attr("output_shapes", std::vector( - inner.signature().output_arg_size())) - .Finalize(&graph, &map_defun_node)); - - for (size_t i = 0; i < map_defun_node->num_outputs(); ++i) { - Node* ret; - TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat("ret", i), "_Retval") - .Input(map_defun_node, i) - .Attr("index", static_cast(i)) - .Finalize(&graph, &ret)); - } - - return GraphToFunctionDef(graph, "outer_function", result); -} - // Tests that a function which applies a cwise op can be vectorized completely. Status CwiseTestHelper(DataType input_type, const string& op_type, size_t arity) { @@ -1022,16 +911,9 @@ Status CwiseTestHelper(DataType input_type, const string& op_type, TF_RETURN_IF_ERROR(GraphToFunctionDef(graph, "inner_function", &inner)); - FunctionDef outer; - TF_RETURN_IF_ERROR(WrapFunctionWithMapDefun(inner, &outer)); - - const NodeDef* map_defun = &outer.node_def(0); - FunctionDefLibrary lib; FunctionDef* vectorized; - *lib.add_function() = outer; - *lib.add_function() = inner; - TF_RETURN_IF_ERROR(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); + TF_RETURN_IF_ERROR(WrapAndVectorize(inner, &lib, &vectorized)); return function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized) ? errors::Internal( @@ -1179,29 +1061,21 @@ INSTANTIATE_TEST_CASE_P( // TEST(VectorizerTest, VectorizeReshape) { FunctionDef inner = FunctionDefHelper::Create( - "inner_function", {"arg0: int32"}, {"ret0: int32"}, {/* attrs */}, - {/* nodes */ FunctionDefHelper::Const("Const", - gtl::ArraySlice({3, 3, 3})), + /*function_name=*/"inner_function", + /*in_def=*/{"arg0: int32"}, + /*out_def=*/{"ret0: int32"}, + /*attr_def=*/{}, + /*node_def=*/ + {FunctionDefHelper::Const("Const", gtl::ArraySlice({3, 3, 3})), {{"Reshape"}, "Reshape", {"arg0", "Const:output:0"}, {{"T", DT_INT32}, {"Tshape", DT_INT32}}}}, - {{"ret0", "Reshape:output:0"}}); - - FunctionDef outer = FunctionDefHelper::Create( - "outer_function", {"outer_arg0: int32"}, {"mapdefun: int32"}, - {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}}); - - NodeDef* map_defun = - AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT32}, {{}}, - inner.signature().name(), &outer); - CHECK_NOTNULL(map_defun); + /*ret_def=*/{{"ret0", "Reshape:output:0"}}); FunctionDefLibrary lib; - *lib.add_function() = outer; - *lib.add_function() = inner; FunctionDef* vectorized; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); + TF_ASSERT_OK(WrapAndVectorize(inner, &lib, &vectorized)); EXPECT_TRUE( !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); EXPECT_TRUE( @@ -1275,31 +1149,23 @@ TEST(VectorizerTest, VectorizeReshape) { // TEST(VectorizerTest, VectorizeDecodeCSV) { FunctionDef inner = FunctionDefHelper::Create( - "inner_function", {"arg0: string"}, {"ret0: int32", "ret1: string"}, - {/* attrs */}, + /*function_name=*/"inner_function", + /*in_def=*/{"arg0: string"}, + /*out_def=*/{"ret0: int32", "ret1: string"}, + /*attr_def=*/{}, + /*node_def=*/ {FunctionDefHelper::Const("Default0", gtl::ArraySlice({2})), FunctionDefHelper::Const("Default1", gtl::ArraySlice({})), {{"DecodeCSV"}, "DecodeCSV", {"arg0", "Default0:output:0", "Default1:output:0"}, {{"OUT_TYPE", DataTypeVector({DT_INT32, DT_STRING})}}}}, + /*ret_def=*/ {{"ret0", "DecodeCSV:output:0"}, {"ret1", "DecodeCSV:output:1"}}); - FunctionDef outer = FunctionDefHelper::Create( - "outer_function", {"outer_arg0: string"}, - {"mapdefun: int32", "mapdefun_0: string"}, {/* attrs */}, {/* nodes */}, - {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}}); - - NodeDef* map_defun = AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_STRING}, - {DT_INT32, DT_STRING}, {{}, {}}, - inner.signature().name(), &outer); - CHECK_NOTNULL(map_defun); - FunctionDefLibrary lib; - *lib.add_function() = outer; - *lib.add_function() = inner; FunctionDef* vectorized; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); + TF_ASSERT_OK(WrapAndVectorize(inner, &lib, &vectorized)); EXPECT_TRUE( !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); } @@ -1308,31 +1174,21 @@ TEST(VectorizerTest, VectorizeDecodeCSVWithStackedDefaults) { // When the `record_defaults` input to DecodeCSV are stacked, // the node should not be vectorized. FunctionDef inner = FunctionDefHelper::Create( - "inner_function", {"arg0: string", "arg1: int32", "arg2: string"}, - {"ret0: int32", "ret1: string"}, {/* attrs */}, + /*function_name=*/"inner_function", + /*in_def=*/{"arg0: string", "arg1: int32", "arg2: string"}, + /*out_def=*/{"ret0: int32", "ret1: string"}, + /*attr_def=*/{}, + /*node_def=*/ {{{"DecodeCSV"}, "DecodeCSV", {"arg0", "arg1", "arg2"}, // Inputs come from args, which are "stacked" {{"OUT_TYPE", DataTypeVector({DT_INT32, DT_STRING})}}}}, + /*ret_def=*/ {{"ret0", "DecodeCSV:output:0"}, {"ret1", "DecodeCSV:output:1"}}); - FunctionDef outer = FunctionDefHelper::Create( - "outer_function", - {"outer_arg0: string", "outer_arg1: int32", "outer_arg2: string"}, - {"mapdefun: int32", "mapdefun_0: string"}, {/* attrs */}, {/* nodes */}, - {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}}); - - NodeDef* map_defun = - AddMapDefunNode("MapDefun", {"outer_arg0", "outer_arg1", "outer_arg2"}, - {DT_STRING, DT_INT32, DT_STRING}, {DT_INT32, DT_STRING}, - {{}, {}}, inner.signature().name(), &outer); - CHECK_NOTNULL(map_defun); - FunctionDefLibrary lib; - *lib.add_function() = outer; - *lib.add_function() = inner; FunctionDef* vectorized; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); + TF_ASSERT_OK(WrapAndVectorize(inner, &lib, &vectorized)); EXPECT_TRUE( function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); } @@ -1403,10 +1259,13 @@ TEST(VectorizerTest, VectorizeDecodeCSVWithStackedDefaults) { // TEST(VectorizerTest, VectorizeParseSingleExample) { FunctionDef inner = FunctionDefHelper::Create( - "inner_function", {"arg0: string"}, + /*function_name=*/"inner_function", + /*in_def=*/{"arg0: string"}, + /*out_def=*/ {"si0: int64", "si1: int64", "sv0: int64", "sv1: string", "ss0: int64", "ss1: int64", "dv0: int64", "dv1: string"}, - {/* attrs */}, + /*attr_def=*/{}, + /*node_def=*/ {FunctionDefHelper::Const("DenseIntDefault", static_cast(0)), FunctionDefHelper::Const("DenseStrDefault", string("")), {{"Parse"}, @@ -1420,6 +1279,7 @@ TEST(VectorizerTest, VectorizeParseSingleExample) { {"sparse_keys", gtl::ArraySlice({"spar_int", "spar_str"})}, {"sparse_types", DataTypeVector({DT_INT64, DT_STRING})}, }}}, + /*ret_def=*/ { {"si0", "Parse:sparse_indices:0"}, {"si1", "Parse:sparse_indices:1"}, @@ -1431,34 +1291,9 @@ TEST(VectorizerTest, VectorizeParseSingleExample) { {"dv1", "Parse:dense_values:1"}, }); - FunctionDef outer = FunctionDefHelper::Create( - "outer_function", {"outer_arg0: string"}, - {"si0: int64", "si1: int64", "sv0: int64", "sv1: string", "ss0: int64", - "ss1: int64", "dv0: int64", "dv1: string"}, - {/* attrs */}, {/* nodes */}, - { - {"si0", "MapDefun:output:0"}, - {"si1", "MapDefun:output:1"}, - {"sv0", "MapDefun:output:2"}, - {"sv1", "MapDefun:output:3"}, - {"ss0", "MapDefun:output:4"}, - {"ss1", "MapDefun:output:5"}, - {"dv0", "MapDefun:output:6"}, - {"dv1", "MapDefun:output:7"}, - }); - - NodeDef* map_defun = AddMapDefunNode( - "MapDefun", {"outer_arg0"}, {DT_STRING}, - {DT_INT64, DT_INT64, DT_INT64, DT_STRING, DT_INT64, DT_INT64, DT_INT64, - DT_STRING}, - std::vector(8), inner.signature().name(), &outer); - CHECK_NOTNULL(map_defun); - FunctionDefLibrary lib; - *lib.add_function() = outer; - *lib.add_function() = inner; FunctionDef* vectorized; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); + TF_ASSERT_OK(WrapAndVectorize(inner, &lib, &vectorized)); EXPECT_TRUE( !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); EXPECT_TRUE( @@ -1467,8 +1302,11 @@ TEST(VectorizerTest, VectorizeParseSingleExample) { TEST(VectorizerTest, VectorizeParseSingleExampleWithStackedDefaults) { FunctionDef inner = FunctionDefHelper::Create( - "inner_function", {"arg0: string", "arg1: string"}, - {"dv0: int64", "dv1: string"}, {/* attrs */}, + /*function_name=*/"inner_function", + /*in_def=*/{"arg0: string", "arg1: string"}, + /*out_def=*/{"dv0: int64", "dv1: string"}, + /*attr_def=*/{}, + /*node_def=*/ {FunctionDefHelper::Const("DenseIntDefault", static_cast(0)), {{"Parse"}, "ParseSingleExample", @@ -1481,33 +1319,67 @@ TEST(VectorizerTest, VectorizeParseSingleExampleWithStackedDefaults) { {"sparse_keys", gtl::ArraySlice({})}, {"sparse_types", DataTypeVector({})}, }}}, + /*ret_def=*/ { {"dv0", "Parse:dense_values:0"}, {"dv1", "Parse:dense_values:1"}, }); - FunctionDef outer = FunctionDefHelper::Create( - "outer_function", {"outer_arg0: string", "outer_arg1: string"}, - {"dv0: int64", "dv1: string"}, {/* attrs */}, {/* nodes */}, - { - {"dv0", "MapDefun:output:0"}, - {"dv1", "MapDefun:output:1"}, - }); + FunctionDefLibrary lib; + FunctionDef* vectorized; + TF_ASSERT_OK(WrapAndVectorize(inner, &lib, &vectorized)); + EXPECT_TRUE( + function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); +} - NodeDef* map_defun = AddMapDefunNode( - "MapDefun", {"outer_arg0", "outer_arg1"}, {DT_STRING, DT_STRING}, - {DT_INT64, DT_STRING}, std::vector(8), - inner.signature().name(), &outer); - CHECK_NOTNULL(map_defun); +TEST(VectorizerTest, VectorizeTranspose) { + FunctionDef inner = FunctionDefHelper::Create( + /*function_name=*/"inner_function", + /*in_def=*/{"arg0: int32"}, + /*out_def=*/{"out: int32"}, + /*attr_def=*/{}, + /*node_def=*/ + {FunctionDefHelper::Const("Perm", gtl::ArraySlice({1, 0})), + {{"Transpose"}, + "Transpose", + {"arg0", "Perm:output:0"}, + {{"T", DT_INT32}, {"Tperm", DT_INT32}}}}, + /*ret_def=*/{{"out", "Transpose:y:0"}}); FunctionDefLibrary lib; - *lib.add_function() = outer; - *lib.add_function() = inner; FunctionDef* vectorized; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); - EXPECT_TRUE( + TF_ASSERT_OK(WrapAndVectorize(inner, &lib, &vectorized)); + EXPECT_FALSE( + function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); +} + +TEST(VectorizerTest, VectorizeIdentity) { + FunctionDef inner = FunctionDefHelper::Create( + /*function_name=*/"inner_function", + /*in_def=*/{"arg0: int32"}, + /*out_def=*/{"ret0: int32"}, + /*attr_def=*/{}, + /*node_def=*/{{{"Identity"}, "Identity", {"arg0"}, {{"T", DT_INT32}}}}, + /*ret_def=*/{{"ret0", "Identity:output:0"}}); + + FunctionDefLibrary lib; + FunctionDef* vectorized; + TF_ASSERT_OK(WrapAndVectorize(inner, &lib, &vectorized)); + + EXPECT_FALSE( function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + ASSERT_TRUE( + function_utils::ContainsFunctionNodeWithOp("Identity", *vectorized)); + const NodeDef& identity_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Identity", *vectorized)); + + EXPECT_EQ(identity_node.input(0), + vectorized->signature().input_arg(0).name()); + EXPECT_EQ(GetRetval(*vectorized, 0), + strings::StrCat(identity_node.name(), ":output:0")); + EXPECT_EQ(vectorized->node_def_size(), 1); } + } // namespace } // namespace vectorization_utils } // namespace grappler diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc index 0938c27b1f3338..7fee3ae9d51bcd 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc @@ -57,7 +57,7 @@ bool RemoveInput(NodeDef* node, const string& input, NodeMap* node_map) { } // namespace bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) const { - if (!IsIdentity(node) && !IsIdentityNSingleInput(node)) { + if (!IsIdentity(node) && !IsIdentityN(node)) { return true; } @@ -133,15 +133,53 @@ bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) const { return true; } +int DependencyOptimizer::NumEdgesIfBypassed( + const NodeDef& node, const std::vector& output_nodes) const { + const bool is_multi_input_identity_n = + IsIdentityN(node) && !IsIdentityNSingleInput(node); + const int num_outputs = output_nodes.size(); + const int num_inputs = node.input_size(); + + if (is_multi_input_identity_n) { + // multi-input identity_n with input/output control dependencies will likely + // increase number of edges after optimization. + int num_edges_if_bypassed(0); + for (string input_node_name : node.input()) { + if (IsControlInput(input_node_name)) { + num_edges_if_bypassed += num_outputs; + } else { + ++num_edges_if_bypassed; + } + } + + for (auto consumer : output_nodes) { + for (int j = 0; j < consumer->input_size(); ++j) { + const TensorId consumer_input = ParseTensorName(consumer->input(j)); + if (consumer_input.node() == node.name()) { + if (IsControlInput(consumer_input)) { + num_edges_if_bypassed += num_inputs; + } else { + ++num_edges_if_bypassed; + } + } + } + } + return num_edges_if_bypassed; + } else { + return num_inputs * num_outputs; + } +} + bool DependencyOptimizer::BypassingNodeIsBeneficial( const NodeDef& node, const std::vector& input_nodes, const std::vector& output_nodes) const { const bool is_identity = IsIdentity(node) || IsIdentityNSingleInput(node); + const bool is_multi_input_identity_n = + IsIdentityN(node) && !IsIdentityNSingleInput(node); const int num_outputs = output_nodes.size(); const int num_inputs = node.input_size(); - // Don't increase the number of edges in the graph. - if (num_inputs * num_outputs > num_inputs + num_outputs) { + if (NumEdgesIfBypassed(node, output_nodes) > num_inputs + num_outputs) { return false; } @@ -166,7 +204,9 @@ bool DependencyOptimizer::BypassingNodeIsBeneficial( for (NodeDef* output_node : output_nodes) { num_cross_out += static_cast(output_node->device() != node_dev); } - if (is_identity && num_cross_in > 0 && num_cross_out > 0) { + + if ((is_identity || is_multi_input_identity_n) && num_cross_in > 0 && + num_cross_out > 0) { // This identity node follows a device crossing, so it might be // following a _Recv node after partioning. Do not remove such nodes, // unless they only have consumers on the same device as themselves. @@ -194,6 +234,8 @@ void DependencyOptimizer::OptimizeNode(int node_idx, NodeDef* node = optimized_graph_->mutable_node(node_idx); const bool is_noop = IsNoOp(*node); const bool is_identity = IsIdentity(*node) || IsIdentityNSingleInput(*node); + const bool is_multi_input_identity = + IsIdentityN(*node) && !IsIdentityNSingleInput(*node); const string node_name = node->name(); // Constant nodes with no input control dependency are always executed early, // so we can prune all their output control dependencies. @@ -203,11 +245,9 @@ void DependencyOptimizer::OptimizeNode(int node_idx, bool optimize_fanout = false; bool data_connection = false; for (int i = fanout->input_size() - 1; i >= 0; --i) { - int pos; - StringPiece input_name = - ParseNodeNameAsStringPiece(fanout->input(i), &pos); - if (input_name == node_name) { - if (pos < 0) { + const TensorId input_tensor = ParseTensorName(fanout->input(i)); + if (input_tensor.node() == node_name) { + if (input_tensor.index() < 0) { fanout->mutable_input()->SwapElements(i, fanout->input_size() - 1); fanout->mutable_input()->RemoveLast(); optimize_fanout = true; @@ -315,7 +355,8 @@ void DependencyOptimizer::OptimizeNode(int node_idx, // y --^> | | --^> b /\ +---+ // +----------+ y --^> b - if (is_noop || (is_identity && SafeToRemoveIdentity(*node))) { + if (is_noop || ((is_identity || is_multi_input_identity) && + SafeToRemoveIdentity(*node))) { const auto& output_node_set = node_map_->GetOutputs(node_name); const std::vector output_nodes(output_node_set.begin(), output_node_set.end()); @@ -343,34 +384,30 @@ void DependencyOptimizer::OptimizeNode(int node_idx, const NodeDef* input = input_nodes[i]; // Forward dependency from input to consumer if it doesn't already // depend on it. - if (is_identity && i == 0) { + if ((is_identity && i == 0) || + (is_multi_input_identity && !IsControlInput(node->input(i)))) { // Replace regular input from Identity node. - bool found_input = false; string new_input; - const string& input_to_forward = node->input(0); + const string& input_to_forward = node->input(i); CHECK(!IsControlInput(input_to_forward)); for (int j = 0; j < consumer->input_size(); ++j) { - const string& old_input = consumer->input(j); - int old_input_pos; - StringPiece old_input_node_name = - ParseNodeNameAsStringPiece(old_input, &old_input_pos); - if (old_input_node_name == node_name) { - if (old_input_pos >= 0) { + const TensorId old_input = ParseTensorName(consumer->input(j)); + if (old_input.node() == node_name) { + if (old_input.index() == i) { // Regular input new_input = input_to_forward; - node_map_->UpdateInput(consumer->name(), old_input, new_input); + node_map_->UpdateInput(consumer->name(), old_input.ToString(), + new_input); consumer->set_input(j, new_input); - found_input = true; - } else { + } else if (old_input.index() == -1) { // Control dependency new_input = AsControlDependency(NodeName(input_to_forward)); - node_map_->UpdateInput(consumer->name(), old_input, new_input); + node_map_->UpdateInput(consumer->name(), old_input.ToString(), + new_input); consumer->set_input(j, new_input); - found_input = true; } } } - CHECK(found_input); updated_consumer = true; } else { // Forward dependency from input to consumer if it doesn't already @@ -415,7 +452,7 @@ Status DependencyOptimizer::OptimizeDependencies() { std::set nodes_to_delete; for (int i = 0; i < optimized_graph_->node_size(); ++i) { const NodeDef& node = optimized_graph_->node(i); - if (IsNoOp(node) || IsIdentity(node) || IsIdentityNSingleInput(node) || + if (IsNoOp(node) || IsIdentity(node) || IsIdentityN(node) || IsConstant(node) || SafeToConvertToNoOp(node)) { nodes_to_simplify.PushBack(i); } diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.h b/tensorflow/core/grappler/optimizers/dependency_optimizer.h index 48cfa236af847a..7b032673fb3456 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.h +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.h @@ -48,7 +48,8 @@ class DependencyOptimizer : public GraphOptimizer { bool BypassingNodeIsBeneficial( const NodeDef& node, const std::vector& input_nodes, const std::vector& output_nodes) const; - + int NumEdgesIfBypassed(const NodeDef& node, + const std::vector& output_nodes) const; // Returns true if node is not an Identity node or if it is an Identity // that is safe to remove. bool SafeToRemoveIdentity(const NodeDef& node) const; diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc index c0f07562affcde..8d70d9d5c73690 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc @@ -634,7 +634,7 @@ TEST_F(DependencyOptimizerTest, IdentityInputs) { EXPECT_EQ("s:1", output.node(5).input(0)); } -TEST_F(DependencyOptimizerTest, IdentityN) { +TEST_F(DependencyOptimizerTest, RemoveIdentityN_SwitchInput) { tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); Output b = ops::Placeholder(scope.WithOpName("b"), DT_BOOL); Output x = ops::RandomUniform(scope.WithOpName("x"), {1, 2}, DT_FLOAT); @@ -643,8 +643,6 @@ TEST_F(DependencyOptimizerTest, IdentityN) { // IdentityN nodes to be removed. auto id_f = ops::IdentityN(scope.WithOpName("id_f"), {s.output_false}); auto id_t = ops::IdentityN(scope.WithOpName("id_t"), {s.output_true}); - - // IdentityN node that can't be removed. auto id_b = ops::IdentityN(scope.WithOpName("id_b"), {s.output_false, s.output_true}); @@ -663,22 +661,50 @@ TEST_F(DependencyOptimizerTest, IdentityN) { Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - EXPECT_EQ(9, output.node_size()); - EXPECT_EQ("out1", output.node(5).name()); - EXPECT_EQ(1, output.node(5).input_size()); - EXPECT_EQ("s", output.node(5).input(0)); + EXPECT_EQ(8, output.node_size()); + + auto out1_node = output.node(7); + EXPECT_EQ("out1", out1_node.name()); + EXPECT_EQ(1, out1_node.input_size()); + EXPECT_EQ("s", out1_node.input(0)); + + auto out2_node = output.node(4); + EXPECT_EQ("out2", out2_node.name()); + EXPECT_EQ(1, out2_node.input_size()); + EXPECT_EQ("s:1", out2_node.input(0)); + + auto out3_node = output.node(5); + EXPECT_EQ("out3", out3_node.name()); + EXPECT_EQ(1, out3_node.input_size()); + EXPECT_EQ("s", out3_node.input(0)); + + auto out4_node = output.node(6); + EXPECT_EQ("out4", out4_node.name()); + EXPECT_EQ(1, out4_node.input_size()); + EXPECT_EQ("s:1", out4_node.input(0)); +} - EXPECT_EQ("out2", output.node(6).name()); - EXPECT_EQ(1, output.node(6).input_size()); - EXPECT_EQ("s:1", output.node(6).input(0)); +TEST_F(DependencyOptimizerTest, DoNotRemoveIdentityNWithControlDependency) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + Output input1 = ops::Placeholder(scope.WithOpName("input1"), DT_BOOL); + Output input2 = ops::Const(scope.WithOpName("input2"), {1, 2}); + + auto id_n = ops::IdentityN(scope.WithOpName("id_n"), {input1, input2}); + Output out1 = ops::Identity(scope.WithOpName("out1"), id_n[0]); + Output out2 = ops::Identity(scope.WithOpName("out2"), id_n[1]); + auto out3 = + ops::NoOp(scope.WithOpName("out3").WithControlDependencies(id_n[1])); - EXPECT_EQ("out3", output.node(7).name()); - EXPECT_EQ(1, output.node(7).input_size()); - EXPECT_EQ("id_b", output.node(7).input(0)); + GrapplerItem item; + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + item.fetch = {"out1", "out2", "out3"}; + + DependencyOptimizer optimizer; + GraphDef optimized_graph_def; + Status status = optimizer.Optimize(nullptr, item, &optimized_graph_def); + TF_EXPECT_OK(status); - EXPECT_EQ("out4", output.node(8).name()); - EXPECT_EQ(1, output.node(8).input_size()); - EXPECT_EQ("id_b:1", output.node(8).input(0)); + EXPECT_EQ(6, optimized_graph_def.node_size()); } TEST_F(DependencyOptimizerTest, diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc index cc04ed3340bcd4..f99826ddcad1fe 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc @@ -31,8 +31,8 @@ limitations under the License. #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph_constructor.h" -#include "tensorflow/core/grappler/graph_view.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/functions.h" @@ -219,8 +219,7 @@ class FunctionOptimizerContext { : grappler_item_id_(item.id), graph_version_(item.graph.versions().producer()), function_library_(OpRegistry::Global(), item.graph.library()), - // GraphView doesn't not modify the graph or the nodes. - graph_view_(const_cast(&item.graph)) { + graph_view_(&item.graph) { InitializeTrulyConstNodes(item); InitializeInlinedFunctions(opt_level, item); InitializeFetchNodes(item); @@ -1133,7 +1132,7 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // Function specialization might change the number of function outputs, so we // have to process the final optimized graph and update all the node mapping. if (ctx.RequiresOutputMapping()) { - GraphView optimized_graph_view(optimized_graph); + MutableGraphView optimized_graph_view(optimized_graph); for (const auto& output_mapping : ctx.output_mappings()) { const auto& node_name = output_mapping.first; const auto& mappings = output_mapping.second; @@ -1143,11 +1142,11 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, int to = mapping.second; // Get the output port corresponding to the old output position. - GraphView::OutputPort from_port = + MutableGraphView::OutputPort from_port = optimized_graph_view.GetOutputPort(node_name, from); // Update all input ports that read from old output port. - for (GraphView::InputPort to_port : + for (MutableGraphView::InputPort to_port : optimized_graph_view.GetFanout(from_port)) { *to_port.node->mutable_input(to_port.port_id) = strings::StrCat(node_name, ":", to); diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc index 1ea57f7b4f003e..82c408b521f58b 100644 --- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc +++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h" +#include "tensorflow/core/graph/tensor_id.h" namespace tensorflow { namespace grappler { @@ -46,25 +47,27 @@ Status GetTensorProperties(const GraphOptimizerContext& ctx, return errors::InvalidArgument("Graph properties are unknown."); } - int port; - string tensor_node_name = ParseNodeName(tensor, &port); - if (port < 0) { + // TODO(ezhulenev): Make it TensorId when graph properties will support + // absl::string_view lookup. + SafeTensorId tensor_id = ParseTensorName(tensor); + + if (tensor_id.index() < 0) { return errors::InvalidArgument( "Can't get tensor properties of control dependency ", tensor); } const auto& output_properties = - ctx.graph_properties->GetOutputProperties(tensor_node_name); + ctx.graph_properties->GetOutputProperties(tensor_id.node()); auto num_outputs = output_properties.size(); - if (num_outputs == 0 || port > num_outputs - 1) { + if (num_outputs == 0 || tensor_id.index() > num_outputs - 1) { return errors::InvalidArgument( - "Node ", tensor_node_name, - " is missing output properties at position :", port, + "Node ", tensor_id.node(), + " is missing output properties at position :", tensor_id.index(), " (num_outputs=", num_outputs, ")"); } - properties->CopyFrom(output_properties[port]); + properties->CopyFrom(output_properties[tensor_id.index()]); return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc index c74a94094946e2..775fb9a95f2a71 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/attr_value.pb.h" @@ -29,8 +30,8 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/grappler/graph_view.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/constant_folding.h" #include "tensorflow/core/grappler/optimizers/evaluation_utils.h" @@ -565,13 +566,14 @@ Status EvaluateBoolOpForConstantOperands(const NodeDef& op_node, return Status::OK(); } -Status CheckForDeadFanout(const GraphView& view, const NodeDef& switch_node, - const NodeMap& node_map, DeviceBase* cpu_device, - ResourceMgr* resource_mgr, bool* has_dead_fanout, - int* dead_fanout) { +Status CheckForDeadFanout(const MutableGraphView& view, + const NodeDef& switch_node, const NodeMap& node_map, + DeviceBase* cpu_device, ResourceMgr* resource_mgr, + bool* has_dead_fanout, int* dead_fanout) { *has_dead_fanout = false; GraphView::InputPort switch_loopcond_port(&switch_node, 1); - NodeDef* switch_predicate = view.GetRegularFanin(switch_loopcond_port).node; + const NodeDef* switch_predicate = + view.GetRegularFanin(switch_loopcond_port).node; // CASE 1: Control is a constant. if (IsConstant(*switch_predicate)) { @@ -582,7 +584,7 @@ Status CheckForDeadFanout(const GraphView& view, const NodeDef& switch_node, } GraphView::InputPort switch_input_port(&switch_node, 0); - NodeDef* switch_input = view.GetRegularFanin(switch_input_port).node; + const NodeDef* switch_input = view.GetRegularFanin(switch_input_port).node; // CASE 2: Zero-iteration while loop. // We check if its a while loop such that the condition is a simple binary @@ -707,10 +709,9 @@ Status LoopOptimizer::RemoveDeadBranches( std::unordered_map> dead_merge_inputs; // TODO(bsteiner): also rewrite switches as identity. For now we just record // them - std::unordered_set - identity_switches; + absl::flat_hash_set identity_switches; - GraphView view(optimized_graph); + MutableGraphView view(optimized_graph); for (const NodeDef& node : optimized_graph->node()) { if (!IsSwitch(node)) { continue; @@ -727,11 +728,12 @@ Status LoopOptimizer::RemoveDeadBranches( if (!has_dead_fanout) { continue; } - GraphView::OutputPort dead(const_cast(&node), dead_fanout); + GraphView::OutputPort dead(&node, dead_fanout); identity_switches.insert(dead); - SetVector zombie_inputs; - for (const GraphView::InputPort& port : view.GetFanout(dead)) { + SetVector> + zombie_inputs; + for (const MutableGraphView::InputPort& port : view.GetFanout(dead)) { if (dead_nodes.find(port.node) == dead_nodes.end()) { zombie_inputs.PushBack(port); } @@ -745,7 +747,7 @@ Status LoopOptimizer::RemoveDeadBranches( dead_merge_inputs; bool found_node_to_preserve = false; while (!found_node_to_preserve && !zombie_inputs.Empty()) { - GraphView::InputPort dead = zombie_inputs.PopBack(); + MutableGraphView::InputPort dead = zombie_inputs.PopBack(); if (nodes_to_preserve.find(dead.node->name()) != nodes_to_preserve.end()) { found_node_to_preserve = true; @@ -764,9 +766,9 @@ Status LoopOptimizer::RemoveDeadBranches( found_node_to_preserve = true; break; } - GraphView::OutputPort value_index(dead.node, 1); - const std::unordered_set& - index_fanout = view.GetFanout(value_index); + MutableGraphView::OutputPort value_index(dead.node, 1); + const absl::flat_hash_set& index_fanout = + view.GetFanout(value_index); if (!index_fanout.empty()) { // The 2nd output (that indicates which input is propagated) is // connected. This never happens in practice, so we'll just skip this @@ -789,7 +791,7 @@ Status LoopOptimizer::RemoveDeadBranches( } if (fully_dead) { local_dead_nodes.insert(dead.node); - for (const GraphView::InputPort& port : + for (const MutableGraphView::InputPort& port : view.GetFanouts(*dead.node, true)) { zombie_inputs.PushBack(port); } @@ -800,7 +802,7 @@ Status LoopOptimizer::RemoveDeadBranches( break; } else { if (local_dead_nodes.insert(dead.node).second) { - for (const GraphView::InputPort& dead_fanout : + for (const MutableGraphView::InputPort& dead_fanout : view.GetFanouts(*dead.node, true)) { zombie_inputs.PushBack(dead_fanout); } diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc index c36dc65bb04a4b..e0a913565fc4b9 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc @@ -30,8 +30,8 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_memory.h" #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/costs/utils.h" -#include "tensorflow/core/grappler/graph_view.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/graph_rewriter.h" #include "tensorflow/core/grappler/optimizers/static_schedule.h" @@ -497,7 +497,7 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level, bool SchedulingPass(Cluster* cluster, GrapplerItem* item) { // Look for AddN nodes (and equivalent) and record input names. - GraphView view(&item->graph); + MutableGraphView view(&item->graph); std::unordered_map> addn_list; for (NodeDef& node : *item->graph.mutable_node()) { @@ -592,7 +592,7 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) { for (int i = 0; i < node->input_size(); ++i) { const string& input = node->input(i); const string node_name = NodeName(input); - NodeDef* node = view.GetNode(node_name); + const NodeDef* node = view.GetNode(node_name); input_topo_index.push_back(topo_order.at(node)); } int min_input_topo_index = INT_MAX; @@ -834,7 +834,8 @@ static const NodeDef* FindSwapInTrigger( return nullptr; } -static bool IsSwappable(const GraphView& graph, GraphView::OutputPort output) { +static bool IsSwappable(const MutableGraphView& graph, + MutableGraphView::OutputPort output) { const NodeDef& node = *output.node; // There is no point in swapping out persistent tensors, since the tensor will // continue to use memory. @@ -860,10 +861,10 @@ static bool IsSwappable(const GraphView& graph, GraphView::OutputPort output) { // If placed on the same device, these nodes are just forwarding references // to their input. Therefore they are swappable iff their fanin is swappable // or it resides on a different device. - GraphView::InputPort input; + MutableGraphView::InputPort input; input.node = output.node; input.port_id = 0; - GraphView::OutputPort fanin = graph.GetRegularFanin(input); + MutableGraphView::OutputPort fanin = graph.GetRegularFanin(input); if (fanin.node->device() == node.device()) { return IsSwappable(graph, fanin); } @@ -872,19 +873,19 @@ static bool IsSwappable(const GraphView& graph, GraphView::OutputPort output) { } static NodeDef* FindSwapOutTrigger( - const NodeDef* node, int input_id, const GraphView& view, + const NodeDef* node, int input_id, const MutableGraphView& view, const std::unordered_map& execution_times) { // Find the output port that generated the tensor to swap. - GraphView::InputPort swap; + MutableGraphView::InputPort swap; swap.node = const_cast(node); swap.port_id = input_id; - GraphView::OutputPort generator = view.GetRegularFanin(swap); + MutableGraphView::OutputPort generator = view.GetRegularFanin(swap); if (!generator.node) { return nullptr; } - const std::unordered_set& fanout = + const absl::flat_hash_set& fanout = view.GetFanout(generator); NodeDef* trigger = nullptr; Costs::NanoSeconds earliest_fanout(Costs::NanoSeconds::infinity()); @@ -903,7 +904,7 @@ static NodeDef* FindSwapOutTrigger( return trigger; } -static bool IsSwappable(GraphView::InputPort input) { +static bool IsSwappable(MutableGraphView::InputPort input) { const NodeDef& node = *input.node; const OpDef* op_def; @@ -920,9 +921,9 @@ static bool IsSwappable(GraphView::InputPort input) { } struct MemInfo { - GraphView::OutputPort port; + MutableGraphView::OutputPort port; int64 memory_used; - std::vector uses_left; + std::vector uses_left; double fitness; bool operator<(const MemInfo& other) const { return fitness < other.fitness; } @@ -993,7 +994,7 @@ static bool IdentifySwappingCandidates( std::vector mem_state; - GraphView graph(&item->graph); + MutableGraphView graph(&item->graph); for (const auto& live_tensor : mem_usage.live_tensors) { if (live_tensor.memory_used <= 1024) { // Don't bother with small tensors. @@ -1009,7 +1010,7 @@ static bool IdentifySwappingCandidates( if (skip_list->find(live_tensor.node) != skip_list->end()) { continue; } - GraphView::OutputPort port = + MutableGraphView::OutputPort port = graph.GetOutputPort(live_tensor.node, live_tensor.output_id); if (!IsSwappable(graph, port)) { continue; @@ -1020,7 +1021,7 @@ static bool IdentifySwappingCandidates( Costs::Duration allocation_time = live_tensor.allocation_time; Costs::Duration earliest_use(Costs::Duration::infinity()); bool valid = true; - for (GraphView::InputPort input : graph.GetFanout(port)) { + for (MutableGraphView::InputPort input : graph.GetFanout(port)) { // Get execution time. auto it = op_completion_times.find(input.node->name()); if (it == op_completion_times.end()) { @@ -1062,7 +1063,7 @@ static bool IdentifySwappingCandidates( // the values do not fit into any integral type. mem_info.fitness = MathUtil::IPow((earliest_use - peak_time).count(), 2) / - MathUtil::IPow(mem_info.uses_left.size(), 2) + + MathUtil::IPow(mem_info.uses_left.size(), 2) + MathUtil::IPow((allocation_time - peak_time).count(), 2); mem_info.fitness = -mem_info.fitness; mem_state.push_back(mem_info); @@ -1073,7 +1074,8 @@ static bool IdentifySwappingCandidates( std::sort(mem_state.begin(), mem_state.end()); for (const MemInfo& mem_info : mem_state) { - for (const GraphView::InputPort fanout_to_swap : mem_info.uses_left) { + for (const MutableGraphView::InputPort fanout_to_swap : + mem_info.uses_left) { VLOG(1) << "Will swap fanout " << fanout_to_swap.node->name() << ":" << fanout_to_swap.port_id << " of tensor " << mem_info.port.node->name() << ":" << mem_info.port.port_id @@ -1150,7 +1152,7 @@ bool SwappingPass(RewriterConfig::MemOptType optimization_level, for (const auto& node : item->graph.node()) { name_map[node.name()] = &node; } - GraphView view(&item->graph); + MutableGraphView view(&item->graph); bool updated_graph = false; diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 1d787d2b7c2646..82c88bb06aeca7 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -279,6 +279,18 @@ MetaOptimizer::GetCustomGraphOptimizerConfig(const string& name) const { return nullptr; } +#define RUN_OPTIMIZER_OR_RETURN_IF_ERROR(optimizer) \ + { \ + const Status status = RunOptimizer(optimizer, cluster, &optimized_item, \ + optimized_graph, &optimization_result); \ + if (status.ok()) { \ + is_optimized = true; \ + } else if (cfg_.fail_on_optimizer_errors()) { \ + VLOG(2) << "Optimizer '" << optimizer->name() << "' failed: " << status; \ + TF_RETURN_IF_ERROR(status); \ + } \ + } + Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes @@ -340,9 +352,7 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item, if (fusion_optimizer == nullptr) fusion_optimizer = optimizer.get(); continue; } - Status status = RunOptimizer(optimizer.get(), cluster, &optimized_item, - optimized_graph, &optimization_result); - if (status.ok()) is_optimized = true; + RUN_OPTIMIZER_OR_RETURN_IF_ERROR(optimizer.get()); } } @@ -353,16 +363,12 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item, // optimizations from taking place since we don't have shape inference for // functions, and we can't optimize across function boundaries. if (fusion_optimizer != nullptr) { - Status status = RunOptimizer(fusion_optimizer, cluster, &optimized_item, - optimized_graph, &optimization_result); - if (status.ok()) is_optimized = true; + RUN_OPTIMIZER_OR_RETURN_IF_ERROR(fusion_optimizer); } // ScopedAllocatorOptimizer must run last. if (sa_optimizer != nullptr) { - Status status = RunOptimizer(sa_optimizer, cluster, &optimized_item, - optimized_graph, &optimization_result); - if (status.ok()) is_optimized = true; + RUN_OPTIMIZER_OR_RETURN_IF_ERROR(sa_optimizer); } // Record graph optimization result. @@ -379,6 +385,8 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item, return Status::OK(); } +#undef RUN_OPTIMIZER_OR_RETURN_IF_ERROR + Status MetaOptimizer::RunOptimizer( GraphOptimizer* optimizer, Cluster* cluster, GrapplerItem* optimized_item, GraphDef* optimized_graph, GraphOptimizationResult* optimization_result) { diff --git a/tensorflow/core/grappler/optimizers/model_pruner.cc b/tensorflow/core/grappler/optimizers/model_pruner.cc index d074a9cda592c9..1be87a9d0d516a 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner.cc +++ b/tensorflow/core/grappler/optimizers/model_pruner.cc @@ -39,19 +39,12 @@ bool IsTrivialOp(const NodeDef& node, const GraphRewriter& rewriter) { return true; } if (IsIdentity(node) || IsIdentityNSingleInput(node)) { - if (rewriter.FeedsMerge(node) || rewriter.IsDrivenBySwitch(node) || - rewriter.IsDrivenByControlDependency(node) || - rewriter.DrivesControlDependency(node)) { - return false; - } else { - return true; - } - } - if (IsAddN(node) && NumNonControlInputs(node) <= 1) { - return true; + return !(rewriter.FeedsMerge(node) || rewriter.IsDrivenBySwitch(node) || + rewriter.IsDrivenByControlDependency(node) || + rewriter.DrivesControlDependency(node)); } - return false; + return IsAddN(node) && NumNonControlInputs(node) <= 1; } absl::flat_hash_map> IdentityNTerminalPorts( @@ -190,19 +183,18 @@ Status RewriteIdentityNAndInputsOutputs( if (IsControlInput(input)) { continue; } - int pos; - const StringPiece name = ParseNodeNameAsStringPiece(input, &pos); - if (name == node->name()) { - if (terminal_ports.find(pos) == terminal_ports.end()) { + TensorId input_tensor = ParseTensorName(input); + if (input_tensor.node() == node->name()) { + if (terminal_ports.find(input_tensor.index()) == terminal_ports.end()) { // Replace input that does not lead to a terminal node with newly // created identity. - string new_identity = new_identities[pos]; + string new_identity = new_identities[input_tensor.index()]; output->set_input(i, new_identity); updates.push_back({new_identity, output->name()}); } else { // Update input ports that lead to a terminal node from splitting // inputs. - int new_pos = terminal_input_pos[pos]; + int new_pos = terminal_input_pos[input_tensor.index()]; string updated_input_name = new_pos > 0 ? strings::StrCat(node->name(), ":", new_pos) : node->name(); diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.cc b/tensorflow/core/grappler/optimizers/shape_optimizer.cc index 6ccb1cd783d82e..7dae0e3cd9ef8a 100644 --- a/tensorflow/core/grappler/optimizers/shape_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/shape_optimizer.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/grappler/graph_view.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/symbolic_shapes.h" @@ -34,7 +34,7 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, GraphProperties properties(item); bool inferred_properties = false; - GraphView graph(optimized_graph); + MutableGraphView graph(optimized_graph); // The product of all the dimensions in a tensor shape can be expressed more // simply as the size of the tensor. @@ -42,8 +42,8 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, if (!IsShape(node)) { continue; } - for (GraphView::InputPort fanout : - graph.GetFanout(GraphView::OutputPort(&node, 0))) { + for (MutableGraphView::InputPort fanout : + graph.GetFanout(MutableGraphView::OutputPort(&node, 0))) { if (fanout.node->op() != "Prod") { continue; } @@ -53,8 +53,8 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // rewrite the whole expression directly as a Size operation. continue; } - const GraphView::OutputPort reduce_indices = - graph.GetRegularFanin(GraphView::InputPort(fanout.node, 1)); + const MutableGraphView::OutputPort reduce_indices = + graph.GetRegularFanin(MutableGraphView::InputPort(fanout.node, 1)); if (!inferred_properties) { // Infer properties lazily in case they are not needed. TF_RETURN_IF_ERROR(properties.InferStatically(false)); @@ -90,10 +90,10 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // is possible whenever the symbolic dimensions in the numerator and // denominator cancel each other. if (node.op() == "Div") { - const GraphView::OutputPort input1 = - graph.GetRegularFanin(GraphView::InputPort(&node, 0)); - const GraphView::OutputPort input2 = - graph.GetRegularFanin(GraphView::InputPort(&node, 1)); + const MutableGraphView::OutputPort input1 = + graph.GetRegularFanin(MutableGraphView::InputPort(&node, 0)); + const MutableGraphView::OutputPort input2 = + graph.GetRegularFanin(MutableGraphView::InputPort(&node, 1)); if (!IsSize(*input1.node) || !IsSize(*input2.node)) { continue; } diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index ac54eae0e5fc5b..f0f0798035cd88 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -25,12 +25,14 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/scanner.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { namespace grappler { @@ -143,20 +145,18 @@ void NodeMap::UpdateOutput(const string& node_name, } bool IsSameInput(const string& name1, const string& name2) { - if (name1 == name2) { - return true; - } - int position1; - StringPiece node1 = ParseNodeNameAsStringPiece(name1, &position1); - int position2; - StringPiece node2 = ParseNodeNameAsStringPiece(name2, &position2); - return (position1 == position2) && (node1 == node2); + if (name1 == name2) return true; + TensorId tensor1 = ParseTensorName(name1); + TensorId tensor2 = ParseTensorName(name2); + return tensor1.node() == tensor2.node() && tensor1.index() == tensor2.index(); } bool IsControlInput(const string& name) { return !name.empty() && name[0] == '^'; } +bool IsControlInput(const TensorId& tensor_id) { return tensor_id.index() < 0; } + string AddPrefixToNodeName(const string& name, const string& prefix, const string& delimiter) { if (!name.empty()) { @@ -243,7 +243,6 @@ int NumNonControlInputs(const NodeDef& node) { int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) { int num_outputs = 0; - int pos; for (const NodeDef* output : node_map.GetOutputs(node.name())) { for (const string& node_as_input : output->input()) { if (IsControlInput(node_as_input)) { @@ -252,9 +251,8 @@ int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) { if (node_as_input == node.name()) { ++num_outputs; } else { - const StringPiece name = - ParseNodeNameAsStringPiece(node_as_input, &pos); - if (name == node.name()) { + const TensorId tensor = ParseTensorName(node_as_input); + if (tensor.node() == node.name()) { ++num_outputs; } } @@ -563,5 +561,14 @@ Status CheckAttrsExist(const NodeDef& node, absl::Span keys) { return Status::OK(); } +Status IsKernelRegisteredForNode(const NodeDef& node) { + DeviceNameUtils::ParsedName parsed_name; + if (!DeviceNameUtils::ParseFullName(node.device(), &parsed_name)) { + return errors::InvalidArgument("Could not parse device name: ", + node.device()); + } + return FindKernelDef(DeviceType(parsed_name.type), node, nullptr, nullptr); +} + } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index 39319eacb73c4a..0f756a2dbd3d12 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/threadpool.h" @@ -103,6 +104,9 @@ class SetVector { // the ^ character. bool IsControlInput(const string& name); +// True iff tensor index refers to a control input. +bool IsControlInput(const TensorId& tensor_id); + // True iff 'name1' and 'name2' refer to the same input. bool IsSameInput(const string& name1, const string& name2); @@ -165,6 +169,7 @@ inline string NodeName(const string& name) { } // Returns the node name and position in a single call. +// DEPRECATED(ezhulenev): Use TensorId and ParseTensorName. inline StringPiece ParseNodeNameAsStringPiece(const string& name, int* position) { static const string empty; @@ -195,6 +200,7 @@ inline StringPiece ParseNodeNameAsStringPiece(const string& name, } // Returns the node name and position in a single call. +// DEPRECATED(ezhulenev): Use SafeTensorId and ParseTensorName. inline string ParseNodeName(const string& name, int* position) { return string(ParseNodeNameAsStringPiece(name, position)); } @@ -276,6 +282,10 @@ NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map, void PermuteNodesInPlace(GraphDef* graph, std::vector* permutation, bool invert_permutation); +// Returns Status::OK() if a kernel is registered for node.op() on the device +// type corresponding to node.device(). +Status IsKernelRegisteredForNode(const NodeDef& node); + Status SetTensorValue(DataType dtype, int value, Tensor* tensor); void EraseNodesFromGraph(const std::set& nodes_to_delete, GraphDef* graph); diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index 8559788306137d..dbe425b75fd1bb 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -101,6 +101,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:graph_view", + "//tensorflow/core/grappler:mutable_graph_view", "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -172,6 +173,7 @@ cc_library( "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc index da79b0646a2828..c806f3874ddbfa 100644 --- a/tensorflow/core/grappler/utils/functions.cc +++ b/tensorflow/core/grappler/utils/functions.cc @@ -16,6 +16,7 @@ limitations under the License. #include +#include "absl/strings/substitute.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" @@ -617,10 +618,16 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func, // Instantiate function body into a statically defined graph def. GraphDef function_body; - // Function body shares the library with the graph that instantiated it. It's - // unsafe to prune unreachable functions here, because it might lead to - // conflicting specializations. - *function_body.mutable_library() = flib.ToProto(); + // Function body shares the library with the graph that instantiated it. We do + // not need a full copy of the function library, just the reachable subset. + *function_body.mutable_library() = + ReachableFunctionLibraryDefinition(flib, func).ToProto(); + + VLOG(3) << absl::Substitute( + "Deleted $0 unreachable functions from the Grappler function item " + "instantiation of $1 (library size = $2)", + flib.num_functions() - function_body.library().function_size(), + signature.name(), function_body.library().function_size()); // TODO(ezhulenev): support functions with tensor sequence inputs/outputs @@ -658,7 +665,7 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func, InputArgExpansion input_expansion{/*input_name=*/input.name(), /*data_type=*/input_data_type, - /*is_ref*/ input.is_ref(), + /*is_ref=*/input.is_ref(), /*placeholders=*/{input.name()}}; connectivity.RegisterInputArgExpansion(input_expansion); inputs.push_back(std::move(input_expansion)); diff --git a/tensorflow/core/grappler/utils/traversal.cc b/tensorflow/core/grappler/utils/traversal.cc index e5b2d17ae5524d..6952277568676b 100644 --- a/tensorflow/core/grappler/utils/traversal.cc +++ b/tensorflow/core/grappler/utils/traversal.cc @@ -21,8 +21,11 @@ limitations under the License. namespace tensorflow { namespace grappler { -void ReverseDfs( - const GraphView& graph_view, const std::vector& from, +namespace { + +template +void ReverseDfsInternal( + const GraphViewType& graph_view, const std::vector& from, const std::function& pre_order, const std::function& post_order, const std::function& on_back_edge) { @@ -79,5 +82,25 @@ void ReverseDfs( } } +} // namespace + +void ReverseDfs( + const GraphView& graph_view, const std::vector& from, + const std::function& pre_order, + const std::function& post_order, + const std::function& on_back_edge) { + ReverseDfsInternal(graph_view, from, pre_order, post_order, + on_back_edge); +} + +void ReverseDfs( + const MutableGraphView& graph_view, const std::vector& from, + const std::function& pre_order, + const std::function& post_order, + const std::function& on_back_edge) { + ReverseDfsInternal(graph_view, from, pre_order, post_order, + on_back_edge); +} + } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/utils/traversal.h b/tensorflow/core/grappler/utils/traversal.h index 8aa97237cc25bc..5b7737f97eb1f8 100644 --- a/tensorflow/core/grappler/utils/traversal.h +++ b/tensorflow/core/grappler/utils/traversal.h @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/core/grappler/graph_view.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" namespace tensorflow { namespace grappler { @@ -34,6 +35,12 @@ void ReverseDfs( const std::function& post_order, const std::function& on_back_edge); +void ReverseDfs( + const MutableGraphView& graph_view, const std::vector& from, + const std::function& pre_order, + const std::function& post_order, + const std::function& on_back_edge); + } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/utils/traversal_test.cc b/tensorflow/core/grappler/utils/traversal_test.cc index fad26b5a9e34a8..c040477a089704 100644 --- a/tensorflow/core/grappler/utils/traversal_test.cc +++ b/tensorflow/core/grappler/utils/traversal_test.cc @@ -14,9 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/utils/traversal.h" -//#include "tensorflow/core/framework/node_def.pb.h" -//#include "tensorflow/core/lib/core/status_test_util.h" -//#include "tensorflow/core/platform/protobuf.h" + #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" @@ -65,8 +63,16 @@ TEST_F(TraversalTest, ReverseDfsNoLoop) { found_back_edge = true; }); - EXPECT_EQ(std::vector({"1", "4", "3", "2", "5", "0"}), pre_order); - EXPECT_EQ(std::vector({"4", "5", "2", "3", "1", "0"}), post_order); + // Pre/Post order traversals are non deterministic because a node fanin is an + // absl::flat_hash_set with non deterministic traversal order. + using ValidTraversal = std::pair, std::vector>; + + std::set valid_traversals = { + // pre_order post_order + {{"1", "4", "3", "2", "5", "0"}, {"4", "5", "2", "3", "1", "0"}}, + {{"1", "3", "2", "5", "4", "0"}, {"5", "2", "3", "4", "1", "0"}}}; + + EXPECT_EQ(valid_traversals.count({pre_order, post_order}), 1); EXPECT_FALSE(found_back_edge); } @@ -92,8 +98,17 @@ TEST_F(TraversalTest, ReverseDfsWithLoop) { back_edges.push_back(strings::StrCat(src->name(), "->", dst->name())); }); - EXPECT_EQ(std::vector({"6", "3", "2", "1", "5", "4"}), pre_order); - EXPECT_EQ(std::vector({"1", "4", "5", "2", "3", "6"}), post_order); + // Pre/Post order traversals are non deterministic because a node fanin is an + // absl::flat_hash_set with non deterministic traversal order. + using ValidTraversal = std::pair, std::vector>; + + std::set valid_traversals = { + // pre_order post_order + {{"6", "3", "2", "4", "5", "1"}, {"5", "4", "1", "2", "3", "6"}}, + {{"6", "3", "2", "1", "5", "4"}, {"1", "4", "5", "2", "3", "6"}}, + {{"6", "3", "2", "5", "4", "1"}, {"4", "5", "1", "2", "3", "6"}}}; + + EXPECT_EQ(valid_traversals.count({pre_order, post_order}), 1); EXPECT_EQ(std::vector({"4->3"}), back_edges); } diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc index 447195b0018f5d..8cbff1c397114c 100644 --- a/tensorflow/core/grappler/utils_test.cc +++ b/tensorflow/core/grappler/utils_test.cc @@ -384,6 +384,27 @@ TEST_F(UtilsTest, DeleteNodes) { // TODO(rmlarsen): write forgotten test. } +TEST(IsKernelRegisteredForNode, All) { + NodeDef node; + node.set_name("foo"); + node.set_op("NoOp"); + node.set_device("/cpu:0"); + TF_EXPECT_OK(IsKernelRegisteredForNode(node)); + node.set_device("/gpu:0"); + TF_EXPECT_OK(IsKernelRegisteredForNode(node)); + + // Bad device name. + node.set_device(""); + EXPECT_FALSE(IsKernelRegisteredForNode(node).ok()); + + // Check an op that is only defined on CPU. + node.set_op("MatchingFiles"); + node.set_device("/cpu:0"); + TF_EXPECT_OK(IsKernelRegisteredForNode(node)); + node.set_device("/gpu:0"); + EXPECT_FALSE(IsKernelRegisteredForNode(node).ok()); +} + #define BM_NodePositionIfSameNode(I, N, NAME) \ static void BM_NodePositionIfSameNode_##NAME(int iters) { \ string input = I; \ diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 5182a72214f305..fed0178176e494 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -93,6 +93,17 @@ config_setting( }, ) +config_setting( + # Add "--define tensorflow_eigen_mkldnn=1" to your build command to use mkldnn + # sgemm in Eigen tensor contractions (matrix multiplications and convolutions). + # The mkldnn kernels are generated at runtime and use avx/avx2/fma/avx512 + # based on cpu status registers (https://en.wikipedia.org/wiki/CPUID). + name = "eigen_mkldnn", + values = { + "define": "tensorflow_eigen_mkldnn=1", + }, +) + # Public support libraries ---------------------------------------------------- cc_library( @@ -555,10 +566,20 @@ cc_library( "eigen_softmax.h", "eigen_spatial_convolutions.h", "eigen_volume_patch.h", - ], + ] + select({ + ":eigen_mkldnn": ["eigen_mkldnn.h"], + "//conditions:default": [], + }), + defines = select({ + ":eigen_mkldnn": ["EIGEN_USE_MKLDNN"], + "//conditions:default": [], + }), deps = [ "//third_party/eigen3", - ], + ] + select({ + ":eigen_mkldnn": ["//third_party/intel_mkl_dnn:mkldnn_single_threaded"], + "//conditions:default": [], + }), ) cc_library( @@ -1887,10 +1908,22 @@ tf_kernel_library( deps = DATA_FLOW_DEPS, ) +cc_library( + name = "stack", + srcs = ["stack.cc"], + hdrs = ["stack.h"], + deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + tf_kernel_library( name = "stack_ops", prefix = "stack_ops", - deps = DATA_FLOW_DEPS, + deps = DATA_FLOW_DEPS + [":stack"], ) tf_kernel_library( @@ -2393,12 +2426,25 @@ tf_cc_tests( ], deps = [ ":eigen_helpers", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core:testlib", + ], +) + +# Conditional test target generation is not supported by the "tf_cc_tests" macro +# (can't add 'select' to the srcs field, type 'select' is not iterable). +tf_cc_test( + name = "eigen_mkldnn_test", + size = "small", + srcs = select({ + ":eigen_mkldnn": ["eigen_mkldnn_test.cc"], + "//conditions:default": [], + }), + tags = ["eigen_mkldnn"], + deps = [ + ":eigen_helpers", + "//tensorflow/core:test", + "//tensorflow/core:test_main", ], ) @@ -5489,6 +5535,8 @@ filegroup( "sparse_to_dense_op.cc", "spectrogram.cc", "spectrogram_op.cc", + "stack.cc", + "stack.h", "stack_ops.cc", "string_join_op.cc", "string_util.cc", diff --git a/tensorflow/core/kernels/constant_op_test.cc b/tensorflow/core/kernels/constant_op_test.cc index 0faad11e4721c9..3988c190e701c8 100644 --- a/tensorflow/core/kernels/constant_op_test.cc +++ b/tensorflow/core/kernels/constant_op_test.cc @@ -79,7 +79,7 @@ void ConstantOpTest::PersistentMemoryTrackingTest(bool on_gpu) { } // Remove memory leak errors. - for (auto allocator_pair : ctx.wrapped_allocators()) { + for (auto allocator_pair : ctx.ConsumeWrappedAllocators()) { allocator_pair.second->GetRecordsAndUnRef(); } } diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc index dfcc302f468e74..fbd702ef14ed2b 100644 --- a/tensorflow/core/kernels/cudnn_rnn_ops.cc +++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc @@ -515,14 +515,17 @@ struct CudnnRnnModelShapes { // key. struct CudnnRnnConfigHasher { uint64 operator()( - const std::pair& to_hash) const { + const std::pair>& + to_hash) const { auto& shapes = to_hash.first; auto& algo_desc = to_hash.second; uint64 hash = HashList({shapes.num_layers, shapes.input_size, shapes.num_units, shapes.dir_count, shapes.batch_size}); - hash = Hash64Combine(hash, algo_desc.hash()); + if (algo_desc.has_value()) { + hash = Hash64Combine(hash, algo_desc->hash()); + } return hash; } }; @@ -531,8 +534,9 @@ struct CudnnRnnConfigHasher { // table key. struct CudnnRnnConfigComparator { bool operator()( - const std::pair& lhs, - const std::pair& rhs) const { + const std::pair>& lhs, + const std::pair>& rhs) + const { return lhs.first.IsCompatibleWith(rhs.first) && lhs.second == rhs.second; } }; @@ -887,10 +891,9 @@ class CudnnRNNKernelCommon : public OpKernel { return Status::OK(); } - using RnnStateCache = - gtl::FlatMap, - RnnScratchSpace, CudnnRnnConfigHasher, - CudnnRnnConfigComparator>; + using RnnStateCache = gtl::FlatMap< + std::pair>, + RnnScratchSpace, CudnnRnnConfigHasher, CudnnRnnConfigComparator>; // Returns a raw rnn descriptor pointer. The cache owns the rnn descriptor and // should outlive the returned pointer. template @@ -1317,9 +1320,9 @@ class CudnnRNNForwardOpV2 OP_REQUIRES_OK(context, context->allocate_output(4, TensorShape({2}), &output_host_reserved)); auto output_host_reserved_int8 = output_host_reserved->vec(); - output_host_reserved_int8(0) = best_algo_config.algorithm().algo_id(); + output_host_reserved_int8(0) = best_algo_config.algorithm()->algo_id(); output_host_reserved_int8(1) = - best_algo_config.algorithm().tensor_ops_enabled(); + best_algo_config.algorithm()->tensor_ops_enabled(); } else { OP_REQUIRES_OK(context, context->allocate_output(4, {}, &output_host_reserved)); @@ -1359,8 +1362,8 @@ class CudnnRNNForwardOpV2 if (AutoTuneRnnConfigMap::GetInstance()->Find(rnn_params, algo_config)) { VLOG(1) << "Using existing best Cudnn RNN algorithm " << "(algo, tensor_op_enabled) = (" - << algo_config->algorithm().algo_id() << ", " - << algo_config->algorithm().tensor_ops_enabled() << ")."; + << algo_config->algorithm()->algo_id() << ", " + << algo_config->algorithm()->tensor_ops_enabled() << ")."; return Status::OK(); } diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc index fb306a2e5beea2..cbcae0588c6e28 100644 --- a/tensorflow/core/kernels/data/batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/batch_dataset_op.cc @@ -117,7 +117,6 @@ class BatchDatasetOp : public UnaryDatasetOpKernel { : DatasetIterator(params) {} Status Initialize(IteratorContext* ctx) override { - AddConstantParameter(ctx, "batch_size", dataset()->batch_size_); return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); } @@ -199,6 +198,12 @@ class BatchDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + dataset()->batch_size_); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); if (!input_impl_) { diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index 0db5031152ba5c..ce6fd09aee53a4 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -133,6 +133,12 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("mode"), mode_)); @@ -243,6 +249,12 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); if (iteration_completed_) { @@ -468,6 +480,12 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR( @@ -683,6 +701,12 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("mode"), mode_)); @@ -799,6 +823,12 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); return SaveInput(writer, input_impl_); @@ -825,6 +855,12 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("index"), index_)); diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index 0bb929b3ce6c68..64834e507f2d5b 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -19,8 +19,10 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/notification.h" #include "tensorflow/core/util/ptr_util.h" @@ -75,6 +77,8 @@ class SimpleStepStatsCollector : public StepStatsCollectorInterface { end_time_ns_ = Env::Default()->NowNanos(); } + bool TrackAllocations() const override { return false; } + void SetMemory(OpKernelContext* ctx) override {} void SetOutput(int slot, const Tensor* tensor) override {} @@ -451,15 +455,17 @@ void CapturedFunction::RunAsync(IteratorContext* ctx, CancellationManager* c_mgr = new CancellationManager; f_opts.cancellation_manager = c_mgr; std::shared_ptr stats_collector; - if (ctx->model()) { + if (ctx->model() || ctx->stats_aggregator()) { stats_collector = MakeUnique(); } f_opts.stats_collector = stats_collector.get(); auto callback = std::bind( - [rets, step_container, c_mgr, frame]( + [this, rets, step_container, c_mgr, frame]( const FunctionLibraryRuntime::DoneCallback& done, - const std::shared_ptr& model, const string& prefix, + const std::shared_ptr& model, + const std::shared_ptr& stats_aggregator, + const string& prefix, const std::shared_ptr& stats_collector, // Begin unbound arguments. Status s) { @@ -469,6 +475,14 @@ void CapturedFunction::RunAsync(IteratorContext* ctx, s = frame->ConsumeRetvals(rets); } delete frame; + + if (stats_aggregator) { + stats_aggregator->AddToHistogram( + strings::StrCat( + str_util::Split(prefix, "::", str_util::SkipEmpty()).back(), + "::", func_.name(), "::execution_time"), + {static_cast(stats_collector->processing_time())}); + } if (model) { model->AddProcessingTime(prefix, stats_collector->processing_time()); model->RecordStart(prefix, false /* stop_output */); @@ -478,8 +492,8 @@ void CapturedFunction::RunAsync(IteratorContext* ctx, model->RecordStop(prefix, false /* start_output */); } }, - std::move(done), ctx->model(), prefix, std::move(stats_collector), - std::placeholders::_1); + std::move(done), ctx->model(), ctx->stats_aggregator(), prefix, + std::move(stats_collector), std::placeholders::_1); ctx->lib()->Run(f_opts, handle, frame, std::move(callback)); } diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc index 1df988d8456712..d5a0abc64b4e87 100644 --- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc +++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc @@ -129,6 +129,12 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_)); diff --git a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc index b889e4eda97be4..d684d23b24212e 100644 --- a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc @@ -272,6 +272,13 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode( + std::move(args), + DatasetIterator>::dataset()->batch_size_); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(Iterator::SaveInput(writer, input_impl_)); diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index 5e54222c3ec259..1a18864ecf5619 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -147,6 +147,16 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "non_serializable_dataset_op", + srcs = ["non_serializable_dataset_op.cc"], + deps = [ + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//third_party/eigen3", + ], +) + tf_kernel_library( name = "dataset_kernels", deps = [ @@ -156,6 +166,7 @@ tf_kernel_library( ":ignore_errors_dataset_op", ":indexed_dataset", ":lmdb_dataset_op", + ":non_serializable_dataset_op", ":numa_map_and_batch_dataset_op", ":prefetching_kernels", ":sleep_dataset_op", diff --git a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc index 3511cca0f522b5..3b5ee9b783c7c6 100644 --- a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc @@ -122,6 +122,12 @@ class AssertNextDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); return Status::OK(); diff --git a/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc index cc3dcc1612b3b8..f6f58fc430b41d 100644 --- a/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc @@ -263,6 +263,11 @@ class CSVDatasetOp : public DatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeSourceNode(std::move(args)); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"), diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc index c47a9099c4afc3..d8bb696167a797 100644 --- a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc @@ -202,6 +202,11 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeInterleaveManyNode(std::move(args)); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); if (selector_input_impl_) { diff --git a/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc b/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc index 45515bb1dac88e..d10a3dea110c9c 100644 --- a/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc +++ b/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc @@ -91,6 +91,13 @@ class IdentityIndexedDatasetOp : public IndexedDatasetOpKernel { return Status::OK(); } + protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + private: mutex mu_; uint64 cur_ GUARDED_BY(mu_); diff --git a/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc index b34377c6429837..57cb44335b17f3 100644 --- a/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc @@ -103,6 +103,12 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); if (input_impl_) diff --git a/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc index 9e9a13d56d1a54..6248eb775e481c 100644 --- a/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc @@ -127,6 +127,11 @@ class LMDBDatasetOp : public DatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeSourceNode(std::move(args)); + } + Status SaveInternal(IteratorStateWriter* writer) override { return errors::Unimplemented( "Checkpointing is currently not supported for LMDBDataset."); diff --git a/tensorflow/core/kernels/data/experimental/non_serializable_dataset_op.cc b/tensorflow/core/kernels/data/experimental/non_serializable_dataset_op.cc new file mode 100644 index 00000000000000..953e086de3786b --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/non_serializable_dataset_op.cc @@ -0,0 +1,130 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace data { +namespace { + +class NonSerializableDatasetOp : public UnaryDatasetOpKernel { + public: + explicit NonSerializableDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + *output = new Dataset(ctx, input, output_types_, output_shapes_); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const DataTypeVector& output_types, + const std::vector& output_shapes) + : DatasetBase(DatasetContext(ctx)), + input_(input), + output_types_(output_types), + output_shapes_(output_shapes) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::NonSerializable")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "NonSerializableDatasetOp::Dataset"; + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented(DebugString(), "::AsGraphDefInternal"); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + return input_impl_->GetNext(ctx, out_tensors, end_of_sequence); + } + + protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1); + } + + Status SaveInternal(IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + return Status::OK(); + } + + private: + std::unique_ptr input_impl_; + }; + + const DatasetBase* input_; + const DataTypeVector output_types_; + const std::vector output_shapes_; + }; + + DataTypeVector output_types_; + std::vector output_shapes_; +}; + +REGISTER_KERNEL_BUILDER( + Name("ExperimentalNonSerializableDataset").Device(DEVICE_CPU), + NonSerializableDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc index 517cdd25963ef2..068f854023064a 100644 --- a/tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc @@ -200,16 +200,9 @@ class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel { Status Initialize(IteratorContext* ctx) override { mutex_lock l(*mu_); - AddConstantParameter(ctx, "batch_size", dataset()->batch_size_); if (num_parallel_calls_->value == kAutoTune) { - num_parallel_calls_->value = std::max(1, port::NUMANumNodes()); - AddTunableParameter(ctx, - /* name = */ "parallelism", - /* state = */ num_parallel_calls_, - /* min = */ num_parallel_calls_->value, - /* max = */ port::NumSchedulableCPUs()); - } else { - AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value); + num_parallel_calls_->value = ctx->runner_threadpool_size(); + num_parallel_calls_->tunable = true; } TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); @@ -246,6 +239,14 @@ class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeAsyncKnownRatioNode( + std::move(args), dataset()->batch_size_, + {model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1, + /*max=*/ctx->runner_threadpool_size())}); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(*mu_); for (size_t i = 0; i < workers_.size(); ++i) { @@ -906,8 +907,7 @@ class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel { new_ctx = std::make_shared(*ctx); } workers_[i]->threads.emplace_back(ctx->env()->StartThread( - {}, - strings::StrCat("numa_map_and_batch_block_", i, "_thread_", j), + {}, strings::StrCat("tf_data_numa_map_and_batch_", i, "_", j), [this, new_ctx, i, j]() { WorkerThread(new_ctx, i, j); })); VLOG(3) << "Worker " << i << ", " << j << " successfully started."; } @@ -917,7 +917,7 @@ class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel { new_ctx = std::make_shared(*ctx); } runner_thread_.reset(ctx->env()->StartThread( - {}, "numa_map_runner_thread", + {}, "tf_data_numa_map_and_batch", [this, new_ctx] { RunnerThread(new_ctx); })); } VLOG(3) << "All workers & runner thread started."; diff --git a/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc index fba63056be6201..c7bf89cbdeb4f8 100644 --- a/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc @@ -107,6 +107,12 @@ class SleepDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { return SaveInput(writer, input_impl_); } diff --git a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc index d6f697249f4d59..ab21dfc6bc5ddc 100644 --- a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc @@ -47,6 +47,8 @@ class ThreadPoolResource : public ResourceBase { } } + int32 NumThreads() { return thread_pool_.NumThreads(); } + string DebugString() override { return "ThreadPoolResource"; } private: @@ -192,18 +194,20 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { std::vector* out_tensors, bool* end_of_sequence) override { ThreadPoolResource* pool = dataset()->threadpool_; - IteratorContext::Params params; - params.env = ctx->env(); + IteratorContext::Params params(ctx); params.runner = [pool](std::function c) { pool->Schedule(std::move(c)); }; - params.stats_aggregator = ctx->stats_aggregator(); - params.lib = ctx->lib(); - params.function_library = ctx->function_library(); - params.allocator_getter = ctx->allocator_getter(); - IteratorContext threadpool_ctx(params); - return input_impl_->GetNext(&threadpool_ctx, out_tensors, - end_of_sequence); + params.runner_threadpool_size = pool->NumThreads(); + IteratorContext iter_ctx(params); + return input_impl_->GetNext(&iter_ctx, out_tensors, end_of_sequence); + } + + protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); } private: diff --git a/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc index cd612e0eb25366..23dd9ff612db61 100644 --- a/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc @@ -114,6 +114,11 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeUnknownRatioNode(std::move(args)); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); if (input_impl_) { diff --git a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc index 01fdf8001ba03c..784f9872860fee 100644 --- a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc @@ -142,6 +142,11 @@ class FilterByLastComponentDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeUnknownRatioNode(std::move(args)); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc index ba786753bedf96..40cbb124252aa4 100644 --- a/tensorflow/core/kernels/data/filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_dataset_op.cc @@ -234,6 +234,11 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeUnknownRatioNode(std::move(args)); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); if (input_impl_) diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc index ff3fe2ea94a1ee..9b42981ed75aff 100644 --- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc @@ -165,6 +165,11 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeInterleaveManyNode(std::move(args)); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); if (input_impl_) { @@ -242,16 +247,6 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { ¤t_element_iterator_); } - Status BuildCurrentElementIteratorLocked(OpKernelContext* ctx) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - IteratorContext::Params params; - params.env = ctx->env(); - params.runner = *(ctx->runner()); - params.lib = ctx->function_library(); - IteratorContext iter_ctx(std::move(params)); - return BuildCurrentElementIteratorLocked(&iter_ctx); - } - mutex mu_; size_t element_index_ GUARDED_BY(mu_) = 0; std::unique_ptr input_impl_ GUARDED_BY(mu_); diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc index 5de2e2871dd757..ed18d6ed9d8d8f 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.cc +++ b/tensorflow/core/kernels/data/generator_dataset_op.cc @@ -125,6 +125,12 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { return s; } + protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeSourceNode(std::move(args)); + } + private: mutex mu_; bool initialized_ GUARDED_BY(mu_) = false; diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc index bf0189eb063dd6..dc1925a21fe039 100644 --- a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc @@ -267,6 +267,11 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeUnknownRatioNode(std::move(args)); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc index 41711eaa9826b0..64db5df31ebdf2 100644 --- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc @@ -281,6 +281,11 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeUnknownRatioNode(std::move(args)); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc index 42b933bd09794f..9574e400a2db6c 100644 --- a/tensorflow/core/kernels/data/interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc @@ -209,6 +209,11 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeInterleaveManyNode(std::move(args)); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 50b72f46c204e4..445718ba1e532a 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -68,8 +68,10 @@ class IteratorResource : public ResourceBase { std::shared_ptr captured_iterator(iterator_); if (captured_iterator) { CHECK_NOTNULL(lib_); - ctx->set_lib(lib_); - return captured_iterator->GetNext(ctx, out_tensors, end_of_sequence); + IteratorContext::Params params(ctx); + params.lib = lib_; + return captured_iterator->GetNext(IteratorContext(std::move(params)), + out_tensors, end_of_sequence); } else { return errors::FailedPrecondition( "GetNext() failed because the iterator has not been initialized. " @@ -78,6 +80,11 @@ class IteratorResource : public ResourceBase { } } + Status GetNext(IteratorContext&& ctx, std::vector* out_tensors, + bool* end_of_sequence) { + return GetNext(&ctx, out_tensors, end_of_sequence); + } + Status Save(SerializationContext* ctx, IteratorStateWriter* writer) { std::shared_ptr captured_iterator(iterator_); if (captured_iterator) { @@ -124,24 +131,21 @@ class IteratorResource : public ResourceBase { TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset)); std::unique_ptr iterator; - IteratorContext iter_ctx(ctx); - iter_ctx.set_lib(lib); - TF_RETURN_IF_ERROR( - dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator)); + IteratorContext::Params params(ctx); + params.lib = lib; + TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)), + "Iterator", &iterator)); TF_RETURN_IF_ERROR(set_iterator(std::move(iterator))); std::shared_ptr captured_iterator(iterator_); if (captured_iterator) { - IteratorContext::Params params; - params.env = ctx->env(); - params.runner = *(ctx->runner()); + IteratorContext::Params params(ctx); params.lib = lib; DeviceBase* device = lib->device(); params.allocator_getter = [device](AllocatorAttributes attrs) { return device->GetAllocator(attrs); }; IteratorContext iter_ctx(std::move(params)); - TF_RETURN_IF_ERROR(captured_iterator->Restore(&iter_ctx, reader)); mutex_lock l(mu_); device_mgr_ = std::move(device_mgr); @@ -582,10 +586,10 @@ void MakeIteratorOp::Compute(OpKernelContext* ctx) { core::ScopedUnref unref(iterator_resource); std::unique_ptr iterator; - IteratorContext iter_ctx(ctx); - iter_ctx.set_lib(iterator_resource->function_library_runtime()); - OP_REQUIRES_OK( - ctx, dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator)); + IteratorContext::Params params(ctx); + params.lib = iterator_resource->function_library_runtime(); + OP_REQUIRES_OK(ctx, dataset->MakeIterator(IteratorContext(std::move(params)), + "Iterator", &iterator)); OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator))); } @@ -595,9 +599,7 @@ class ToSingleElementOp : public AsyncOpKernel { public: explicit ToSingleElementOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx), - background_worker_(ctx->env(), - strings::StrCat("to_single_element_op_thread_", - SanitizeThreadSuffix(name()))) {} + background_worker_(ctx->env(), "tf_data_to_single_element") {} void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { // The call to `iterator->GetNext()` may block and depend on an @@ -663,9 +665,7 @@ class ReduceDatasetOp : public AsyncOpKernel { public: explicit ReduceDatasetOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx), - background_worker_( - ctx->env(), - strings::StrCat("reduce_thread_", SanitizeThreadSuffix(name()))) { + background_worker_(ctx->env(), "tf_data_reduce_dataset") { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &reduce_func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); @@ -772,10 +772,7 @@ class OneShotIteratorOp : public AsyncOpKernel { public: explicit OneShotIteratorOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx), - background_worker_( - ctx->env(), - strings::StrCat("one_shot_iterator_initialization_thread_", - SanitizeThreadSuffix(name()))), + background_worker_(ctx->env(), "tf_data_one_shot_iterator"), graph_def_version_(ctx->graph_def_version()) { @@ -920,10 +917,10 @@ class OneShotIteratorOp : public AsyncOpKernel { DatasetBase* dataset; TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset)); std::unique_ptr iter; - IteratorContext iter_ctx(ctx); - iter_ctx.set_lib(lib); - TF_RETURN_IF_ERROR( - dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iter)); + IteratorContext::Params params(ctx); + params.lib = lib; + TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)), + "Iterator", &iter)); TF_RETURN_IF_ERROR((*iterator)->set_iterator(std::move(iter))); (*iterator)->Ref(); @@ -979,17 +976,10 @@ void IteratorGetNextOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { std::vector components; bool end_of_sequence = false; - IteratorContext::Params params; - params.env = ctx->env(); - params.runner = *(ctx->runner()); + IteratorContext::Params params(ctx); params.function_library = iterator->function_library(); - DeviceBase* device = ctx->function_library()->device(); - params.allocator_getter = [device](AllocatorAttributes attrs) { - return device->GetAllocator(attrs); - }; - IteratorContext iter_ctx(std::move(params)); - - Status s = iterator->GetNext(&iter_ctx, &components, &end_of_sequence); + Status s = iterator->GetNext(IteratorContext(std::move(params)), + &components, &end_of_sequence); // NOTE(mrry): We must unref the iterator before calling `done()`, to // avoid destruction races. iterator->Unref(); @@ -1013,22 +1003,12 @@ void IteratorGetNextSyncOp::Compute(OpKernelContext* ctx) { IteratorResource* iterator; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator)); core::ScopedUnref unref_iterator(iterator); - std::vector components; bool end_of_sequence = false; - - IteratorContext::Params params; - params.env = ctx->env(); - params.runner = *(ctx->runner()); + IteratorContext::Params params(ctx); params.function_library = iterator->function_library(); - DeviceBase* device = ctx->function_library()->device(); - params.allocator_getter = [device](AllocatorAttributes attrs) { - return device->GetAllocator(attrs); - }; - IteratorContext iter_ctx(std::move(params)); - - OP_REQUIRES_OK(ctx, - iterator->GetNext(&iter_ctx, &components, &end_of_sequence)); + OP_REQUIRES_OK(ctx, iterator->GetNext(IteratorContext(std::move(params)), + &components, &end_of_sequence)); OP_REQUIRES(ctx, !end_of_sequence, errors::OutOfRange("End of sequence")); for (int i = 0; i < components.size(); ++i) { @@ -1043,9 +1023,8 @@ class IteratorGetNextAsOptionalOp : public AsyncOpKernel { public: explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx), - background_worker_( - ctx->env(), strings::StrCat("iterator_get_next_as_optional_thread_", - SanitizeThreadSuffix(name()))) { + background_worker_(ctx->env(), + "tf_data_iterator_get_next_as_optional") { OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); } @@ -1062,18 +1041,10 @@ class IteratorGetNextAsOptionalOp : public AsyncOpKernel { std::vector components; bool end_of_sequence = false; - IteratorContext::Params params; - params.env = ctx->env(); - params.runner = *(ctx->runner()); + IteratorContext::Params params(ctx); params.function_library = iterator->function_library(); - DeviceBase* device = ctx->function_library()->device(); - params.allocator_getter = [device](AllocatorAttributes attrs) { - return device->GetAllocator(attrs); - }; - IteratorContext iter_ctx(std::move(params)); - - Status s = - iterator->GetNext(&iter_ctx, &components, &end_of_sequence); + Status s = iterator->GetNext(IteratorContext(std::move(params)), + &components, &end_of_sequence); // NOTE(mrry): We must unref the iterator before calling `done()`, to // avoid destruction races. iterator->Unref(); diff --git a/tensorflow/core/kernels/data/iterator_ops.h b/tensorflow/core/kernels/data/iterator_ops.h index 8a2b2639a783e4..cd72269859044e 100644 --- a/tensorflow/core/kernels/data/iterator_ops.h +++ b/tensorflow/core/kernels/data/iterator_ops.h @@ -107,9 +107,7 @@ class IteratorGetNextOp : public AsyncOpKernel { public: explicit IteratorGetNextOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx), - background_worker_(ctx->env(), - strings::StrCat("iterator_get_next_thread_", - SanitizeThreadSuffix(name()))) {} + background_worker_(ctx->env(), "tf_data_iterator_get_next") {} void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index 2a9a357d789477..31851925124330 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset_utils.h" @@ -241,7 +242,11 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { cond_var_(std::make_shared()), num_parallel_calls_(std::make_shared( params.dataset->num_parallel_calls_, mu_, cond_var_)), - map_func_(std::move(map_func)) {} + map_func_(std::move(map_func)) { + std::vector components = + str_util::Split(params.prefix, "::", str_util::SkipEmpty()); + prefix_end_ = components.back(); + } ~Iterator() override { mutex_lock l(*mu_); @@ -256,13 +261,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { Status Initialize(IteratorContext* ctx) override { mutex_lock l(*mu_); - AddConstantParameter(ctx, "batch_size", dataset()->batch_size_); if (num_parallel_calls_->value == kAutoTune) { - num_parallel_calls_->value = 1; - AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1, - port::NumSchedulableCPUs()); - } else { - AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value); + num_parallel_calls_->value = ctx->runner_threadpool_size(); + num_parallel_calls_->tunable = true; } TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); @@ -290,6 +291,14 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeAsyncKnownRatioNode( + std::move(args), dataset()->batch_size_, + {model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1, + /*max=*/ctx->runner_threadpool_size())}); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(*mu_); // Wait for all in-flight calls to complete. @@ -363,11 +372,18 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { int64 num_calls; // access guarded by owner's mutex }; - void CallCompleted(const std::shared_ptr& result) + void CallCompleted(const std::shared_ptr& ctx, + const std::shared_ptr& result) LOCKS_EXCLUDED(*mu_) { mutex_lock l(*mu_); num_calls_--; result->num_calls--; + const auto& stats_aggregator = ctx->stats_aggregator(); + if (stats_aggregator) { + stats_aggregator->AddScalar( + strings::StrCat(prefix_end_, "::active_parallel_calls"), + static_cast(num_calls_)); + } cond_var_->notify_all(); } @@ -387,7 +403,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { return_early = result->end_of_input || !result->status.ok(); } if (return_early) { - CallCompleted(result); + CallCompleted(ctx, result); return; } @@ -429,7 +445,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { result->num_elements++; } } - CallCompleted(result); + CallCompleted(ctx, result); }; // Apply the map function on `input_element`, storing the result in @@ -464,7 +480,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { if (!runner_thread_) { auto ctx_copy = std::make_shared(*ctx); runner_thread_.reset(ctx->env()->StartThread( - {}, "runner_thread", + {}, "tf_data_map_and_batch", std::bind(&Iterator::RunnerThread, this, ctx_copy))); } } @@ -574,7 +590,19 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { num_calls_++; } } - + const std::shared_ptr& stats_aggregator = + ctx->stats_aggregator(); + if (stats_aggregator) { + mutex_lock l(*mu_); + // TODO(shivaniagrawal): add `parallel_calls_utilization` in the + // monitoring code or as histogram at fixed time intervals. + stats_aggregator->AddScalar( + strings::StrCat(prefix_end_, "::active_parallel_calls"), + static_cast(num_calls_)); + stats_aggregator->AddScalar( + strings::StrCat(prefix_end_, "::num_parallel_calls"), + static_cast(num_parallel_calls_->value)); + } for (const auto& call : new_calls) { CallFunction(ctx, call.first, call.second); } @@ -722,6 +750,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { std::deque> batch_results_ GUARDED_BY(*mu_); std::unique_ptr runner_thread_ GUARDED_BY(*mu_); bool cancelled_ GUARDED_BY(*mu_) = false; + string prefix_end_; }; const DatasetBase* const input_; diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index 9a92cd21822a91..d64114e70e531c 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -206,6 +206,12 @@ class MapDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); return Status::OK(); diff --git a/tensorflow/core/kernels/data/matching_files_dataset_op.cc b/tensorflow/core/kernels/data/matching_files_dataset_op.cc index 09517ac264a969..d36b9e7e786774 100644 --- a/tensorflow/core/kernels/data/matching_files_dataset_op.cc +++ b/tensorflow/core/kernels/data/matching_files_dataset_op.cc @@ -182,6 +182,11 @@ class MatchingFilesDatasetOp : public DatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeSourceNode(std::move(args)); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar( diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc index 39356e9a591502..a0d8a1619c654a 100644 --- a/tensorflow/core/kernels/data/model_dataset_op.cc +++ b/tensorflow/core/kernels/data/model_dataset_op.cc @@ -86,9 +86,10 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { } Status Initialize(IteratorContext* ctx) override { - IteratorContext ctx_with_model(CreateParams(ctx)); - return dataset()->input_->MakeIterator(&ctx_with_model, prefix(), - &input_impl_); + IteratorContext::Params params(ctx); + params.model = model_; + return dataset()->input_->MakeIterator( + IteratorContext(std::move(params)), prefix(), &input_impl_); } Status GetNextInternal(IteratorContext* ctx, @@ -96,12 +97,19 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { bool* end_of_sequence) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(EnsureOptimizeThreadStarted(ctx)); - IteratorContext ctx_with_model(CreateParams(ctx)); - return input_impl_->GetNext(&ctx_with_model, out_tensors, - end_of_sequence); + IteratorContext::Params params(ctx); + params.model = model_; + return input_impl_->GetNext(IteratorContext(std::move(params)), + out_tensors, end_of_sequence); } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); @@ -115,19 +123,13 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - IteratorContext::Params CreateParams(IteratorContext* ctx) { - IteratorContext::Params params = ctx->params(); - params.model = model_; - return params; - } - private: Status EnsureOptimizeThreadStarted(IteratorContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (!optimize_thread_) { std::shared_ptr new_ctx(new IteratorContext(*ctx)); optimize_thread_.reset(ctx->env()->StartThread( - {}, "optimize_thread", + {}, "tf_data_model", [this, new_ctx]() { OptimizeThread(new_ctx); })); } return Status::OK(); diff --git a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc index d909b9e9d374a4..5268007e3d9528 100644 --- a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc +++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc @@ -86,12 +86,18 @@ class MultiDeviceIterator : public ResourceBase { void GetNextFromShard(IteratorContext* ctx, int shard_num, int64 incarnation_id, MultiDeviceIteratorCallback callback) { - if (lib_ != nullptr) { - ctx->set_lib(lib_); + if (ctx->lib() == lib_) { + tf_shared_lock l(mu_); + multi_device_buffer_->GetNextFromShard(ctx, shard_num, incarnation_id, + std::move(callback)); + } else { + IteratorContext::Params params(ctx); + params.lib = lib_; + IteratorContext iter_ctx(std::move(params)); + tf_shared_lock l(mu_); + multi_device_buffer_->GetNextFromShard( + &iter_ctx, shard_num, incarnation_id, std::move(callback)); } - tf_shared_lock l(mu_); - multi_device_buffer_->GetNextFromShard(ctx, shard_num, incarnation_id, - std::move(callback)); } const DataTypeVector& output_types() const { return output_types_; } @@ -200,7 +206,7 @@ class MultiDeviceIterator : public ResourceBase { EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (!background_thread_) { background_thread_.reset(ctx->env()->StartThread( - {}, "multi_device_iterator_background_thread", + {}, "tf_data_multi_device_iterator", std::bind(&MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread, this, new IteratorContext(*ctx)))); } @@ -455,8 +461,9 @@ class MultiDeviceIteratorInitOp : public OpKernel { core::ScopedUnref unref(resource); std::unique_ptr iterator; - IteratorContext iter_ctx(ctx); - iter_ctx.set_lib(resource->lib()); + IteratorContext::Params params(ctx); + params.lib = resource->lib(); + IteratorContext iter_ctx(std::move(params)); OP_REQUIRES_OK( ctx, dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator)); int64 incarnation_id; @@ -478,11 +485,8 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel { public: explicit MultiDeviceIteratorGetNextFromShardOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx), - thread_pool_(new thread::ThreadPool( - ctx->env(), ThreadOptions(), - strings::StrCat("multi_device_iterator_get_next_thread_", - SanitizeThreadSuffix(name())), - 1 /* num_threads */, false /* low_latency_hint */)) {} + background_worker_(ctx->env(), + "tf_data_multi_device_iterator_get_next") {} void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { const Tensor* tensor_shard_num; @@ -497,18 +501,8 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel { MultiDeviceIterator* iterator; OP_REQUIRES_OK_ASYNC( ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done); - thread_pool_->Schedule(std::bind( + background_worker_.Schedule(std::bind( [ctx, iterator, shard_num, incarnation_id](DoneCallback done) { - IteratorContext::Params params; - params.env = ctx->env(); - params.runner = *(ctx->runner()); - params.function_library = iterator->function_library(); - DeviceBase* device = ctx->function_library()->device(); - params.allocator_getter = [device](AllocatorAttributes attrs) { - return device->GetAllocator(attrs); - }; - IteratorContext iter_ctx(std::move(params)); - MultiDeviceIteratorCallback callback = std::bind( [ctx](const HostBufferElement& elem, DoneCallback done) { // iterator->Unref(); @@ -526,6 +520,9 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel { }, std::placeholders::_1, std::move(done)); + IteratorContext::Params params(ctx); + params.function_library = iterator->function_library(); + IteratorContext iter_ctx(std::move(params)); iterator->GetNextFromShard(&iter_ctx, shard_num, incarnation_id, callback); iterator->Unref(); @@ -534,7 +531,7 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel { } private: - std::unique_ptr thread_pool_; + BackgroundWorker background_worker_; }; REGISTER_KERNEL_BUILDER( diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc index 47499633a2561f..726220e06bff22 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op.cc +++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc @@ -94,9 +94,9 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { Node* input_node = nullptr; SerializationContext::Params params; std::vector> input_list; - params.allow_stateful_functions = true; params.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); params.input_list = &input_list; + params.optimization_only = true; SerializationContext serialization_ctx(params); TF_RETURN_IF_ERROR( db.AddInputDataset(&serialization_ctx, input_, &input_node)); @@ -164,22 +164,28 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { : DatasetIterator(params) {} Status Initialize(IteratorContext* ctx) override { - IteratorContext::Params params = ctx->params(); + IteratorContext::Params params(ctx); params.lib = dataset()->lib_; return dataset()->optimized_input_->MakeIterator( - IteratorContext(params), prefix(), &input_impl_); + IteratorContext(std::move(params)), prefix(), &input_impl_); } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { - IteratorContext::Params params = ctx->params(); + IteratorContext::Params params(ctx); params.lib = dataset()->lib_; - return input_impl_->GetNext(IteratorContext(params), out_tensors, - end_of_sequence); + return input_impl_->GetNext(IteratorContext(std::move(params)), + out_tensors, end_of_sequence); } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); return Status::OK(); diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc index b1515203b359eb..594a9ce7ec2d3a 100644 --- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc @@ -207,7 +207,6 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { : DatasetIterator(params) {} Status Initialize(IteratorContext* ctx) override { - AddConstantParameter(ctx, "batch_size", dataset()->batch_size_); return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); } @@ -338,6 +337,12 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + dataset()->batch_size_); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); if (input_impl_) diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index 9c7bc5482cc55d..985e197a9934ec 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset_utils.h" @@ -246,7 +247,6 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { } Status Initialize(IteratorContext* ctx) override { - AddConstantParameter(ctx, "parallelism", dataset()->cycle_length_); TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); return dataset()->captured_func_->Instantiate(ctx); @@ -360,6 +360,12 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeAsyncInterleaveManyNode(std::move(args), + /*parameters=*/{}); + } + Status SaveInternal(IteratorStateWriter* writer) override { // The order of locking is important here to avoid deadlock. mutex_lock l(mu_); @@ -483,7 +489,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { for (size_t i = 0; i < dataset()->num_threads(); ++i) { std::shared_ptr new_ctx(new IteratorContext(*ctx)); worker_threads_.emplace_back(ctx->env()->StartThread( - {}, "worker_thread", + {}, strings::StrCat("tf_data_parallel_interleave_worker_", i), [this, new_ctx, i]() { WorkerThread(new_ctx, i); })); } } @@ -582,7 +588,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { workers_[i].SetInputs(s, std::move(args)); std::shared_ptr new_ctx(new IteratorContext(*ctx)); worker_threads_.emplace_back(ctx->env()->StartThread( - {}, "worker_thread", + {}, strings::StrCat("tf_data_parallel_interleave_worker_", i), [this, new_ctx, i]() { WorkerThread(new_ctx, i); })); if (i < dataset()->cycle_length_) { interleave_indices_.push_back(i); @@ -1157,14 +1163,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { - if (sloppy_) { - return MakeUnique( - ParallelInterleaveIteratorBase::Params{ - this, strings::StrCat(prefix, "::ParallelInterleaveV2")}); - } - return MakeUnique( - ParallelInterleaveIteratorBase::Params{ - this, strings::StrCat(prefix, "::ParallelInterleaveV2")}); + return MakeUnique( + ParallelInterleaveIterator::Params{ + this, strings::StrCat(prefix, "::ParallelInterleaveV2")}, + sloppy_); } const DataTypeVector& output_dtypes() const override { @@ -1225,23 +1227,29 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { } private: - class ParallelInterleaveIteratorBase : public DatasetIterator { + class ParallelInterleaveIterator : public DatasetIterator { public: - explicit ParallelInterleaveIteratorBase(const Params& params) + explicit ParallelInterleaveIterator(const Params& params, bool sloppy) : DatasetIterator(params), mu_(std::make_shared()), cond_var_(std::make_shared()), num_parallel_calls_(std::make_shared( params.dataset->num_parallel_calls_, mu_, cond_var_)), + sloppy_(sloppy), args_list_(params.dataset->cycle_length_), current_elements_(params.dataset->cycle_length_), element_in_use_(params.dataset->cycle_length_, false), thread_pool_(new thread::ThreadPool( - Env::Default(), ThreadOptions(), "parallel_interleave", + Env::Default(), ThreadOptions(), + "tf_data_parallel_interleave_worker_pool", dataset()->cycle_length_ /* num_threads */, - false /* low_latency_hint */)) {} + false /* low_latency_hint */)) { + std::vector components = + str_util::Split(params.prefix, "::", str_util::SkipEmpty()); + prefix_end_ = components.back(); + } - ~ParallelInterleaveIteratorBase() override { + ~ParallelInterleaveIterator() override { mutex_lock l(*mu_); // Cancel the runner thread. cancelled_ = true; @@ -1255,13 +1263,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { Status Initialize(IteratorContext* ctx) override { mutex_lock l(*mu_); if (num_parallel_calls_->value == kAutoTune) { - num_parallel_calls_->value = 1; - AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1, - dataset()->cycle_length_); - } else { - AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value); + num_parallel_calls_->value = dataset()->cycle_length_; + num_parallel_calls_->tunable = true; } - AddConstantParameter(ctx, "cycle_length", dataset()->cycle_length_); TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); return dataset()->captured_func_->Instantiate(ctx); @@ -1299,17 +1303,13 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { } protected: - struct InvocationResult { - Notification notification; // used for coordination with the consumer - Status status; // the invocation status - std::vector return_values; // the invocation result values - bool skip; // if set the result should be skipped - }; - - // Used by the consumer to determine whether it needs to wait. Upon - // returning false, `result` will either be NULL if end of input has been - // reached or point to a result to consume. - virtual bool ShouldWait(std::shared_ptr* result) = 0; + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeAsyncInterleaveManyNode( + std::move(args), + {model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1, + /*max=*/dataset()->cycle_length_)}); + } Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(*mu_); @@ -1397,12 +1397,20 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { return Status::OK(); } + private: + struct InvocationResult { + Notification notification; // used for coordination with the consumer + Status status; // the invocation status + std::vector return_values; // the invocation result values + bool skip; // if set the result should be skipped + }; + void EnsureRunnerThreadStarted(IteratorContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { std::shared_ptr new_ctx(new IteratorContext(*ctx)); runner_thread_.reset(ctx->env()->StartThread( - {}, "runner_thread", + {}, "tf_data_parallel_interleave_runner", [this, new_ctx]() { RunnerThread(new_ctx); })); } } @@ -1445,6 +1453,12 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { mutex_lock l(*mu_); element_in_use_[cycle_index] = false; num_calls_--; + const auto& stats_aggregator = ctx->stats_aggregator(); + if (stats_aggregator) { + stats_aggregator->AddScalar( + strings::StrCat(prefix_end_, "::active_parallel_calls"), + static_cast(num_calls_)); + } if (end_of_input) { args_list_[cycle_index].clear(); num_open_--; @@ -1522,15 +1536,54 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { num_calls_++; element_in_use_[cycle_index_] = true; thread_pool_->Schedule( - std::bind(&ParallelInterleaveIteratorBase::FetchOutputs, this, + std::bind(&ParallelInterleaveIterator::FetchOutputs, this, ctx, cycle_index_, std::move(results))); } cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_; } + const auto& stats_aggregator = ctx->stats_aggregator(); + if (stats_aggregator) { + // TODO(shivaniagrawal): add `parallel_calls_utilization` in the + // monitoring code or as histogram at fixed time intervals. + stats_aggregator->AddScalar( + strings::StrCat(prefix_end_, "::active_parallel_calls"), + static_cast(num_calls_)); + stats_aggregator->AddScalar( + strings::StrCat(prefix_end_, "::num_parallel_calls"), + static_cast(num_parallel_calls_->value)); + } cond_var_->notify_all(); } } + // Determines whether the caller needs to wait for a result. Upon + // returning false, `result` will either be NULL if end of input has been + // reached or point to the result. + bool ShouldWait(std::shared_ptr* result) + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + if (sloppy_) { + for (auto it = invocation_results_.begin(); + it != invocation_results_.end(); ++it) { + if ((*it)->notification.HasBeenNotified()) { + std::swap(*result, *it); + invocation_results_.erase(it); + cond_var_->notify_all(); + return false; + } + } + return !invocation_results_.empty() || + (!end_of_input_ || num_open_ > 0); + } else { + if (!invocation_results_.empty()) { + std::swap(*result, invocation_results_.front()); + invocation_results_.pop_front(); + cond_var_->notify_all(); + return false; + } + return (!end_of_input_ || num_open_ > 0); + } + } + Status WriteStatusLocked(IteratorStateWriter* writer, size_t index, const Status& status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { @@ -1631,6 +1684,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { // Identifies the maximum number of parallel calls. const std::shared_ptr num_parallel_calls_; + // Determines whether outputs can be produced in non-deterministic order. + const bool sloppy_; + // Iterator for input elements. std::unique_ptr input_impl_ GUARDED_BY(*mu_); @@ -1665,45 +1721,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { // Identifies whether background activity should be cancelled. bool cancelled_ GUARDED_BY(*mu_) = false; - }; - - class DeterministicParallelInterleave - : public ParallelInterleaveIteratorBase { - public: - using ParallelInterleaveIteratorBase::ParallelInterleaveIteratorBase; - - protected: - bool ShouldWait(std::shared_ptr* result) override - EXCLUSIVE_LOCKS_REQUIRED(*mu_) { - if (!invocation_results_.empty()) { - std::swap(*result, invocation_results_.front()); - invocation_results_.pop_front(); - cond_var_->notify_all(); - return false; - } - return (!end_of_input_ || num_open_ > 0); - } - }; - - class SloppyParallelInterleave : public ParallelInterleaveIteratorBase { - public: - using ParallelInterleaveIteratorBase::ParallelInterleaveIteratorBase; - - protected: - bool ShouldWait(std::shared_ptr* result) override - EXCLUSIVE_LOCKS_REQUIRED(*mu_) { - for (auto it = invocation_results_.begin(); - it != invocation_results_.end(); ++it) { - if ((*it)->notification.HasBeenNotified()) { - std::swap(*result, *it); - invocation_results_.erase(it); - cond_var_->notify_all(); - return false; - } - } - return !invocation_results_.empty() || - (!end_of_input_ || num_open_ > 0); - } + string prefix_end_; }; const DatasetBase* const input_; diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc index 39809f5e9a5ba6..ec1c92384304d0 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.cc +++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/util/ptr_util.h" @@ -29,13 +30,13 @@ namespace tensorflow { namespace data { namespace { -class ParallelMapIteratorBase : public DatasetBaseIterator { +class ParallelMapIterator : public DatasetBaseIterator { public: - ParallelMapIteratorBase( - const typename DatasetBaseIterator::BaseParams& params, - const DatasetBase* input_dataset, - std::function init_func, - ParallelMapIteratorFunction map_func, int32 num_parallel_calls) + ParallelMapIterator(const typename DatasetBaseIterator::BaseParams& params, + const DatasetBase* input_dataset, + std::function init_func, + ParallelMapIteratorFunction map_func, + int32 num_parallel_calls, bool sloppy) : DatasetBaseIterator(params), input_dataset_(input_dataset), init_func_(std::move(init_func)), @@ -43,9 +44,14 @@ class ParallelMapIteratorBase : public DatasetBaseIterator { mu_(std::make_shared()), cond_var_(std::make_shared()), num_parallel_calls_(std::make_shared( - num_parallel_calls, mu_, cond_var_)) {} + num_parallel_calls, mu_, cond_var_)), + sloppy_(sloppy) { + std::vector components = + str_util::Split(params.prefix, "::", str_util::SkipEmpty()); + prefix_end_ = components.back(); + } - ~ParallelMapIteratorBase() override { + ~ParallelMapIterator() override { mutex_lock l(*mu_); // Cancel the runner thread. cancelled_ = true; @@ -59,13 +65,8 @@ class ParallelMapIteratorBase : public DatasetBaseIterator { Status Initialize(IteratorContext* ctx) override { mutex_lock l(*mu_); if (num_parallel_calls_->value == kAutoTune) { - num_parallel_calls_->value = 1; - // TODO(jsimsa): Surface the number of threads used by `ctx->runner()` and - // use it here for the maximum. - AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1, - port::NumSchedulableCPUs()); - } else { - AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value); + num_parallel_calls_->value = ctx->runner_threadpool_size(); + num_parallel_calls_->tunable = true; } TF_RETURN_IF_ERROR( input_dataset_->MakeIterator(ctx, prefix(), &input_impl_)); @@ -94,16 +95,14 @@ class ParallelMapIteratorBase : public DatasetBaseIterator { } protected: - struct InvocationResult { - Notification notification; - Status status; - std::vector return_values; - bool end_of_input; - }; - - // Used by the consumer to determine whether it needs to wait. Upon returning - // false, `result` will point to a result to consume. - virtual bool ShouldWait(std::shared_ptr* result) = 0; + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeAsyncKnownRatioNode( + std::move(args), + /*ratio=*/1, + {model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1, + /*max=*/ctx->runner_threadpool_size())}); + } Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(*mu_); @@ -127,10 +126,10 @@ class ParallelMapIteratorBase : public DatasetBaseIterator { result.return_values[j])); } if (result.end_of_input) { - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name( - strings::StrCat("invocation_results[", i, "].end_of_input")), - "")); + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name(strings::StrCat("invocation_results[", + i, "].end_of_input")), + "")); } } return Status::OK(); @@ -141,8 +140,8 @@ class ParallelMapIteratorBase : public DatasetBaseIterator { mutex_lock l(*mu_); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); int64 invocation_results_size; - TF_RETURN_IF_ERROR(reader->ReadScalar( - full_name("invocation_results.size"), &invocation_results_size)); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("invocation_results.size"), + &invocation_results_size)); for (size_t i = 0; i < invocation_results_size; i++) { invocation_results_.push_back(std::make_shared()); auto& result = *invocation_results_.back(); @@ -150,15 +149,13 @@ class ParallelMapIteratorBase : public DatasetBaseIterator { size_t num_return_values; { int64 size; - TF_RETURN_IF_ERROR( - reader->ReadScalar(full_name(strings::StrCat( - "invocation_results[", i, "].size")), - &size)); + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat("invocation_results[", i, "].size")), + &size)); num_return_values = static_cast(size); if (num_return_values != size) { return errors::InvalidArgument(strings::StrCat( - full_name( - strings::StrCat("invocation_results[", i, "].size")), + full_name(strings::StrCat("invocation_results[", i, "].size")), ": ", size, " is not a valid value of type size_t.")); } } @@ -176,20 +173,35 @@ class ParallelMapIteratorBase : public DatasetBaseIterator { return Status::OK(); } + private: + struct InvocationResult { + Notification notification; + Status status; + std::vector return_values; + bool end_of_input; + }; + void EnsureRunnerThreadStarted(IteratorContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { auto ctx_copy = std::make_shared(*ctx); runner_thread_.reset(ctx->env()->StartThread( - {}, "runner_thread", - std::bind(&ParallelMapIteratorBase::RunnerThread, this, ctx_copy))); + {}, "tf_data_parallel_map", + std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy))); } } - void CallCompleted(const std::shared_ptr& result) + void CallCompleted(const std::shared_ptr& ctx, + const std::shared_ptr& result) LOCKS_EXCLUDED(*mu_) { mutex_lock l(*mu_); num_calls_--; + const auto& stats_aggregator = ctx->stats_aggregator(); + if (stats_aggregator) { + stats_aggregator->AddScalar( + strings::StrCat(prefix_end_, "::active_parallel_calls"), + static_cast(num_calls_)); + } result->notification.Notify(); cond_var_->notify_all(); } @@ -202,13 +214,13 @@ class ParallelMapIteratorBase : public DatasetBaseIterator { result->status = input_impl_->GetNext(ctx.get(), &input_element, &result->end_of_input); if (result->end_of_input || !result->status.ok()) { - CallCompleted(result); + CallCompleted(ctx, result); return; } - auto done = [this, result](Status status) { + auto done = [this, ctx, result](Status status) { result->status.Update(status); - CallCompleted(result); + CallCompleted(ctx, result); }; // Apply the map function on `input_element`, storing the result in @@ -218,8 +230,8 @@ class ParallelMapIteratorBase : public DatasetBaseIterator { } Status ProcessResult(const std::shared_ptr& result, - std::vector* out_tensors, - bool* end_of_sequence) { + std::vector* out_tensors, bool* end_of_sequence) + LOCKS_EXCLUDED(*mu_) { if (!result->end_of_input && result->status.ok()) { *out_tensors = std::move(result->return_values); *end_of_sequence = false; @@ -235,7 +247,8 @@ class ParallelMapIteratorBase : public DatasetBaseIterator { return result->status; } - void RunnerThread(const std::shared_ptr& ctx) { + void RunnerThread(const std::shared_ptr& ctx) + LOCKS_EXCLUDED(*mu_) { RecordStart(ctx.get()); auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); std::vector> new_calls; @@ -261,6 +274,17 @@ class ParallelMapIteratorBase : public DatasetBaseIterator { new_calls.push_back(invocation_results_.back()); num_calls_++; } + const auto& stats_aggregator = ctx->stats_aggregator(); + if (stats_aggregator) { + // TODO(shivaniagrawal): add `parallel_calls_utilization` in the + // monitoring code or as histogram at fixed time intervals. + stats_aggregator->AddScalar( + strings::StrCat(prefix_end_, "::active_parallel_calls"), + static_cast(num_calls_)); + stats_aggregator->AddScalar( + strings::StrCat(prefix_end_, "::num_parallel_calls"), + static_cast(num_parallel_calls_->value)); + } cond_var_->notify_all(); } for (const auto& call : new_calls) { @@ -270,6 +294,30 @@ class ParallelMapIteratorBase : public DatasetBaseIterator { } } + // Determines whether the caller needs to wait for a result. Upon returning + // false, `result` will point to the result. + bool ShouldWait(std::shared_ptr* result) + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + if (sloppy_) { + for (auto it = invocation_results_.begin(); + it != invocation_results_.end(); ++it) { + if ((*it)->notification.HasBeenNotified() && + (it == invocation_results_.begin() || !(*it)->end_of_input)) { + std::swap(*result, *it); + invocation_results_.erase(it); + cond_var_->notify_all(); + return false; + } + } + } else if (!invocation_results_.empty()) { + std::swap(*result, invocation_results_.front()); + invocation_results_.pop_front(); + cond_var_->notify_all(); + return false; + } + return true; + } + Status WriteStatusLocked(IteratorStateWriter* writer, size_t index, const Status& status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { @@ -300,8 +348,7 @@ class ParallelMapIteratorBase : public DatasetBaseIterator { } string CodeKey(size_t index) { - return full_name( - strings::StrCat("invocation_results[", index, "].code")); + return full_name(strings::StrCat("invocation_results[", index, "].code")); } string ErrorMessageKey(size_t index) { @@ -322,6 +369,8 @@ class ParallelMapIteratorBase : public DatasetBaseIterator { const std::shared_ptr cond_var_; // Identifies the maximum number of parallel calls. const std::shared_ptr num_parallel_calls_; + // Determines whether outputs can be produced in non-deterministic order. + const bool sloppy_; // Counts the number of outstanding calls. int64 num_calls_ GUARDED_BY(*mu_) = 0; std::unique_ptr input_impl_; @@ -330,56 +379,7 @@ class ParallelMapIteratorBase : public DatasetBaseIterator { GUARDED_BY(*mu_); std::unique_ptr runner_thread_ GUARDED_BY(*mu_); bool cancelled_ GUARDED_BY(*mu_) = false; -}; - -class DeterministicParallelMapIterator : public ParallelMapIteratorBase { - public: - DeterministicParallelMapIterator( - const typename DatasetBaseIterator::BaseParams& params, - const DatasetBase* input_dataset, - std::function init_func, - ParallelMapIteratorFunction map_func, int32 num_parallel_calls) - : ParallelMapIteratorBase(params, input_dataset, init_func, map_func, - num_parallel_calls) {} - - protected: - bool ShouldWait(std::shared_ptr* result) override - EXCLUSIVE_LOCKS_REQUIRED(*mu_) { - if (!invocation_results_.empty()) { - std::swap(*result, invocation_results_.front()); - invocation_results_.pop_front(); - cond_var_->notify_all(); - return false; - } - return true; - } -}; - -class SloppyParallelMapIterator : public ParallelMapIteratorBase { - public: - SloppyParallelMapIterator( - const typename DatasetBaseIterator::BaseParams& params, - const DatasetBase* input_dataset, - std::function init_func, - ParallelMapIteratorFunction map_func, int32 num_parallel_calls) - : ParallelMapIteratorBase(params, input_dataset, init_func, map_func, - num_parallel_calls) {} - - protected: - bool ShouldWait(std::shared_ptr* result) override - EXCLUSIVE_LOCKS_REQUIRED(*mu_) { - for (auto it = invocation_results_.begin(); it != invocation_results_.end(); - ++it) { - if ((*it)->notification.HasBeenNotified() && - (it == invocation_results_.begin() || !(*it)->end_of_input)) { - std::swap(*result, *it); - invocation_results_.erase(it); - cond_var_->notify_all(); - return false; - } - } - return true; - } + string prefix_end_; }; } // namespace @@ -390,14 +390,9 @@ std::unique_ptr NewParallelMapIterator( std::function init_func, ParallelMapIteratorFunction map_func, int32 num_parallel_calls, bool sloppy) { - if (sloppy) { - return MakeUnique( - params, input_dataset, std::move(init_func), std::move(map_func), - num_parallel_calls); - } - return MakeUnique( + return MakeUnique( params, input_dataset, std::move(init_func), std::move(map_func), - num_parallel_calls); + num_parallel_calls, sloppy); } } // namespace data diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc index aa14c9c84d1da2..608b39d5f50e11 100644 --- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc +++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc @@ -142,11 +142,11 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { it->second = i++; } - *output = new Dataset( - ctx, input, std::move(dense_defaults), std::move(sparse_keys_), - std::move(dense_keys_), std::move(key_to_output_index), - std::move(config), num_parallel_calls, sparse_types_, dense_types_, - dense_shapes_, output_types_, output_shapes_, sloppy_); + *output = + new Dataset(ctx, input, dense_defaults, sparse_keys_, dense_keys_, + std::move(key_to_output_index), std::move(config), + num_parallel_calls, sparse_types_, dense_types_, + dense_shapes_, output_types_, output_shapes_, sloppy_); } private: diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc index e1d42a9a6be1bf..960373b74f3c65 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc @@ -103,7 +103,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { - auto stats_aggregator = ctx->stats_aggregator(); + const auto& stats_aggregator = ctx->stats_aggregator(); { mutex_lock l(mu_); TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx)); @@ -123,7 +123,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { } if (!buffer_.empty()) { - return Consume(out_tensors, end_of_sequence, stats_aggregator); + return Consume(out_tensors, end_of_sequence, ctx); } if (prefetch_thread_finished_) { @@ -148,6 +148,13 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeAsyncKnownRatioNode(std::move(args), + /*ratio=*/1, + /*parameters=*/{}); + } + Status SaveInternal(IteratorStateWriter* writer) override { // Acquire both locks to ensure that the prefetch thread and // all GetNext threads are blocked. @@ -220,8 +227,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { }; Status Consume(std::vector* out_tensors, bool* end_of_sequence, - const std::shared_ptr& stats_aggregator) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { + IteratorContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + const auto& stats_aggregator = ctx->stats_aggregator(); if (stats_aggregator) { stats_aggregator->AddToHistogram( strings::StrCat(prefix_end_, "::buffer_utilization"), @@ -258,7 +265,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { if (!prefetch_thread_) { std::shared_ptr new_ctx(new IteratorContext(*ctx)); prefetch_thread_.reset(ctx->env()->StartThread( - {}, "prefetch_thread", + {}, "tf_data_prefetch", [this, new_ctx]() { PrefetchThread(new_ctx); })); } return Status::OK(); diff --git a/tensorflow/core/kernels/data/random_dataset_op.cc b/tensorflow/core/kernels/data/random_dataset_op.cc index 63c137e2dd3c8c..816405fea90ef5 100644 --- a/tensorflow/core/kernels/data/random_dataset_op.cc +++ b/tensorflow/core/kernels/data/random_dataset_op.cc @@ -108,6 +108,11 @@ class RandomDatasetOp : public DatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeSourceNode(std::move(args)); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_random_samples"), diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc index c9ab25c0084e4c..1ad5b007751895 100644 --- a/tensorflow/core/kernels/data/range_dataset_op.cc +++ b/tensorflow/core/kernels/data/range_dataset_op.cc @@ -114,6 +114,11 @@ class RangeDatasetOp : public DatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeSourceNode(std::move(args)); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("next"), next_)); diff --git a/tensorflow/core/kernels/data/reader_dataset_ops.cc b/tensorflow/core/kernels/data/reader_dataset_ops.cc index b1acbfc5a16096..ea97cf5ffdc0b2 100644 --- a/tensorflow/core/kernels/data/reader_dataset_ops.cc +++ b/tensorflow/core/kernels/data/reader_dataset_ops.cc @@ -168,6 +168,11 @@ class TextLineDatasetOp : public DatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeSourceNode(std::move(args)); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"), @@ -435,6 +440,11 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeSourceNode(std::move(args)); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"), @@ -622,6 +632,11 @@ class TFRecordDatasetOp : public DatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeSourceNode(std::move(args)); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"), diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc index 486d1802ba5734..cee14df69d07a1 100644 --- a/tensorflow/core/kernels/data/repeat_dataset_op.cc +++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc @@ -97,6 +97,12 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { return Status::OK(); } @@ -139,6 +145,12 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_)); @@ -210,6 +222,12 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); if (!first_call_) diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc index 57ab98c795e36f..d9182d15bed272 100644 --- a/tensorflow/core/kernels/data/scan_dataset_op.cc +++ b/tensorflow/core/kernels/data/scan_dataset_op.cc @@ -211,6 +211,12 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index acf95a43ce7a20..ad6960685e4284 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -164,6 +164,12 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); // Save state needed to restore the random number generators. @@ -400,6 +406,12 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase { seed2) {} protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(dataset()->mu_); diff --git a/tensorflow/core/kernels/data/skip_dataset_op.cc b/tensorflow/core/kernels/data/skip_dataset_op.cc index 2218c405206c41..8379383662a760 100644 --- a/tensorflow/core/kernels/data/skip_dataset_op.cc +++ b/tensorflow/core/kernels/data/skip_dataset_op.cc @@ -93,6 +93,12 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { return Status::OK(); } @@ -149,6 +155,12 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_)); diff --git a/tensorflow/core/kernels/data/slide_dataset_op.cc b/tensorflow/core/kernels/data/slide_dataset_op.cc index 23ad4ab89f8668..e67c5272b6fa3e 100644 --- a/tensorflow/core/kernels/data/slide_dataset_op.cc +++ b/tensorflow/core/kernels/data/slide_dataset_op.cc @@ -223,6 +223,12 @@ class SlideDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + dataset()->window_shift_); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); if (!input_impl_) { diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc index d0061523ca394b..a002c605357381 100644 --- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc @@ -152,6 +152,11 @@ class Dataset : public DatasetBase { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeSourceNode(std::move(args)); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(Iterator::full_name("i"), i_)); diff --git a/tensorflow/core/kernels/data/sql_dataset_ops.cc b/tensorflow/core/kernels/data/sql_dataset_ops.cc index 45c2408b14dcaf..f01ecf84afab05 100644 --- a/tensorflow/core/kernels/data/sql_dataset_ops.cc +++ b/tensorflow/core/kernels/data/sql_dataset_ops.cc @@ -147,6 +147,11 @@ class SqlDatasetOp : public DatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeSourceNode(std::move(args)); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); if (query_connection_initialized_) { diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc index 830eedfabb6163..a21b3fc16b7a93 100644 --- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc +++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc @@ -163,22 +163,22 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { mutex_lock l(mu_); StatsAggregatorResource* stats_aggregator_resource = dataset()->stats_aggregator_resource_; - IteratorContext::Params params; - params.env = ctx->env(); - params.runner = *(ctx->runner()); + IteratorContext::Params params(ctx); params.stats_aggregator = std::shared_ptr( new StatsAggregatorWithTagAndPrefix( stats_aggregator_resource->stats_aggregator(), dataset()->tag_, dataset()->prefix_)); - params.lib = ctx->lib(); - params.function_library = ctx->function_library(); - params.allocator_getter = ctx->allocator_getter(); - IteratorContext set_stats_aggregator_ctx(params); - return input_impl_->GetNext(&set_stats_aggregator_ctx, out_tensors, - end_of_sequence); + IteratorContext iter_ctx(std::move(params)); + return input_impl_->GetNext(&iter_ctx, out_tensors, end_of_sequence); } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { return errors::Unimplemented(dataset()->DebugString(), " does not support checkpointing"); diff --git a/tensorflow/core/kernels/data/stats_dataset_ops.cc b/tensorflow/core/kernels/data/stats_dataset_ops.cc index b360099ba36033..da0039773cee1b 100644 --- a/tensorflow/core/kernels/data/stats_dataset_ops.cc +++ b/tensorflow/core/kernels/data/stats_dataset_ops.cc @@ -116,6 +116,12 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); @@ -220,6 +226,12 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); diff --git a/tensorflow/core/kernels/data/take_dataset_op.cc b/tensorflow/core/kernels/data/take_dataset_op.cc index 97122570049dc6..57c9b0d57f6812 100644 --- a/tensorflow/core/kernels/data/take_dataset_op.cc +++ b/tensorflow/core/kernels/data/take_dataset_op.cc @@ -94,6 +94,12 @@ class TakeDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { return Status::OK(); } @@ -136,6 +142,12 @@ class TakeDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_)); diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc index ac5035c4d3d284..c7d374f489740a 100644 --- a/tensorflow/core/kernels/data/tensor_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc @@ -69,10 +69,10 @@ class TensorDatasetOp : public DatasetOpKernel { components.reserve(tensors_.size()); for (const Tensor& t : tensors_) { Node* node; - std::vector>* input_list = ctx->input_list(); - if (input_list) { + if (ctx->optimization_only()) { TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node)); - input_list->emplace_back(node->name(), t); + DCHECK_NE(ctx->input_list(), nullptr); + ctx->input_list()->emplace_back(node->name(), t); } else { TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); } @@ -107,6 +107,11 @@ class TensorDatasetOp : public DatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeSourceNode(std::move(args)); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); if (produced_) diff --git a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc index 195da54bac2da3..7fd1c4c9e0488a 100644 --- a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc @@ -419,6 +419,11 @@ class PrependFromQueueAndPaddedBatchDataset : public DatasetBase { const DatasetBase* dataset_input() const { return dataset()->input_; } + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), dataset()->batch_size_); + } + Status SaveInternal(IteratorStateWriter* writer) override { return queue_->Save(this, writer); } diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc index 2cdf91374887d1..6291bfc110bafe 100644 --- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc @@ -92,10 +92,10 @@ class TensorSliceDatasetOp : public DatasetOpKernel { components.reserve(tensors_.size()); for (const Tensor& t : tensors_) { Node* node; - std::vector>* input_list = ctx->input_list(); - if (input_list) { + if (ctx->optimization_only()) { TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node)); - input_list->emplace_back(node->name(), t); + DCHECK_NE(ctx->input_list(), nullptr); + ctx->input_list()->emplace_back(node->name(), t); } else { TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); } @@ -140,6 +140,11 @@ class TensorSliceDatasetOp : public DatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeSourceNode(std::move(args)); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_)); diff --git a/tensorflow/core/kernels/data/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/unbatch_dataset_op.cc index b3219a01c32969..b32ab8ba4faa7b 100644 --- a/tensorflow/core/kernels/data/unbatch_dataset_op.cc +++ b/tensorflow/core/kernels/data/unbatch_dataset_op.cc @@ -145,6 +145,20 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + // Unbatch assumes that all input components have the same leading + // dimension. If it is statically known for any component, we model the + // transformation using `KnownRatio`. Otherwise, we use `UnknownRatio`. + for (auto& shape : dataset()->input_->output_shapes()) { + if (shape.dims() > 0 && shape.dim_size(0) > 0) { + return model::MakeKnownRatioNode( + std::move(args), 1.0 / static_cast(shape.dim_size(0))); + } + } + return model::MakeUnknownRatioNode(std::move(args)); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); if (input_impl_) { diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc index 3c45fa52888504..2c68e1ee05b542 100644 --- a/tensorflow/core/kernels/data/window_dataset_op.cc +++ b/tensorflow/core/kernels/data/window_dataset_op.cc @@ -232,6 +232,12 @@ class WindowDatasetOp : public UnaryDatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + dataset()->window_shift_); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); if (!input_impl_) { diff --git a/tensorflow/core/kernels/data/writer_ops.cc b/tensorflow/core/kernels/data/writer_ops.cc index 06d65681e87e59..66e759a135591c 100644 --- a/tensorflow/core/kernels/data/writer_ops.cc +++ b/tensorflow/core/kernels/data/writer_ops.cc @@ -29,10 +29,7 @@ class ToTFRecordOp : public AsyncOpKernel { public: explicit ToTFRecordOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx), - thread_pool_(new thread::ThreadPool( - ctx->env(), ThreadOptions(), - strings::StrCat("to_tf_record__op_", SanitizeThreadSuffix(name())), - 1 /* num_threads */, false /* low_latency_hint */)) {} + background_worker_(ctx->env(), "tf_data_to_tf_record") {} template Status ParseScalarArgument(OpKernelContext* ctx, @@ -47,10 +44,9 @@ class ToTFRecordOp : public AsyncOpKernel { } void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { - // The call to `iterator->GetNext()` may block and depend on an - // inter-op thread pool thread, so we issue the call from the - // owned thread pool. - thread_pool_->Schedule([this, ctx, done]() { + // The call to `iterator->GetNext()` may block and depend on an inter-op + // thread pool thread, so we issue the call using a background thread. + background_worker_.Schedule([this, ctx, done]() { string filename; OP_REQUIRES_OK_ASYNC( ctx, ParseScalarArgument(ctx, "filename", &filename), done); @@ -97,7 +93,7 @@ class ToTFRecordOp : public AsyncOpKernel { } private: - std::unique_ptr thread_pool_; + BackgroundWorker background_worker_; }; REGISTER_KERNEL_BUILDER(Name("DatasetToTFRecord").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/data/zip_dataset_op.cc b/tensorflow/core/kernels/data/zip_dataset_op.cc index f505ed542a5571..6e94d77867168d 100644 --- a/tensorflow/core/kernels/data/zip_dataset_op.cc +++ b/tensorflow/core/kernels/data/zip_dataset_op.cc @@ -136,6 +136,14 @@ class ZipDatasetOp : public DatasetOpKernel { } protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + // NOTE: Although this dataset may have multiple inputs, it always + // consumes one element per input to produce an output. + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); if (input_impls_.empty()) { diff --git a/tensorflow/core/kernels/eigen_activations_test.cc b/tensorflow/core/kernels/eigen_activations_test.cc index 34952f5abb8526..195525b02f9c3a 100644 --- a/tensorflow/core/kernels/eigen_activations_test.cc +++ b/tensorflow/core/kernels/eigen_activations_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/eigen_activations.h" -#include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/test.h" namespace Eigen { diff --git a/tensorflow/core/kernels/eigen_attention_test.cc b/tensorflow/core/kernels/eigen_attention_test.cc index 08f61877182cce..8886dba49613b8 100644 --- a/tensorflow/core/kernels/eigen_attention_test.cc +++ b/tensorflow/core/kernels/eigen_attention_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/eigen_attention.h" -#include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/test.h" namespace Eigen { diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc b/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc index 673ec1458b8fb7..e5500ba7dadc5b 100644 --- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc +++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/eigen_backward_spatial_convolutions.h" -#include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/core/kernels/eigen_cuboid_convolution.h b/tensorflow/core/kernels/eigen_cuboid_convolution.h index 6a9a2accd8d807..a98850cf4b31fa 100644 --- a/tensorflow/core/kernels/eigen_cuboid_convolution.h +++ b/tensorflow/core/kernels/eigen_cuboid_convolution.h @@ -948,7 +948,9 @@ class TensorContractionSubMapper< } private: - const ParentMapper& m_base_mapper; + const ParentMapper m_base_mapper; // Keeping a copy instead of a reference + // performs better in benchmarks. + Index m_depth_offset; // First row in the input matrix Index m_col_offset; // First col in the input matrix diff --git a/tensorflow/core/kernels/eigen_mkldnn.h b/tensorflow/core/kernels/eigen_mkldnn.h new file mode 100644 index 00000000000000..5235431f5f36e0 --- /dev/null +++ b/tensorflow/core/kernels/eigen_mkldnn.h @@ -0,0 +1,123 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_MKLDNN_H_ +#define TENSORFLOW_CORE_KERNELS_EIGEN_MKLDNN_H_ + +// Support for Mkldnn sgemm kernel in Eigen/Tensor contractions: +// +// 1. Prepare packed Lhs/Rhs blocks from tensor expressions using +// DataMapper (see TensorContractionInputMapper). +// 2. Invoke gemm kernel with packed blocks (replacement for default +// gebp_kernel). + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "third_party/intel_mkl_dnn/include/mkldnn.h" + +namespace Eigen { +namespace internal { + +template +struct mkldnn_gemm_pack; + +// mkl_gemm_pack for ColMajor storage order. +template +struct mkldnn_gemm_pack { + typedef typename internal::packet_traits::type Packet; + typedef typename DataMapper::LinearMapper LinearMapper; + + enum { PacketSize = internal::packet_traits::size }; + + EIGEN_DONT_INLINE + void operator()(Scalar *block, const DataMapper &data_mapper, IndexType rows, + IndexType cols) { + const IndexType unrolled_rows = + (rows / (4 * PacketSize)) * (4 * PacketSize); + const IndexType vectorized_rows = (rows / PacketSize) * PacketSize; + + for (IndexType col = 0; col < cols; ++col) { + LinearMapper lm = data_mapper.getLinearMapper(0, col); + + // Give compiler a strong possibility to unroll the loop. + for (IndexType i = 0; i < unrolled_rows; i += 4 * PacketSize) { + for (IndexType j = 0; j < 4; ++j) { + const Packet p = lm.loadPacket(i + j * PacketSize); + internal::pstoreu(block + j * PacketSize, p); + } + block += 4 * PacketSize; + } + + // Process remaining rows with packets. + for (IndexType i = unrolled_rows; i < vectorized_rows; i += PacketSize) { + const Packet p = lm.loadPacket(i); + internal::pstoreu(block, p); + block += PacketSize; + } + + // Finalize with coefficients. + for (IndexType i = vectorized_rows; i < rows; ++i) { + *block = lm(i); + ++block; + } + } + } +}; + +template +struct mkldnn_gemm_kernel; + +// mkldnn_gemm_kernel for floats defined as a thin layer on top of mkldnn_sgemm. +template +struct mkldnn_gemm_kernel { + EIGEN_DONT_INLINE + void operator()(const OutputMapper &output, const float *blockA, + const float *blockB, const IndexType rows, + const IndexType depth, const IndexType cols, float alpha) { + static const int max_index = (std::numeric_limits::max)(); + + eigen_assert(max_index >= rows); + eigen_assert(max_index >= cols); + eigen_assert(max_index >= depth); + eigen_assert(max_index >= output.stride()); + + const int m = static_cast(rows); + const int n = static_cast(cols); + const int k = static_cast(depth); + + const char transposeA = ConjugateLhs ? 'Y' : 'N'; + const char transposeB = ConjugateRhs ? 'Y' : 'N'; + + const int ldA = ConjugateLhs ? k : m; + const int ldB = ConjugateRhs ? n : k; + const int ldC = static_cast(output.stride()); + + const float beta = 1.0; + + mkldnn_status_t st = mkldnn_sgemm(&transposeA, &transposeB, &m, &n, &k, + &alpha, blockA, &ldA, blockB, &ldB, &beta, + const_cast(output.data()), &ldC); + eigen_assert(st == 0); + } +}; + +} // namespace internal +} // namespace Eigen + +#endif // TENSORFLOW_CORE_KERNELS_EIGEN_MKLDNN_H_ diff --git a/tensorflow/core/kernels/eigen_mkldnn_test.cc b/tensorflow/core/kernels/eigen_mkldnn_test.cc new file mode 100644 index 00000000000000..051ab28f792b44 --- /dev/null +++ b/tensorflow/core/kernels/eigen_mkldnn_test.cc @@ -0,0 +1,148 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/eigen_mkldnn.h" +#include "tensorflow/core/platform/test.h" + +namespace Eigen { +namespace internal { + +namespace { +template +Eigen::array RandomDims(int min_dim = 1, int max_dim = 20) { + Eigen::array dims; + for (int i = 0; i < NumDims; ++i) { + dims[i] = internal::random(min_dim, max_dim); + } + return dims; +} +} // namespace + +using Scalar = float; +using Index = Eigen::Index; + +TEST(EigenMkldnnTest, MkldnnPack) { + // Packing with mkldnn_gemm_pack is the same as taking a slice of 2 + // dimensional Tensor. + + // Mkldnn pack and gemm are used only in Tensor contractions, and it's + // guaranteed that Tensors will have ColMajor layout. + static const int Options = ColMajor; + + using DataMapper = blas_data_mapper; + using MkldnnGemmPack = mkldnn_gemm_pack; + using Tensor2d = Tensor; + + Eigen::array dims = RandomDims(1, 500); + + // Create a tensor initialized with random data. + Tensor2d src(dims); + src.setRandom(); + + // Pick a random slice of src tensor. + Eigen::array slice_start = RandomDims(0, 250); + Eigen::array slice_size = RandomDims(100, 500); + + // Make sure that slice start + size do not overflow tensor dims. + for (int i = 0; i < 2; ++i) { + slice_start[i] = numext::mini(dims[i] - 1, slice_start[i]); + slice_size[i] = numext::mini(slice_size[i], dims[i] - slice_start[i]); + } + + // Prepare tensors for packing and slicing results. + Tensor2d pack_dst(slice_size[0], slice_size[1]); + Tensor2d slice_dst(slice_size[0], slice_size[1]); + + // Pack memory using mkldnn_gemm_pack. + DataMapper data_mapper(src.data(), dims[0]); + MkldnnGemmPack gemm_pack; + gemm_pack(pack_dst.data(), + data_mapper.getSubMapper(slice_start[0], slice_start[1]), + slice_size[0], slice_size[1]); + + // Slice the source tensor. + slice_dst = src.slice(slice_start, slice_size); + + // Verify that dst tensors are equal. + EXPECT_EQ(pack_dst.dimensions().TotalSize(), + slice_dst.dimensions().TotalSize()); + for (size_t i = 0; i < pack_dst.dimensions().TotalSize(); ++i) { + Scalar packed = pack_dst.coeff(i); + Scalar sliced = slice_dst.coeff(i); + EXPECT_EQ(packed, sliced); + } +} + +TEST(EigenMkldnnTest, MkldnnGemm) { + // Mkldnn pack and gemm are used only in Tensor contractions, and it's + // guaranteed that Tensors will have ColMajor layout. + static const int Options = ColMajor; + + using Tensor2d = Tensor; + + int m = internal::random(1, 100); + int n = internal::random(1, 100); + int k = internal::random(1, 100); + + Tensor2d lhs(m, k); + lhs.setRandom(); + + Tensor2d rhs(k, n); + rhs.setRandom(); + + // Compute matmul with mkldnn gemm kernel. + using OutputMapper = blas_data_mapper; + using MkldnnGemmKernel = + mkldnn_gemm_kernel; + + Tensor2d mkldnn_result(m, n); + mkldnn_result.setZero(); + OutputMapper output_mapper(mkldnn_result.data(), m); + + MkldnnGemmKernel gemm_kernel; + gemm_kernel(output_mapper, lhs.data(), rhs.data(), m, k, n, /*alpha=*/1.0); + + // Compute matmul with Eigen::Matrix. + using Matrix = Eigen::Matrix; + using MatrixMap = Map>; + + MatrixMap lhs_mat(lhs.data(), m, k); + MatrixMap rhs_mat(rhs.data(), k, n); + + Matrix matmul_result(m, n); + matmul_result.setZero(); + matmul_result = lhs_mat * rhs_mat; + + // Verify that results are equal. + for (Index i = 0; i < m * n; ++i) { + Scalar gemm = mkldnn_result(i); + Scalar matmul = matmul_result(i % m, i / m); + + Scalar delta = std::abs(gemm - matmul); + + // NOTE(rmlarsen): Compute proper forward error bound. + Scalar sum = Scalar(0.0); + for (int k1 = 0; k1 < k; ++k1) { + sum += std::abs(lhs_mat(i % m, k1) * rhs_mat(k1, i / m)); + } + Scalar epsilon = std::numeric_limits::epsilon(); + Scalar upper_bound = Scalar(1.01) * epsilon * k * sum; + + EXPECT_LE(delta, upper_bound); + } +} + +} // namespace internal +} // namespace Eigen diff --git a/tensorflow/core/kernels/eigen_pooling_test.cc b/tensorflow/core/kernels/eigen_pooling_test.cc index 47b6665e680268..1fe9fd09dabbc1 100644 --- a/tensorflow/core/kernels/eigen_pooling_test.cc +++ b/tensorflow/core/kernels/eigen_pooling_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/eigen_pooling.h" -#include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/test.h" namespace Eigen { diff --git a/tensorflow/core/kernels/eigen_softmax_test.cc b/tensorflow/core/kernels/eigen_softmax_test.cc index 7f985d71366487..30a1ccca052487 100644 --- a/tensorflow/core/kernels/eigen_softmax_test.cc +++ b/tensorflow/core/kernels/eigen_softmax_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/eigen_softmax.h" -#include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/test.h" namespace Eigen { diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions.h b/tensorflow/core/kernels/eigen_spatial_convolutions.h index e926d73f87c0bb..a08c7064d5852a 100644 --- a/tensorflow/core/kernels/eigen_spatial_convolutions.h +++ b/tensorflow/core/kernels/eigen_spatial_convolutions.h @@ -726,9 +726,11 @@ class TensorContractionSubMapper< } private: - const ParentMapper& m_base_mapper; // that was a reference before - Index m_depth_offset; // First row in the input matrix - Index m_col_offset; // First col in the input matrix + const ParentMapper m_base_mapper; // Keeping a copy instead of a reference + // performs better in benchmarks. + + Index m_depth_offset; // First row in the input matrix + Index m_col_offset; // First col in the input matrix // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base // indices for the first element in a patch specified by col_offset diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions_test.cc b/tensorflow/core/kernels/eigen_spatial_convolutions_test.cc index 450b98785baf8a..b671421f5fde84 100644 --- a/tensorflow/core/kernels/eigen_spatial_convolutions_test.cc +++ b/tensorflow/core/kernels/eigen_spatial_convolutions_test.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/eigen_spatial_convolutions.h" -#include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/eigen_cuboid_convolution.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" namespace Eigen { @@ -1373,4 +1373,193 @@ TEST(EigenSpatialConvolutionsTest, SpatialConvContractionMapper) { EigenApprox(8.0f, direct(0, 1, 3, 0)); } +static void PackRhsHelper(int iters, + /* Input dimensions: */ + int input_batches, int input_cols, int input_rows, + int input_depth, + /* Filter (kernel) dimensions: */ + int filter_count, int filter_cols, int filter_rows) { + tensorflow::testing::UseRealTime(); + tensorflow::testing::StopTiming(); + + using Dimensions = Eigen::DSizes; + + // Default Eigen::Tensor layout is column major, so we configure dimensions + // starting from the inner most (channels aka depth in this case). + Dimensions input_dims(input_depth, input_rows, input_cols, input_batches); + + using Traits = typename Eigen::internal::gebp_traits; + static const int packet_size = Eigen::internal::packet_traits::size; + + // Reshape dimensions. + using NewDimension = Eigen::DSizes; + + // Contraction dimensions. + using nocontract_t = Eigen::array; + using contract_t = Eigen::array; + + // Input to the TensorImagePatchOp. It is the tensorflow TTypes::Tensor + // with ColMajor layout, instead of RowMajor. But that doesn't make any + // difference, because TensorContraction swaps LHS with RHS for row major + // inputs, and contraction mapper always works with column major data. + using ArgType = TensorMap, Eigen::Aligned>; + + using Evaluator = TensorEvaluator< + const TensorReshapingOp< + NewDimension, const TensorImagePatchOp>, + Eigen::DefaultDevice>; + + using InputMapper = Eigen::internal::TensorContractionInputMapper< + float, Index, Eigen::internal::Rhs, Evaluator, // + nocontract_t, contract_t, // + packet_size, // + /*inner_dim_contiguous*/ true, // + /*inner_dim_reordered*/ false, // + /*Alignment*/ 0>; + + using SubMapper = Eigen::internal::TensorContractionSubMapper< + float, Index, Eigen::internal::Rhs, Evaluator, // + nocontract_t, contract_t, // + packet_size, // + /*inner_dim_contiguous*/ true, // + /*inner_dim_reordered*/ false, // + /*Alignment*/ 0>; + + using PackRhsImpl = + Eigen::internal::gemm_pack_rhs; + + Eigen::DefaultDevice device; + + // Actual contract dimensions are not important. + const Eigen::Index not_important = -1234; + nocontract_t nocontract_dim = {not_important}; + contract_t contract_dim = {not_important}; + + // We use tensor of the same dimensions to store packed data. + Tensor packed(input_dims); + + // We generate multiple input tensors, around 512mb in total size to measure + // realistic workload when input data in not in L1-L3 cache. + size_t input_bytes = input_dims.TotalSize() * sizeof(float); + size_t mem_size_bytes = 1024 * 1024 * 512; + size_t num_inputs = + std::max(static_cast(1), mem_size_bytes / input_bytes); + + std::vector> inputs; + std::vector evaluators; + std::vector input_mappers; + + for (int i = 0; i < num_inputs; ++i) { + inputs.emplace_back(input_dims); + inputs[i].setRandom(); + + ArgType tensor_map(inputs[i].data(), input_dims); + + // 1. Extract image patches from input tensor. All strides are `1`. + const auto image_patch_op = TensorImagePatchOp( + tensor_map, // + filter_rows, filter_cols, // + /*row_strides=*/1, /*col_strides=*/1, // + /*in_row_strides=*/1, /*in_col_strides=*/1, // + /*row_inflate_strides=*/1, /*col_inflate_strides=*/1, // + Eigen::PADDING_SAME, /*padding_value=*/0.0); + + // 2. Reshape extracted patches into "virtual" 2d tensor. + // NOTE: for PADDING_SAME output {rows, cols} == input {rows, cols}. + NewDimension reshape_dims; + reshape_dims[0] = input_depth * filter_rows * filter_cols; // patch size + reshape_dims[1] = input_rows * input_cols * input_batches; // num_patches + + const auto reshape_op = + TensorReshapingOp( + image_patch_op, reshape_dims); + + evaluators.emplace_back(reshape_op, device); + + input_mappers.emplace_back(evaluators[i], nocontract_dim, nocontract_dim, + contract_dim, contract_dim); + } + + // We read properties of extracted image patches directly from evaluator. + const Index patch_depth = evaluators[0].impl().dimensions()[0]; + const Index patch_rows = evaluators[0].impl().dimensions()[1]; + const Index patch_cols = evaluators[0].impl().dimensions()[2]; + + // Number of patches is the same as the maximum column available through the + // InputMapper (SubMapper). + const Index num_patches = evaluators[0].impl().dimensions()[3]; + + // The size of a single patch, it's the same as the maximum depth available + // through the InputMapper (SubMapper). + const Index patch_size = patch_depth * patch_rows * patch_cols; + + PackRhsImpl pack_rhs; + + // This is the typical size of the rhs block used in Tensor contractions. + const Index default_depth = 320; // must be multiple of 8 + const Index default_cols = 280; + + const Index packed_total_size = input_dims.TotalSize(); + + tensorflow::testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + int input_idx = + num_inputs == 1 ? 1 : internal::random(0, num_inputs - 1); + + // Depth offset must be a multiple of 8 (float packet size with AVX2). + Index depth_offset = (internal::random(0, patch_size - 10) / 8) * 8; + Index col_offset = internal::random(0, num_patches - 10); + + Index depth = std::min(default_depth, patch_size - depth_offset); + Index cols = std::min(default_cols, num_patches - col_offset); + + // Write packed data to random memory location to emulate cold caches. + Index packed_size = depth * cols; + Index packed_offset = + internal::random(0, packed_total_size - packed_size - 1); + + pack_rhs(packed.data() + packed_offset, + input_mappers[input_idx].getSubMapper(depth_offset, col_offset), + depth, cols); + } + tensorflow::testing::StopTiming(); + + std::ostringstream stringStream; + stringStream << "patch: depth=" << patch_depth << " rows=" << patch_rows + << " cols=" << patch_cols << " num_patches=" << num_patches + << " patch_size=" << patch_size << " num_inputs=" << num_inputs; + tensorflow::testing::SetLabel(stringStream.str()); +} + +#define BM_NAME(prefix, N, H, W, C, FC, FH, FW) \ + BM_##prefix##_##N##_##H##x##W##_IC##C##_FC##FC##_##FH##x##FW + +#define BM_PackRhs(N, H, W, C, FC, FH, FW) \ + static void BM_NAME(PackRhs, N, H, W, C, FC, FH, FW)(int iters) { \ + PackRhsHelper(iters, N, H, W, C, FC, FH, FW); \ + } \ + BENCHMARK(BM_NAME(PackRhs, N, H, W, C, FC, FH, FW)) + +// Number of input channel (input depth) it equal to the number of patch +// channels (patch depth). + +// NOTE: This is the most common case in Tensorflow models. +// Fast path: input channel dimension is the multiple of the packet size. +BM_PackRhs(/*batch*/ 32, // + /*image*/ 64, 64, // + /*channels*/ 32, // + /*num_filters*/ 64, // + /*filter*/ 5, 5); + +// Slow path: input channel dimension is not the multiple of the packet size. +BM_PackRhs(/*batch*/ 32, // + /*image*/ 64, 64, // + /*channels*/ 30, // + /*num_filters*/ 64, // + /*filter*/ 5, 5); + } // namespace Eigen diff --git a/tensorflow/core/kernels/list_kernels.cc b/tensorflow/core/kernels/list_kernels.cc index 812530e6a8b1a1..3d0c193d9fce60 100644 --- a/tensorflow/core/kernels/list_kernels.cc +++ b/tensorflow/core/kernels/list_kernels.cc @@ -46,20 +46,26 @@ TensorList::TensorList(const TensorList& other) void TensorList::Encode(VariantTensorData* data) const { data->set_type_name(TypeName()); - for (const Tensor& t : tensors) { - *data->add_tensors() = t; + std::vector invalid_indices; + for (size_t i = 0; i < tensors.size(); i++) { + if (tensors.at(i).dtype() != DT_INVALID) { + *data->add_tensors() = tensors.at(i); + } else { + invalid_indices.push_back(i); + } } string metadata; - core::PutVarint64(&metadata, static_cast(element_dtype)); - if (!element_shape.unknown_rank()) { - for (TensorShapeDim dim : element_shape) { - if (dim.size > 0) { - core::PutVarint64(&metadata, dim.size); - } else { - core::PutVarint64(&metadata, std::numeric_limits::max()); - } - } + // TODO(b/118838800): Add a proto for storing the metadata. + // Metadata format: + // + core::PutVarint64(&metadata, static_cast(invalid_indices.size())); + for (size_t i : invalid_indices) { + core::PutVarint64(&metadata, static_cast(i)); } + core::PutVarint64(&metadata, static_cast(element_dtype)); + TensorShapeProto element_shape_proto; + element_shape.AsProto(&element_shape_proto); + element_shape_proto.AppendToString(&metadata); data->set_metadata(metadata); } @@ -98,23 +104,45 @@ Status TensorListShape(const TensorList& t, TensorShape* s) { REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(TensorList, TensorListShape); bool TensorList::Decode(const VariantTensorData& data) { - tensors = data.tensors(); + // TODO(srbs): Change the signature to Decode(VariantTensorData data) so + // that we do not have to copy each tensor individually below. This would + // require changing VariantTensorData::tensors() as well. string metadata; data.get_metadata(&metadata); uint64 scratch; StringPiece iter(metadata); + std::vector invalid_indices; core::GetVarint64(&iter, &scratch); - element_dtype = static_cast(scratch); - std::vector dims; - while (!iter.empty()) { + size_t num_invalid_tensors = static_cast(scratch); + invalid_indices.resize(num_invalid_tensors); + for (size_t i = 0; i < num_invalid_tensors; i++) { core::GetVarint64(&iter, &scratch); - if (scratch == std::numeric_limits::max()) { - dims.push_back(-1); + invalid_indices[i] = static_cast(scratch); + } + + size_t total_num_tensors = data.tensors().size() + num_invalid_tensors; + tensors.reserve(total_num_tensors); + std::vector::iterator invalid_indices_it = invalid_indices.begin(); + std::vector::const_iterator tensors_it = data.tensors().begin(); + for (size_t i = 0; i < total_num_tensors; i++) { + if (invalid_indices_it != invalid_indices.end() && + *invalid_indices_it == i) { + tensors.emplace_back(Tensor(DT_INVALID)); + invalid_indices_it++; + } else if (tensors_it != data.tensors().end()) { + tensors.emplace_back(*tensors_it); + tensors_it++; } else { - dims.push_back(scratch); + // VariantTensorData is corrupted. + return false; } } - element_shape = PartialTensorShape(dims); + + core::GetVarint64(&iter, &scratch); + element_dtype = static_cast(scratch); + TensorShapeProto element_shape_proto; + element_shape_proto.ParseFromString(string(iter.data(), iter.size())); + element_shape = PartialTensorShape(element_shape_proto); return true; } diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc index 38cadd10e1dd69..4ebe165937055a 100644 --- a/tensorflow/core/kernels/matmul_op.cc +++ b/tensorflow/core/kernels/matmul_op.cc @@ -580,13 +580,16 @@ struct MatMulFunctor { #if defined(INTEL_MKL) && defined(ENABLE_MKL) -// MKL does not support half, bfloat16 and int32 types for +// MKL supports float, double, complex64 and complex128 types for +// matrix-multiplication, and these kernels are registered in mkl_matmul_op.cc. +// MKL does not support half, bfloat16, int32 and int64 types for // matrix-multiplication, so register the kernel to use default Eigen based // implementations for these types. REGISTER_CPU defines two versions - Eigen // label and NO-LABEL TF_CALL_half(REGISTER_CPU); TF_CALL_bfloat16(REGISTER_CPU); TF_CALL_int32(REGISTER_CPU); +TF_CALL_int64(REGISTER_CPU); // Float is supported in both MKL DNN as well as in MKL ML // Registration for NO-LABEL version is in mkl_matmul_op.cc for types supported diff --git a/tensorflow/core/kernels/mkl_concat_op.cc b/tensorflow/core/kernels/mkl_concat_op.cc index 8ad7ebb51f3c11..f0278caee6b952 100644 --- a/tensorflow/core/kernels/mkl_concat_op.cc +++ b/tensorflow/core/kernels/mkl_concat_op.cc @@ -13,9 +13,10 @@ limitations under the License. #ifdef INTEL_MKL #include -#include #include +#include +#include "mkldnn.hpp" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -27,15 +28,8 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" -#ifndef INTEL_MKL_ML_ONLY -#include "mkldnn.hpp" - using mkldnn::concat; using mkldnn::stream; -#else -#include "mkl_dnn.h" -#include "mkl_dnn_types.h" -#endif #include "tensorflow/core/util/mkl_util.h" namespace tensorflow { @@ -63,95 +57,6 @@ class EigenConcatBaseOp : public OpKernel { // we need to have empty Compute because Compute is pure virtual function. void Compute(OpKernelContext* c) {} -#ifdef INTEL_MKL_ML_ONLY - - void Compute(OpKernelContext* c, const std::vector& values) { - const Tensor* concat_dim_tensor; - const char* axis_attribute_name = - AxisArgName == NAME_IS_AXIS - ? "axis" - : AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : ""; - OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor)); - OP_REQUIRES(c, IsLegacyScalar(concat_dim_tensor->shape()), - errors::InvalidArgument( - axis_attribute_name, - " tensor should be a scalar integer, but got shape ", - concat_dim_tensor->shape().DebugString())); - const int32 concat_dim = - internal::SubtleMustCopy(concat_dim_tensor->scalar()()); - // Instead of accessing values from context, we use input to Compute. - const int N = values.size(); - const int input_dims = values[0].dims(); - const TensorShape& input_shape = values[0].shape(); - - int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim; - OP_REQUIRES(c, - (0 <= axis && axis < input_dims) || - (allow_legacy_scalars() && concat_dim == 0), - errors::InvalidArgument( - "ConcatOp : Expected concatenating dimensions in the range " - "[", - -input_dims, ", ", input_dims, "), but got ", concat_dim)); - // Note that we reduce the concat of n-dimensional tensors into a two - // dimensional concat. Assuming the dimensions of any input/output - // tensor are {x0, x1,...,xn-1, y0, y1,...,ym-1}, where the concat is along - // the dimension indicated with size y0, we flatten it to {x, y}, where y = - // Prod_i(yi) and x = ((n > 0) ? Prod_i(xi) : 1). - ConstMatrixVector inputs_flat; - inputs_flat.reserve(N); - int64 inputs_flat_dim0 = 1; - for (int d = 0; d < axis; ++d) { - inputs_flat_dim0 *= input_shape.dim_size(d); - } - int64 output_concat_dim = 0; - const bool input_is_scalar = IsLegacyScalar(input_shape); - for (int i = 0; i < N; ++i) { - const auto in = values[i]; - const bool in_is_scalar = IsLegacyScalar(in.shape()); - OP_REQUIRES( - c, in.dims() == input_dims || (input_is_scalar && in_is_scalar), - errors::InvalidArgument( - "ConcatOp : Ranks of all input tensors should match: shape[0] = ", - input_shape.DebugString(), " vs. shape[", i, - "] = ", in.shape().DebugString())); - for (int j = 0; j < input_dims; ++j) { - if (j == axis) { - continue; - } - OP_REQUIRES( - c, in.dim_size(j) == input_shape.dim_size(j), - errors::InvalidArgument( - "ConcatOp : Dimensions of inputs should match: shape[0] = ", - input_shape.DebugString(), " vs. shape[", i, - "] = ", in.shape().DebugString())); - } - if (in.NumElements() > 0) { - int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0; - inputs_flat.emplace_back(new typename TTypes::ConstMatrix( - in.shaped({inputs_flat_dim0, inputs_flat_dim1}))); - } - // TODO(irving): Remove check once !allow_legacy_scalars(). - output_concat_dim += in.dims() > 0 ? in.dim_size(axis) : 1; - } - - TensorShape output_shape(input_shape); - // TODO(irving): Remove rank 0 case once !allow_legacy_scalars(). - if (output_shape.dims() == 0) { - output_shape.AddDim(output_concat_dim); - } else { - output_shape.set_dim(axis, output_concat_dim); - } - Tensor* output = nullptr; - OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output)); - if (output->NumElements() > 0) { - int64 output_dim1 = output->NumElements() / inputs_flat_dim0; - auto output_flat = output->shaped({inputs_flat_dim0, output_dim1}); - ConcatCPU(c->device(), inputs_flat, &output_flat); - } - } - -#else // MKL_DNN - void Compute(OpKernelContext* c, const std::vector& values, const TensorShapeList& input_shapes) { const Tensor* concat_dim_tensor; @@ -227,342 +132,7 @@ class EigenConcatBaseOp : public OpKernel { ConcatCPU(c->device(), inputs_flat, &output_flat); } } - -#endif -}; - -#ifdef INTEL_MKL_ML_ONLY - -// -------------------------------------------------------------------------- -// Mkl Concat Op -// -------------------------------------------------------------------------- - -template -class MklConcatOp : public OpKernel { - private: - TensorFormat data_format_; - EigenConcatBaseOp eigen_concat_op_; - - public: - typedef std::vector::ConstMatrix>> - ConstMatrixVector; - - explicit MklConcatOp(OpKernelConstruction* c) - : OpKernel(c), eigen_concat_op_(c) {} - - void Compute(OpKernelContext* context) override { - MklConcatOpContext mkl_context; - - // Get input tensors. - OpInputList input_tensors; - GetMklInputList(context, "values", &input_tensors); - const int N = input_tensors.size(); - // Get MKL shapes. - MklShapeList input_shapes(N); - GetMklShapeList(context, "values", &input_shapes); - - // If this is Concat, then concat_dim is 0th input. - // If this is ConcatV2, then axis is Nth input. - const Tensor& concat_dim_tensor = AxisArgName == NAME_IS_CONCAT_DIM - ? MklGetInput(context, 0) - : MklGetInput(context, N); - - // Sanity checks - OP_REQUIRES( - context, IsLegacyScalar(concat_dim_tensor.shape()), - errors::InvalidArgument( - "Concat dim tensor should be a scalar integer, but got shape ", - concat_dim_tensor.shape().DebugString())); - int32 concat_dim = - internal::SubtleMustCopy(concat_dim_tensor.scalar()()); - - MklShape& inpshape0 = input_shapes[0]; - - // Check that all tensors are Mkl, if not we call Eigen version. - bool invoke_eigen = false; - bool is_concat_dim_channel = true; - if (!AreAllMklTensors(input_shapes)) { - invoke_eigen = true; - } - - // Check that total number of dimensions is 4, if not call Eigen. - if (!invoke_eigen) { - for (auto& s : input_shapes) { - if (s.GetDimension() != 4) { - invoke_eigen = true; - break; - } - } - } - - // check that concat_dim is channel, if not call Eigen version. - if (!invoke_eigen) { - for (auto& s : input_shapes) { - if (!s.IsMklChannelDim(concat_dim)) { - invoke_eigen = true; - is_concat_dim_channel = false; - break; - } - } - } - - if (invoke_eigen) { - VLOG(1) << "_MklConcatOp: Invoking Eigen version of Concat. Reason:" - << (!is_concat_dim_channel ? "Concat dimension is not channel" - : "Not all tensors are in Mkl layout"); - CallEigenVersion(context, input_tensors, input_shapes); - return; - } - - // For MKL format, the channel is dimension number 2. - // So if we are concating over channel and _all_ inputs are in MKL - // format, then we set concat_dim to 2. - // Since we have reached till here, it means we are concating - // over channel. - concat_dim = MklDims::C; - - // One more sanity check: check that ranks of all tensors match - // and that their shapes match except for concat_dim. - int i = 0; - for (auto& s : input_shapes) { - size_t exp_dims = inpshape0.GetDimension(); - OP_REQUIRES(context, s.GetDimension() == exp_dims, - errors::InvalidArgument( - "_MklConcatOp : Ranks of all input tensors should match:" - " input dimensions = ", - s.GetDimension(), " vs. expected rank = ", exp_dims)); - - for (int d = 0; d < exp_dims; ++d) { - if (d == concat_dim) { - continue; - } - - size_t exp_size = inpshape0.GetSizes()[d]; - OP_REQUIRES( - context, exp_size == s.GetSizes()[d], - errors::InvalidArgument("_MklConcatOp : Dimensions of inputs" - "should match: shape[0][", - d, "]= ", exp_size, " vs. shape[", i, "][", - d, "] = ", s.GetSizes()[d])); - } - ++i; - } - - // Use input MKL layout instead of creating new layouts. - int64 output_concat_dim_size = 0; - for (auto& s : input_shapes) { - output_concat_dim_size += - s.GetDimension() > 0 ? s.GetSizes()[concat_dim] : 1; - } - mkl_context.MklCreateInputLayouts(context, input_shapes); - OP_REQUIRES_OK(context, context->status()); - - CHECK_EQ(dnnConcatCreate_F32(&mkl_context.prim_concat, NULL, N, - &mkl_context.lt_inputs[0]), - E_SUCCESS); - - // Calculate output sizes and strides - TensorFormat data_format; - if (inpshape0.IsTensorInNHWCFormat()) { - data_format = FORMAT_NHWC; - } else { - OP_REQUIRES( - context, inpshape0.IsTensorInNCHWFormat(), - errors::InvalidArgument( - "_MklConcat only supports all inputs in NCHW or NHWC format ")); - data_format = FORMAT_NCHW; - } - - // Since all tensors are in Mkl layout, we copy sizes from input tensor. - mkl_context.out_sizes[MklDims::W] = inpshape0.GetSizes()[MklDims::W]; - mkl_context.out_sizes[MklDims::H] = inpshape0.GetSizes()[MklDims::H]; - mkl_context.out_sizes[MklDims::C] = output_concat_dim_size; - mkl_context.out_sizes[MklDims::N] = inpshape0.GetSizes()[MklDims::N]; - GetStridesFromSizes(data_format, mkl_context.out_strides, - mkl_context.out_sizes); - - // Set output Mkl shape. - int64 dim = 4; - MklShape mkl_output_mkl_shape; - mkl_output_mkl_shape.SetMklTensor(true); - mkl_output_mkl_shape.SetMklLayout(mkl_context.prim_concat, dnnResourceDst); - mkl_output_mkl_shape.SetTfLayout(dim, mkl_context.out_sizes, - mkl_context.out_strides); - mkl_output_mkl_shape.SetTfDimOrder(dim, inpshape0.GetTfToMklDimMap()); - - TensorShape mkl_output_tf_shape; - mkl_output_tf_shape.AddDim(1); - mkl_output_tf_shape.AddDim( - dnnLayoutGetMemorySize_F32( - static_cast(mkl_output_mkl_shape.GetMklLayout())) / - sizeof(T)); - - Tensor* output = nullptr; - AllocateOutputSetMklShape(context, 0, &output, mkl_output_tf_shape, - mkl_output_mkl_shape); - - // Set destination resource. - mkl_context.concat_res[dnnResourceDst] = - const_cast(static_cast(output->flat().data())); - - mkl_context.mkl_tmp_tensors.resize(N); - mkl_context.MklPrepareConcatInputs(context, input_tensors); - OP_REQUIRES_OK(context, context->status()); - - // Execute primitive. - CHECK_EQ(dnnExecute_F32(mkl_context.prim_concat, mkl_context.concat_res), - E_SUCCESS); - - mkl_context.MklCleanup(); - OP_REQUIRES_OK(context, context->status()); - } - - private: - typedef struct { - TensorFormat data_format; - size_t out_sizes[4]; - size_t out_strides[4]; - dnnPrimitive_t prim_concat; - void* concat_res[dnnResourceNumber]; - std::vector lt_inputs; - std::vector mkl_tmp_tensors; - - // Create MKL dnnLayout_t objects for tensors coming into the layer - // We only support case where input tensors are all in Mkl layout. - void MklCreateInputLayouts(OpKernelContext* context, - MklShapeList& input_shapes) { - for (auto& is : input_shapes) { - CHECK_EQ(is.IsMklTensor(), true); - lt_inputs.push_back((dnnLayout_t)is.GetCurLayout()); - } - } - - void MklPrepareConcatInputs(OpKernelContext* context, - OpInputList& input_tensors) { - CHECK_EQ(lt_inputs.size(), mkl_tmp_tensors.size()); - - for (int i = 0; i < lt_inputs.size(); ++i) { - dnnPrimitive_t mkl_prim_convert_input; - dnnLayout_t mkl_lt_internal_input; - void* mkl_buf_convert_input = nullptr; - - CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( - &mkl_lt_internal_input, prim_concat, - (dnnResourceType_t)(dnnResourceMultipleSrc + i)), - E_SUCCESS); - - if (!dnnLayoutCompare_F32(lt_inputs[i], mkl_lt_internal_input)) { - CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, - lt_inputs[i], mkl_lt_internal_input), - E_SUCCESS); - - AllocTmpBuffer(context, &mkl_tmp_tensors[i], mkl_lt_internal_input, - &mkl_buf_convert_input); - - CHECK_EQ(dnnConversionExecute_F32( - mkl_prim_convert_input, - const_cast(static_cast( - input_tensors[i].flat().data())), - mkl_buf_convert_input), - E_SUCCESS); - - concat_res[dnnResourceMultipleSrc + i] = mkl_buf_convert_input; - CHECK_EQ(dnnDelete_F32(mkl_prim_convert_input), E_SUCCESS); - } else { - concat_res[dnnResourceMultipleSrc + i] = const_cast( - static_cast(input_tensors[i].flat().data())); - } - - CHECK_EQ(dnnLayoutDelete_F32(mkl_lt_internal_input), E_SUCCESS); - } - } - - void MklCleanup() { - for (auto& lt : lt_inputs) { - lt = nullptr; - } - CHECK_EQ(dnnDelete_F32(prim_concat), E_SUCCESS); - } - } MklConcatOpContext; - - void CallEigenVersion(OpKernelContext* context, const OpInputList& values, - const MklShapeList& input_shapes) { - // Before calling Eigen version, we need to convert Mkl tensors to TF. - // First check that the number of input tensors and the number of Mkl - // shapes match. - CHECK_EQ(values.size(), input_shapes.size()); - - std::vector converted_values; - for (int i = 0; i < input_shapes.size(); i++) { - if (input_shapes[i].IsMklTensor()) { - // If input tensor is Mkl, then do the conversion. - Tensor tmp_tensor = - ConvertMklToTF(context, values[i], input_shapes[i]); - converted_values.push_back(tmp_tensor); - } else { - // If input tensor is TF already, then we do not need any conversion. - converted_values.push_back(values[i]); - } - } - - // Call Eigen concat. - eigen_concat_op_.Compute(context, converted_values); - - // Set dummy Mkl tensor as output Mkl tensor for this op. - MklShape mkl_tensor_mkl_shape; - mkl_tensor_mkl_shape.SetMklTensor(false); - mkl_tensor_mkl_shape.SetDimensions(4); - mkl_tensor_mkl_shape.SetTfDimOrder(4); // Dimensions - Tensor* mkl_tensor = nullptr; - TensorShape mkl_tensor_tf_shape; - mkl_tensor_tf_shape.AddDim( - SIZE_OF_MKL_SERIAL_DATA(mkl_tensor_mkl_shape.GetDimension())); - int tf_output_index = 0; - // TODO(jktomer): replace this with OP_REQUIRES_OK and clean up this file - // to propagate the status up the call stack. - TF_CHECK_OK(context->allocate_output( - GetTensorMetaDataIndex(tf_output_index, context->num_outputs()), - mkl_tensor_tf_shape, &mkl_tensor)); - mkl_tensor_mkl_shape.SerializeMklShape( - mkl_tensor->flat().data(), - mkl_tensor->flat().size() * sizeof(uint8)); - } - - // overloading methods with input shapes as a list of TensorShape's - void CallEigenVersion(OpKernelContext* context, const OpInputList& values, - const TensorShapeList& input_shapes) { - CHECK_EQ(values.size(), input_shapes.size()); - - std::vector converted_values; - for (int i = 0; i < input_shapes.size(); i++) { - converted_values.push_back(values[i]); - } - - // Call Eigen concat. - eigen_concat_op_.Compute(context, converted_values); - - // Set dummy Mkl tensor as output Mkl tensor for this op. - MklShape mkl_tensor_mkl_shape; - mkl_tensor_mkl_shape.SetMklTensor(false); - mkl_tensor_mkl_shape.SetDimensions(4); - Tensor* mkl_tensor = nullptr; - TensorShape mkl_tensor_tf_shape; - mkl_tensor_tf_shape.AddDim( - SIZE_OF_MKL_SERIAL_DATA(mkl_tensor_mkl_shape.GetDimension())); - int tf_output_index = 0; - // TODO(jktomer): replace this with OP_REQUIRES_OK and clean up this file - // to propagate the status up the call stack. - TF_CHECK_OK(context->allocate_output( - GetTensorMetaDataIndex(tf_output_index, context->num_outputs()), - mkl_tensor_tf_shape, &mkl_tensor)); - mkl_tensor_mkl_shape.SerializeMklShape( - mkl_tensor->flat().data(), - mkl_tensor->flat().size() * sizeof(uint8)); - } }; - -#else - // -------------------------------------------------------------------------- // Mkl Concat Op // -------------------------------------------------------------------------- @@ -609,8 +179,8 @@ class MklConcatOp : public OpKernel { bool invoke_eigen = false; bool are_all_mkl_inputs = true, are_all_tf_inputs = true; const TensorShape expected_shape = mkl_input_shapes[0].IsMklTensor() - ? mkl_input_shapes[0].GetTfShape() - : input_tensors[0].shape(); + ? mkl_input_shapes[0].GetTfShape() + : input_tensors[0].shape(); size_t expected_dims = expected_shape.dims(); if (concat_dim < 0) concat_dim = expected_dims + concat_dim; @@ -681,13 +251,12 @@ class MklConcatOp : public OpKernel { if (are_all_mkl_inputs) { mkl_common_format = FindMklCommonFormat(mkl_input_shapes, concat_dim, - &isMklReorderNeeded, &dst_concat_dim_size); + &isMklReorderNeeded, &dst_concat_dim_size); if (!isMklReorderNeeded) { // All MKL tensors have a same format. Reorder is not needed. for (int k = 0; k < N; k++) { - if (input_tensors[k].NumElements() == 0) - continue; + if (input_tensors[k].NumElements() == 0) continue; auto src_md = mkl_input_shapes[k].GetMklLayout(); srcs[k].SetUsrMem(src_md, &input_tensors[k]); @@ -698,16 +267,16 @@ class MklConcatOp : public OpKernel { // MKL tensors have different formats. // Reorder them to most common format. for (int k = 0; k < N; k++) { - if (input_tensors[k].NumElements() == 0) - continue; + if (input_tensors[k].NumElements() == 0) continue; auto src_md = mkl_input_shapes[k].GetMklLayout(); srcs[k].SetUsrMem(src_md, &input_tensors[k]); if (src_md.data.format != mkl_common_format) { - memory::dims src_dims(src_md.data.dims, &src_md.data.dims[src_md.data.ndims]); - src_md = memory::desc(src_dims, MklDnnType(), - mkl_common_format); + memory::dims src_dims(src_md.data.dims, + &src_md.data.dims[src_md.data.ndims]); + src_md = + memory::desc(src_dims, MklDnnType(), mkl_common_format); } srcs_pd.push_back(memory::primitive_desc(src_md, cpu_engine)); @@ -715,8 +284,7 @@ class MklConcatOp : public OpKernel { } } else { // All TF inputs for (int k = 0; k < N; k++) { - if (input_tensors[k].NumElements() == 0) - continue; + if (input_tensors[k].NumElements() == 0) continue; memory::dims src_dims = TFShapeToMklDnnDims(input_tensors[k].shape()); dst_concat_dim_size += src_dims[concat_dim]; @@ -744,8 +312,8 @@ class MklConcatOp : public OpKernel { dst_dims, MklDnnDataFormatToTFDataFormat(orig_tf_format)); // Set the output format same as the most common format of inputs // to avoid layout conversions. - dst_md = memory::desc( - dst_dims_in_nchw, MklDnnType(), mkl_common_format); + dst_md = + memory::desc(dst_dims_in_nchw, MklDnnType(), mkl_common_format); } else { // All inputs are TF tensors. // Set the output format same as input format (nchw). @@ -774,9 +342,10 @@ class MklConcatOp : public OpKernel { // E.g., if we are concatinating over Channel (dimension 3 for NHWC), // then since MklDnn order is NCHW, concat_dim needs to be 1. if (are_all_mkl_inputs) - concat_dim = mkl_input_shapes[0].TfDimIdx(concat_dim); + concat_dim = mkl_input_shapes[0].TfDimIdx(concat_dim); - auto concat_pd = concat::primitive_desc(dst_md, concat_dim, srcs_pd); + auto concat_pd = concat::primitive_desc(concat_dim, srcs_pd); + auto dst_pd = concat_pd.dst_primitive_desc(); MklDnnShape dnn_shape_dst; TensorShape tf_shape_dst; @@ -865,27 +434,26 @@ class MklConcatOp : public OpKernel { // Return: // return the common MKL format. memory::format FindMklCommonFormat(const MklDnnShapeList& input_shapes, - int concat_dim, bool* is_reorder_needed, int64* concat_dim_size) { + int concat_dim, bool* is_reorder_needed, + int64* concat_dim_size) { *is_reorder_needed = false; *concat_dim_size = 0; std::unordered_map occurrence_map; - if (input_shapes.size() == 0) - return memory::format::any; + if (input_shapes.size() == 0) return memory::format::any; // Compute ocurrences of each format of all inputs. - for (int k=0; k ( - input_shapes[k].GetMklLayout().data.format); + int fmt = static_cast(input_shapes[k].GetMklLayout().data.format); occurrence_map[fmt] += 1; } if (occurrence_map.size() == 1) { - // this means that all inputs have a same format - // return it with is_reorder_needed set false. - return static_cast( - input_shapes[0].GetMklLayout().data.format); + // this means that all inputs have a same format + // return it with is_reorder_needed set false. + return static_cast( + input_shapes[0].GetMklLayout().data.format); } // Input tensors have different formats. Thus, reorder is needed. @@ -904,8 +472,6 @@ class MklConcatOp : public OpKernel { } }; -#endif - /* Use optimized concat for float type only */ #define REGISTER_MKL_CPU(type) \ REGISTER_KERNEL_BUILDER(Name("_MklConcat") \ diff --git a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h index bb8254eaacf97f..f5644d0da4cee3 100644 --- a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h +++ b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h @@ -880,11 +880,11 @@ struct ReduceFunctor> { }; template -struct ReduceFunctor> { +struct ReduceFunctor> { template static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, const ReductionAxes& reduction_axes, - const Eigen::internal::MeanReducer& reducer) { + const functor::MeanReducer& reducer) { int divisor = 1; if (out.rank() == 0) divisor = in.size(); @@ -910,17 +910,17 @@ struct ReduceFunctor> { template static void FillIdentity(const GPUDevice& d, OUT_T out, - const Eigen::internal::MeanReducer& reducer) { + const functor::MeanReducer& reducer) { FillIdentityEigenImpl(d, To32Bit(out), reducer); } }; template <> -struct ReduceFunctor> { +struct ReduceFunctor> { template static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, const ReductionAxes& reduction_axes, - const Eigen::internal::MeanReducer& reducer) { + const functor::MeanReducer& reducer) { float divisor = 1.f; if (out.rank() == 0) divisor = in.size(); @@ -952,9 +952,8 @@ struct ReduceFunctor> { } template - static void FillIdentity( - const GPUDevice& d, OUT_T out, - const Eigen::internal::MeanReducer& reducer) { + static void FillIdentity(const GPUDevice& d, OUT_T out, + const functor::MeanReducer& reducer) { FillIdentityEigenImpl(d, To32Bit(out), reducer); } }; diff --git a/tensorflow/core/kernels/reduction_ops.h b/tensorflow/core/kernels/reduction_ops.h index eb264e0e5a7363..2331599b72f46d 100644 --- a/tensorflow/core/kernels/reduction_ops.h +++ b/tensorflow/core/kernels/reduction_ops.h @@ -26,13 +26,35 @@ limitations under the License. namespace tensorflow { namespace functor { +// Dummy class used for template specialization for mean reduction, which is +// accomplished by SumReducer and on-the-fly division by the reduction factor. +template +struct MeanReducer { + Scalar initialize() const { return Scalar(0); } +}; + template -void ReduceEigenImpl(const Device& d, OUT_T out, IN_T in, - const ReductionAxes& reduction_axes, - const Reducer& reducer) { - out.device(d) = in.reduce(reduction_axes, reducer); -} +struct ReduceEigenImpl { + void operator()(const Device& d, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, const Reducer& reducer) { + out.device(d) = in.reduce(reduction_axes, reducer); + } +}; + +template +struct ReduceEigenImpl> { + void operator()(const Device& d, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const functor::MeanReducer& reducer) { + static_assert(std::is_same::value, ""); + Eigen::internal::SumReducer sum_reducer; + out.device(d) = in.reduce(reduction_axes, sum_reducer) / + static_cast(in.size() / out.size()); + } +}; // For most reducers, the identity is Reducer::initialize() template @@ -46,12 +68,12 @@ struct Identity { // MeanReducer is a special case, since it doesn't technically have an identity. // Thus, ideally we'd return nan. However, mean is instantiated for integer // types as well, so we do the nan override only for floating point types. -#define FIX_MEAN_IDENTITY(T) \ - template <> \ - struct Identity> { \ - static T identity(const Eigen::internal::MeanReducer&) { \ - return Eigen::NumTraits::quiet_NaN(); \ - } \ +#define FIX_MEAN_IDENTITY(T) \ + template <> \ + struct Identity> { \ + static T identity(const functor::MeanReducer&) { \ + return Eigen::NumTraits::quiet_NaN(); \ + } \ }; FIX_MEAN_IDENTITY(Eigen::half) FIX_MEAN_IDENTITY(float) diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h index d83e1c7d15d22f..c6c36ec29a782e 100644 --- a/tensorflow/core/kernels/reduction_ops_common.h +++ b/tensorflow/core/kernels/reduction_ops_common.h @@ -256,7 +256,8 @@ struct ReduceFunctorBase { const ReductionAxes& reduction_axes, const Reducer& reducer) { const Device& d = ctx->eigen_device(); - ReduceEigenImpl(d, out, in, reduction_axes, reducer); + ReduceEigenImpl reducer_impl; + reducer_impl(d, out, in, reduction_axes, reducer); } template diff --git a/tensorflow/core/kernels/reduction_ops_gpu_complex128.cu.cc b/tensorflow/core/kernels/reduction_ops_gpu_complex128.cu.cc index cb19c084e3984f..c44a40b3b38f5a 100644 --- a/tensorflow/core/kernels/reduction_ops_gpu_complex128.cu.cc +++ b/tensorflow/core/kernels/reduction_ops_gpu_complex128.cu.cc @@ -52,7 +52,7 @@ typedef TTypes::Tensor::Index Index; DEFINE_IDENTITY(T, R) DEFINE_FOR_TYPE_AND_R(complex128, Eigen::internal::SumReducer); -DEFINE_FOR_TYPE_AND_R(complex128, Eigen::internal::MeanReducer); +DEFINE_FOR_TYPE_AND_R(complex128, functor::MeanReducer); DEFINE_FOR_TYPE_AND_R(complex128, Eigen::internal::ProdReducer); #undef DEFINE_FOR_TYPE_AND_R #undef DEFINE diff --git a/tensorflow/core/kernels/reduction_ops_gpu_complex64.cu.cc b/tensorflow/core/kernels/reduction_ops_gpu_complex64.cu.cc index fa550e594a5c59..1921130ac043d9 100644 --- a/tensorflow/core/kernels/reduction_ops_gpu_complex64.cu.cc +++ b/tensorflow/core/kernels/reduction_ops_gpu_complex64.cu.cc @@ -52,7 +52,7 @@ typedef TTypes::Tensor::Index Index; DEFINE_IDENTITY(T, R) DEFINE_FOR_TYPE_AND_R(complex64, Eigen::internal::SumReducer); -DEFINE_FOR_TYPE_AND_R(complex64, Eigen::internal::MeanReducer); +DEFINE_FOR_TYPE_AND_R(complex64, functor::MeanReducer); DEFINE_FOR_TYPE_AND_R(complex64, Eigen::internal::ProdReducer); #undef DEFINE_FOR_TYPE_AND_R #undef DEFINE diff --git a/tensorflow/core/kernels/reduction_ops_gpu_double.cu.cc b/tensorflow/core/kernels/reduction_ops_gpu_double.cu.cc index de46933f615869..119f726b929bd9 100644 --- a/tensorflow/core/kernels/reduction_ops_gpu_double.cu.cc +++ b/tensorflow/core/kernels/reduction_ops_gpu_double.cu.cc @@ -51,11 +51,11 @@ typedef TTypes::Tensor::Index Index; DEFINE(T, R, 3, 2); \ DEFINE_IDENTITY(T, R) -#define DEFINE_FOR_ALL_REDUCERS(T) \ - DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer); \ - DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MeanReducer); \ - DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MinReducer); \ - DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MaxReducer); \ +#define DEFINE_FOR_ALL_REDUCERS(T) \ + DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer); \ + DEFINE_FOR_TYPE_AND_R(T, functor::MeanReducer); \ + DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MinReducer); \ + DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MaxReducer); \ DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::ProdReducer) DEFINE_FOR_ALL_REDUCERS(double); diff --git a/tensorflow/core/kernels/reduction_ops_gpu_float.cu.cc b/tensorflow/core/kernels/reduction_ops_gpu_float.cu.cc index b9d737183977c2..70ba4abac48bcf 100644 --- a/tensorflow/core/kernels/reduction_ops_gpu_float.cu.cc +++ b/tensorflow/core/kernels/reduction_ops_gpu_float.cu.cc @@ -51,11 +51,11 @@ typedef TTypes::Tensor::Index Index; DEFINE(T, R, 3, 2); \ DEFINE_IDENTITY(T, R) -#define DEFINE_FOR_ALL_REDUCERS(T) \ - DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer); \ - DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MeanReducer); \ - DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MinReducer); \ - DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MaxReducer); \ +#define DEFINE_FOR_ALL_REDUCERS(T) \ + DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer); \ + DEFINE_FOR_TYPE_AND_R(T, functor::MeanReducer); \ + DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MinReducer); \ + DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MaxReducer); \ DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::ProdReducer) DEFINE_FOR_ALL_REDUCERS(float); diff --git a/tensorflow/core/kernels/reduction_ops_gpu_int.cu.cc b/tensorflow/core/kernels/reduction_ops_gpu_int.cu.cc index 69296c7b65c253..82f6d7df952fcd 100644 --- a/tensorflow/core/kernels/reduction_ops_gpu_int.cu.cc +++ b/tensorflow/core/kernels/reduction_ops_gpu_int.cu.cc @@ -51,11 +51,11 @@ typedef TTypes::Tensor::Index Index; DEFINE(T, R, 3, 2); \ DEFINE_IDENTITY(T, R) -#define DEFINE_FOR_ALL_REDUCERS(T) \ - DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer); \ - DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MeanReducer); \ - DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MinReducer); \ - DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MaxReducer); \ +#define DEFINE_FOR_ALL_REDUCERS(T) \ + DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer); \ + DEFINE_FOR_TYPE_AND_R(T, functor::MeanReducer); \ + DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MinReducer); \ + DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MaxReducer); \ DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::ProdReducer) DEFINE_FOR_ALL_REDUCERS(int32); diff --git a/tensorflow/core/kernels/reduction_ops_half_mean_sum.cu.cc b/tensorflow/core/kernels/reduction_ops_half_mean_sum.cu.cc index 2120e22f99cbcf..db050fdea38bd6 100644 --- a/tensorflow/core/kernels/reduction_ops_half_mean_sum.cu.cc +++ b/tensorflow/core/kernels/reduction_ops_half_mean_sum.cu.cc @@ -53,7 +53,7 @@ typedef TTypes::Tensor::Index Index; #define DEFINE_FOR_ALL_REDUCERS(T) \ DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer); \ - DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MeanReducer); + DEFINE_FOR_TYPE_AND_R(T, functor::MeanReducer); DEFINE_FOR_ALL_REDUCERS(Eigen::half); #undef DEFINE_FOR_ALL_REDUCERS diff --git a/tensorflow/core/kernels/reduction_ops_mean.cc b/tensorflow/core/kernels/reduction_ops_mean.cc index f61589f913b14b..67c974edda284d 100644 --- a/tensorflow/core/kernels/reduction_ops_mean.cc +++ b/tensorflow/core/kernels/reduction_ops_mean.cc @@ -17,39 +17,39 @@ limitations under the License. namespace tensorflow { -#define REGISTER_CPU_KERNELS(type) \ - REGISTER_KERNEL_BUILDER(Name("Mean") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tidx"), \ - ReductionOp>); \ - REGISTER_KERNEL_BUILDER(Name("Mean") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tidx"), \ - ReductionOp>); +#define REGISTER_CPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Mean") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tidx"), \ + ReductionOp>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Mean") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tidx"), \ + ReductionOp>); TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS); #undef REGISTER_CPU_KERNELS #if GOOGLE_CUDA -#define REGISTER_GPU_KERNELS(type) \ - REGISTER_KERNEL_BUILDER(Name("Mean") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tidx") \ - .HostMemory("reduction_indices"), \ - ReductionOp>); \ - REGISTER_KERNEL_BUILDER(Name("Mean") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tidx") \ - .HostMemory("reduction_indices"), \ - ReductionOp>); +#define REGISTER_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Mean") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tidx") \ + .HostMemory("reduction_indices"), \ + ReductionOp>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Mean") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tidx") \ + .HostMemory("reduction_indices"), \ + ReductionOp>); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); TF_CALL_complex64(REGISTER_GPU_KERNELS); TF_CALL_complex128(REGISTER_GPU_KERNELS); @@ -58,21 +58,21 @@ TF_CALL_complex128(REGISTER_GPU_KERNELS); #endif #ifdef TENSORFLOW_USE_SYCL -#define REGISTER_SYCL_KERNELS(type) \ - REGISTER_KERNEL_BUILDER(Name("Mean") \ - .Device(DEVICE_SYCL) \ - .TypeConstraint("T") \ - .TypeConstraint("Tidx") \ - .HostMemory("reduction_indices"), \ - ReductionOp>); \ - REGISTER_KERNEL_BUILDER(Name("Mean") \ - .Device(DEVICE_SYCL) \ - .TypeConstraint("T") \ - .TypeConstraint("Tidx") \ - .HostMemory("reduction_indices"), \ - ReductionOp>); +#define REGISTER_SYCL_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Mean") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .TypeConstraint("Tidx") \ + .HostMemory("reduction_indices"), \ + ReductionOp>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Mean") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .TypeConstraint("Tidx") \ + .HostMemory("reduction_indices"), \ + ReductionOp>); REGISTER_SYCL_KERNELS(float); REGISTER_SYCL_KERNELS(double); #undef REGISTER_SYCL_KERNELS diff --git a/tensorflow/core/kernels/reverse_op.cc b/tensorflow/core/kernels/reverse_op.cc index bb96c42f10c498..1c4d0bc1ae9934 100644 --- a/tensorflow/core/kernels/reverse_op.cc +++ b/tensorflow/core/kernels/reverse_op.cc @@ -373,8 +373,7 @@ TF_CALL_complex128(DECLARE_GPU_SPEC); ReverseV2Op) TF_CALL_uint8(REGISTER_GPU_KERNELS); TF_CALL_int8(REGISTER_GPU_KERNELS); -// TODO decide whether we want to enable the bool kernel. -// TF_CALL_bool(REGISTER_GPU_KERNELS); +TF_CALL_bool(REGISTER_GPU_KERNELS); TF_CALL_half(REGISTER_GPU_KERNELS); TF_CALL_float(REGISTER_GPU_KERNELS); TF_CALL_double(REGISTER_GPU_KERNELS); diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc index 14b83f25140a6c..0b0ff95093e44d 100644 --- a/tensorflow/core/kernels/slice_op.cc +++ b/tensorflow/core/kernels/slice_op.cc @@ -294,19 +294,19 @@ TF_CALL_complex128(DECLARE_FOR_N); TF_CALL_bfloat16(DECLARE_FOR_N); TF_CALL_bool(DECLARE_FOR_N); TF_CALL_int8(DECLARE_FOR_N); +TF_CALL_int64(DECLARE_FOR_N); DECLARE_FOR_N(int32); #undef DECLARE_FOR_N #undef DECLARE_GPU_SPEC } // namespace functor -#define REGISTER_GPU(type) \ - REGISTER_KERNEL_BUILDER(Name("Slice") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .HostMemory("begin") \ - .HostMemory("size") \ - .TypeConstraint("Index"), \ +#define REGISTER_GPU(type) \ + REGISTER_KERNEL_BUILDER(Name("Slice") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("begin") \ + .HostMemory("size"), \ SliceOp) TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); @@ -315,6 +315,7 @@ TF_CALL_complex128(REGISTER_GPU); TF_CALL_bfloat16(REGISTER_GPU); TF_CALL_bool(REGISTER_GPU); TF_CALL_int8(REGISTER_GPU); +TF_CALL_int64(REGISTER_GPU); // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -322,7 +323,6 @@ TF_CALL_int8(REGISTER_GPU); REGISTER_KERNEL_BUILDER(Name("Slice") .Device(DEVICE_GPU) .TypeConstraint("T") - .TypeConstraint("Index") .HostMemory("input") .HostMemory("begin") .HostMemory("size") diff --git a/tensorflow/core/kernels/slice_op_gpu.cu.cc b/tensorflow/core/kernels/slice_op_gpu.cu.cc index f6893aa6f017a7..044948f4065f97 100644 --- a/tensorflow/core/kernels/slice_op_gpu.cu.cc +++ b/tensorflow/core/kernels/slice_op_gpu.cu.cc @@ -43,6 +43,7 @@ TF_CALL_bfloat16(DEFINE_GPU_KERNELS); TF_CALL_bool(DEFINE_GPU_KERNELS); TF_CALL_int8(DEFINE_GPU_KERNELS); DEFINE_GPU_KERNELS(int32); +DEFINE_GPU_KERNELS(int64); #undef DEFINE_GPU_KERNELS diff --git a/tensorflow/core/kernels/stack.cc b/tensorflow/core/kernels/stack.cc new file mode 100644 index 00000000000000..5c70a2d62d36b9 --- /dev/null +++ b/tensorflow/core/kernels/stack.cc @@ -0,0 +1,339 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/stack.h" + +#include +#include +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class Stack : public ResourceBase { + public: + static std::atomic stack_counter; + + struct TensorAndAllocation { + Tensor tensor; + AllocatorAttributes alloc_attrs; + bool swapped_to_cpu; + }; + + Stack(const DataType& elem_type, const string& stack_name, int max_size) + : elem_type_(elem_type), + stack_name_(stack_name), + max_size_(max_size), + closed_(false) {} + + Status Push(const TensorAndAllocation& value) { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(CheckNotClosed()); + if (max_size_ >= 0 && stack_.size() >= max_size_) { + return errors::InvalidArgument("Stack[", stack_name_, "] overflowed ", + "its max_size (", max_size_, ")"); + } + stack_.push_back(value); + return Status::OK(); + } + + Status Pop(TensorAndAllocation* value) { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(CheckNotClosed()); + if (stack_.empty()) { + return errors::InvalidArgument("Stack[", stack_name_, + "] is empty when calling Pop()."); + } + *value = stack_.back(); + stack_.pop_back(); + return Status::OK(); + } + + // We don't swap the first tensor on the stack and any subsequent tensors + // that share the buffer with the first tensor. + bool IsUsefulToSwap(const Tensor& tensor) const { + mutex_lock l(mu_); + if (stack_.empty()) { + return false; + } + const Tensor& first = stack_.front().tensor; + return !tensor.SharesBufferWith(first); + } + + void Close() { + mutex_lock l(mu_); + stack_.clear(); + closed_ = true; + } + + DataType ElemType() { return elem_type_; } + + string DebugString() override { + mutex_lock l(mu_); + return strings::StrCat("Stack[", stack_name_, "]"); + } + + const string& stack_name() { return stack_name_; } + + private: + friend class StackOp; + mutex* mu() { return &mu_; } + + mutable mutex mu_; + DataType elem_type_; + const string stack_name_; + Tensor handle_; + int max_size_; + bool closed_ GUARDED_BY(mu_); + std::vector stack_ GUARDED_BY(mu_); + + Status CheckNotClosed() const EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (closed_) { + return errors::InvalidArgument("Stack[", stack_name_, + "] has already been closed."); + } + return Status::OK(); + } +}; + +Status GetStack(OpKernelContext* ctx, Stack** stack) { + if (ctx->input_dtype(0) == DT_RESOURCE) { + return LookupResource(ctx, HandleFromInput(ctx, 0), stack); + } else { + Tensor Tstack_handle = ctx->mutable_input(0, false); + if (Tstack_handle.NumElements() != 2) { + return errors::InvalidArgument( + "Stack handle must have two elements, but had shape: ", + Tstack_handle.shape().DebugString()); + } + const string& container = Tstack_handle.flat()(0); + const string& stack_name = Tstack_handle.flat()(1); + string key = strings::StrCat(container, stack_name); + ResourceMgr* rm = ctx->resource_manager(); + if (rm == nullptr) { + return errors::Internal("No resource manager."); + } + auto* step_container = ctx->step_container(); + if (step_container == nullptr) { + return errors::Internal("No step container."); + } + TF_RETURN_IF_ERROR(rm->Lookup(step_container->name(), key, stack)); + return Status::OK(); + } +} + +std::atomic Stack::stack_counter{0}; + +// StackOp + +StackOp::StackOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("elem_type", &elem_type_)); + OP_REQUIRES_OK(context, context->GetAttr("stack_name", &stack_name_)); + if (stack_name_.empty()) stack_name_ = name(); +} + +void StackOp::Compute(OpKernelContext* ctx) { + int32 size = std::numeric_limits::max(); + if (ctx->num_inputs() > 0) { + const Tensor* tensor_size; + OP_REQUIRES_OK(ctx, ctx->input("max_size", &tensor_size)); + + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(tensor_size->shape()), + errors::InvalidArgument("Stack size must be a scalar, but had shape: ", + tensor_size->shape().DebugString())); + + int32 size_value = tensor_size->scalar()(); + if (size_value >= 0) { + size = size_value; + } + } + + static const char kContainer[] = "_stacks"; + auto stack_id = Stack::stack_counter.fetch_add(1); + string stack_name = strings::StrCat(stack_name_, "_", stack_id); + // Store the handle in a per-step container. + ResourceMgr* rm = ctx->resource_manager(); + OP_REQUIRES(ctx, rm != nullptr, errors::Internal("No resource manager.")); + string key = strings::StrCat(kContainer, stack_name); + Stack* stack = new Stack(elem_type_, stack_name, size); + auto* step_container = ctx->step_container(); + OP_REQUIRES(ctx, step_container != nullptr, + errors::Internal("No step container.")); + OP_REQUIRES_OK(ctx, rm->Create(step_container->name(), key, stack)); + if (IsRefType(ctx->expected_output_dtype(0))) { + // Create the stack handle. + AllocatorAttributes alloc_attr; + alloc_attr.set_on_host(true); + OP_REQUIRES_OK(ctx, ctx->allocate_temp(tensorflow::DT_STRING, + tensorflow::TensorShape({2}), + &stack->handle_, alloc_attr)); + auto handle = stack->handle_.flat(); + handle(0) = kContainer; + handle(1) = std::move(stack_name); + ctx->set_output_ref(0, stack->mu(), &stack->handle_); + } else { + Tensor* handle; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle)); + handle->flat()(0) = + MakePerStepResourceHandle(ctx, key); + } +} + +// StackPushOp + +StackPushOp::StackPushOp(OpKernelConstruction* context, bool allow_swapping) + : AsyncOpKernel(context) { + if (allow_swapping) { + OP_REQUIRES_OK(context, context->GetAttr("swap_memory", &swap_memory_)); + } +} + +void StackPushOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { + // Get the stack from the handle. + Stack* stack = nullptr; + OP_REQUIRES_OK_ASYNC(ctx, GetStack(ctx, &stack), done); + core::ScopedUnref unref(stack); + + if (ctx->input_dtype(1) != stack->ElemType()) { + ctx->CtxFailure(errors::InvalidArgument("Must have type ", + stack->ElemType(), " but got ", + ctx->input_dtype(1))); + done(); + return; + } + + // Push the tensor onto the stack. Swap the tensor to CPU if instructed. + const Tensor& tensor = ctx->input(1); + AllocatorAttributes alloc_attrs = ctx->input_alloc_attr(1); + // For now, we use a simple heuristic for swapping: A GPU tensor is moved + // to CPU if the tensor has more than kCopyThreshold bytes and the GPU + // allocator says more than kOccupancy of the memory is in use. + static constexpr int kCopyThreshold = 2048; + static constexpr double kOccupancy = 0.7; + if (swap_memory_ && !alloc_attrs.on_host() && + tensor.TotalBytes() > kCopyThreshold && stack->IsUsefulToSwap(tensor)) { + DeviceContext* device_ctxt = ctx->op_device_context(); + auto device = static_cast(ctx->device()); + Allocator* allocator = device->GetAllocator(alloc_attrs); + AllocatorStats stats; + allocator->GetStats(&stats); + if (stats.bytes_in_use > (stats.bytes_limit * kOccupancy)) { + // Asynchronously copy the tensor from GPU to CPU memory. + // TODO(yuanbyu): Swap the oldest tensor first. + AllocatorAttributes host_alloc_attrs; + host_alloc_attrs.set_gpu_compatible(true); + host_alloc_attrs.set_on_host(true); + Allocator* cpu_allocator = device->GetAllocator(host_alloc_attrs); + Tensor* cpu_tensor = + new Tensor(cpu_allocator, tensor.dtype(), tensor.shape()); + device_ctxt->CopyDeviceTensorToCPU( + &tensor, "StackPush", device, cpu_tensor, + [cpu_tensor, stack, ctx, done](const Status& s) { + ctx->SetStatus(s); + if (s.ok()) { + AllocatorAttributes alloc_attrs = ctx->input_alloc_attr(1); + ctx->SetStatus(stack->Push({*cpu_tensor, alloc_attrs, true})); + } + if (ctx->status().ok()) { + ctx->set_output(0, *cpu_tensor); + } + done(); + delete cpu_tensor; + }); + return; + } + } + + // Execute synchronously if not swapped. + OP_REQUIRES_OK_ASYNC(ctx, stack->Push({tensor, alloc_attrs, false}), done); + ctx->set_output(0, tensor); + done(); +} + +bool StackPushOp::IsExpensive() { return false; } + +// StackPopOp + +StackPopOp::StackPopOp(OpKernelConstruction* context) + : AsyncOpKernel(context) {} + +void StackPopOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { + // Get the stack from the handle. + Stack* stack = nullptr; + OP_REQUIRES_OK_ASYNC(ctx, GetStack(ctx, &stack), done); + core::ScopedUnref unref(stack); + + // Pop the tensor. Transfer the tensor back to device if it was + // swapped out to CPU. + Stack::TensorAndAllocation value; + OP_REQUIRES_OK_ASYNC(ctx, stack->Pop(&value), done); + if (value.swapped_to_cpu) { + // Asynchronously copy the tensor back from CPU to GPU memory. + DeviceContext* device_ctxt = ctx->op_device_context(); + Device* device = static_cast(ctx->device()); + Tensor* cpu_tensor = &value.tensor; + Allocator* gpu_allocator = device->GetAllocator(value.alloc_attrs); + Tensor* device_tensor = + new Tensor(gpu_allocator, cpu_tensor->dtype(), cpu_tensor->shape()); + device_ctxt->CopyCPUTensorToDevice( + cpu_tensor, device, device_tensor, + [device_tensor, ctx, done](const Status& s) { + ctx->SetStatus(s); + if (s.ok()) { + ctx->set_output(0, *device_tensor); + } + done(); + delete device_tensor; + }); + } else { + // Execute synchronously if not swapped. + ctx->set_output(0, value.tensor); + done(); + } +} + +bool StackPopOp::IsExpensive() { return false; } + +// StackCloseOp + +StackCloseOp::StackCloseOp(OpKernelConstruction* context) : OpKernel(context) {} + +void StackCloseOp::Compute(OpKernelContext* ctx) { + Stack* stack = nullptr; + OP_REQUIRES_OK(ctx, GetStack(ctx, &stack)); + core::ScopedUnref unref(stack); + stack->Close(); +} + +bool StackCloseOp::IsExpensive() { return false; } + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/stack.h b/tensorflow/core/kernels/stack.h new file mode 100644 index 00000000000000..e1927e1f28fa21 --- /dev/null +++ b/tensorflow/core/kernels/stack.h @@ -0,0 +1,76 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_STACK_H_ +#define TENSORFLOW_CORE_KERNELS_STACK_H_ + +// See docs in ../ops/data_flow_ops.cc. + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// A per-run local stack. The stack uses a "per-step" resource manager which +// ensures that correct garbage collection on error or successful completion. +class StackOp : public OpKernel { + public: + explicit StackOp(OpKernelConstruction* context); + void Compute(OpKernelContext* ctx) override; + + private: + DataType elem_type_; + string stack_name_; + + TF_DISALLOW_COPY_AND_ASSIGN(StackOp); +}; + +class StackPushOp : public AsyncOpKernel { + public: + StackPushOp(OpKernelConstruction* context, bool allow_swapping); + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; + bool IsExpensive() override; + + private: + bool swap_memory_ = false; +}; + +// Templated helper to make it easier to register kernels with or without +// swapping. +template +class TemplatedStackPushOp : public StackPushOp { + public: + TemplatedStackPushOp(OpKernelConstruction* context) + : StackPushOp(context, allow_swapping) {} +}; + +class StackPopOp : public AsyncOpKernel { + public: + explicit StackPopOp(OpKernelConstruction* context); + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; + bool IsExpensive() override; +}; + +class StackCloseOp : public OpKernel { + public: + explicit StackCloseOp(OpKernelConstruction* context); + void Compute(OpKernelContext* ctx) override; + bool IsExpensive() override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_STACK_H_ diff --git a/tensorflow/core/kernels/stack_ops.cc b/tensorflow/core/kernels/stack_ops.cc index add4afafc92d4e..df94a8818e7edd 100644 --- a/tensorflow/core/kernels/stack_ops.cc +++ b/tensorflow/core/kernels/stack_ops.cc @@ -15,6 +15,8 @@ limitations under the License. // See docs in ../ops/data_flow_ops.cc. +#include "tensorflow/core/kernels/stack.h" + #include #include #include @@ -38,191 +40,6 @@ limitations under the License. namespace tensorflow { -typedef Eigen::ThreadPoolDevice CPUDevice; -typedef Eigen::GpuDevice GPUDevice; -#ifdef TENSORFLOW_USE_SYCL -typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL - -class Stack : public ResourceBase { - public: - static std::atomic stack_counter; - - struct TensorAndAllocation { - Tensor tensor; - AllocatorAttributes alloc_attrs; - bool swapped_to_cpu; - }; - - Stack(const DataType& elem_type, const string& stack_name, int max_size) - : elem_type_(elem_type), - stack_name_(stack_name), - max_size_(max_size), - closed_(false) {} - - Status Push(const TensorAndAllocation& value) { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(CheckNotClosed()); - if (max_size_ >= 0 && stack_.size() >= max_size_) { - return errors::InvalidArgument("Stack[", stack_name_, "] overflowed ", - "its max_size (", max_size_, ")"); - } - stack_.push_back(value); - return Status::OK(); - } - - Status Pop(TensorAndAllocation* value) { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(CheckNotClosed()); - if (stack_.empty()) { - return errors::InvalidArgument("Stack[", stack_name_, - "] is empty when calling Pop()."); - } - *value = stack_.back(); - stack_.pop_back(); - return Status::OK(); - } - - // We don't swap the first tensor on the stack and any subsequent tensors - // that share the buffer with the first tensor. - bool IsUsefulToSwap(const Tensor& tensor) const { - mutex_lock l(mu_); - if (stack_.empty()) { - return false; - } - const Tensor& first = stack_.front().tensor; - return !tensor.SharesBufferWith(first); - } - - void Close() { - mutex_lock l(mu_); - stack_.clear(); - closed_ = true; - } - - DataType ElemType() { return elem_type_; } - - string DebugString() override { - mutex_lock l(mu_); - return strings::StrCat("Stack[", stack_name_, "]"); - } - - const string& stack_name() { return stack_name_; } - - private: - friend class StackOp; - mutex* mu() { return &mu_; } - - mutable mutex mu_; - DataType elem_type_; - const string stack_name_; - Tensor handle_; - int max_size_; - bool closed_ GUARDED_BY(mu_); - std::vector stack_ GUARDED_BY(mu_); - - Status CheckNotClosed() const EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (closed_) { - return errors::InvalidArgument("Stack[", stack_name_, - "] has already been closed."); - } - return Status::OK(); - } -}; - -Status GetStack(OpKernelContext* ctx, Stack** stack) { - if (ctx->input_dtype(0) == DT_RESOURCE) { - return LookupResource(ctx, HandleFromInput(ctx, 0), stack); - } else { - Tensor Tstack_handle = ctx->mutable_input(0, false); - if (Tstack_handle.NumElements() != 2) { - return errors::InvalidArgument( - "Stack handle must have two elements, but had shape: ", - Tstack_handle.shape().DebugString()); - } - const string& container = Tstack_handle.flat()(0); - const string& stack_name = Tstack_handle.flat()(1); - string key = strings::StrCat(container, stack_name); - ResourceMgr* rm = ctx->resource_manager(); - if (rm == nullptr) { - return errors::Internal("No resource manager."); - } - auto* step_container = ctx->step_container(); - if (step_container == nullptr) { - return errors::Internal("No step container."); - } - TF_RETURN_IF_ERROR(rm->Lookup(step_container->name(), key, stack)); - return Status::OK(); - } -} - -std::atomic Stack::stack_counter{0}; - -// A per-run local stack. The stack uses a "per-step" resource manager which -// ensures that correct garbage collection on error or successful completion. -class StackOp : public OpKernel { - public: - explicit StackOp(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("elem_type", &elem_type_)); - OP_REQUIRES_OK(context, context->GetAttr("stack_name", &stack_name_)); - if (stack_name_.empty()) stack_name_ = name(); - } - - void Compute(OpKernelContext* ctx) override { - int32 size = std::numeric_limits::max(); - if (ctx->num_inputs() > 0) { - const Tensor* tensor_size; - OP_REQUIRES_OK(ctx, ctx->input("max_size", &tensor_size)); - - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tensor_size->shape()), - errors::InvalidArgument( - "Stack size must be a scalar, but had shape: ", - tensor_size->shape().DebugString())); - - int32 size_value = tensor_size->scalar()(); - if (size_value >= 0) { - size = size_value; - } - } - - static const char kContainer[] = "_stacks"; - auto stack_id = Stack::stack_counter.fetch_add(1); - string stack_name = strings::StrCat(stack_name_, "_", stack_id); - // Store the handle in a per-step container. - ResourceMgr* rm = ctx->resource_manager(); - OP_REQUIRES(ctx, rm != nullptr, errors::Internal("No resource manager.")); - string key = strings::StrCat(kContainer, stack_name); - Stack* stack = new Stack(elem_type_, stack_name, size); - auto* step_container = ctx->step_container(); - OP_REQUIRES(ctx, step_container != nullptr, - errors::Internal("No step container.")); - OP_REQUIRES_OK(ctx, rm->Create(step_container->name(), key, stack)); - if (IsRefType(ctx->expected_output_dtype(0))) { - // Create the stack handle. - AllocatorAttributes alloc_attr; - alloc_attr.set_on_host(true); - OP_REQUIRES_OK(ctx, ctx->allocate_temp(tensorflow::DT_STRING, - tensorflow::TensorShape({2}), - &stack->handle_, alloc_attr)); - auto handle = stack->handle_.flat(); - handle(0) = kContainer; - handle(1) = std::move(stack_name); - ctx->set_output_ref(0, stack->mu(), &stack->handle_); - } else { - Tensor* handle; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle)); - handle->flat()(0) = - MakePerStepResourceHandle(ctx, key); - } - } - - private: - DataType elem_type_; - string stack_name_; - - TF_DISALLOW_COPY_AND_ASSIGN(StackOp); -}; - REGISTER_KERNEL_BUILDER(Name("Stack").Device(DEVICE_CPU), StackOp); REGISTER_KERNEL_BUILDER(Name("Stack").Device(DEVICE_GPU).HostMemory("handle"), StackOp); @@ -242,102 +59,22 @@ REGISTER_KERNEL_BUILDER(Name("StackV2") StackOp); #endif // TENSORFLOW_USE_SYCL -template -class StackPushOp : public AsyncOpKernel { - public: - explicit StackPushOp(OpKernelConstruction* context) : AsyncOpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("swap_memory", &swap_memory_)); - } - - void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { - // Get the stack from the handle. - Stack* stack = nullptr; - OP_REQUIRES_OK_ASYNC(ctx, GetStack(ctx, &stack), done); - core::ScopedUnref unref(stack); - - if (ctx->input_dtype(1) != stack->ElemType()) { - ctx->CtxFailure(errors::InvalidArgument("Must have type ", - stack->ElemType(), " but got ", - ctx->input_dtype(1))); - done(); - return; - } - - // Push the tensor onto the stack. Swap the tensor to CPU if instructed. - const Tensor& tensor = ctx->input(1); - AllocatorAttributes alloc_attrs = ctx->input_alloc_attr(1); - // For now, we use a simple heuristic for swapping: A GPU tensor is moved - // to CPU if the tensor has more than kCopyThreshold bytes and the GPU - // allocator says more than kOccupancy of the memory is in use. - static constexpr int kCopyThreshold = 2048; - static constexpr double kOccupancy = 0.7; - if (swap_memory_ && !alloc_attrs.on_host() && - (std::is_same::value -#ifdef TENSORFLOW_USE_SYCL - || std::is_same::value -#endif // TENSORFLOW_USE_SYCL - ) && - tensor.TotalBytes() > kCopyThreshold && stack->IsUsefulToSwap(tensor)) { - DeviceContext* device_ctxt = ctx->op_device_context(); - auto device = static_cast(ctx->device()); - Allocator* allocator = device->GetAllocator(alloc_attrs); - AllocatorStats stats; - allocator->GetStats(&stats); - if (stats.bytes_in_use > (stats.bytes_limit * kOccupancy)) { - // Asynchronously copy the tensor from GPU to CPU memory. - // TODO(yuanbyu): Swap the oldest tensor first. - AllocatorAttributes host_alloc_attrs; - host_alloc_attrs.set_gpu_compatible(true); - host_alloc_attrs.set_on_host(true); - Allocator* cpu_allocator = device->GetAllocator(host_alloc_attrs); - Tensor* cpu_tensor = - new Tensor(cpu_allocator, tensor.dtype(), tensor.shape()); - device_ctxt->CopyDeviceTensorToCPU( - &tensor, "StackPush", device, cpu_tensor, - [cpu_tensor, stack, ctx, done](const Status& s) { - ctx->SetStatus(s); - if (s.ok()) { - AllocatorAttributes alloc_attrs = ctx->input_alloc_attr(1); - ctx->SetStatus(stack->Push({*cpu_tensor, alloc_attrs, true})); - } - if (ctx->status().ok()) { - ctx->set_output(0, *cpu_tensor); - } - done(); - delete cpu_tensor; - }); - return; - } - } - - // Execute synchronously if not swapped. - OP_REQUIRES_OK_ASYNC(ctx, stack->Push({tensor, alloc_attrs, false}), done); - ctx->set_output(0, tensor); - done(); - } - - bool IsExpensive() override { return false; } - - private: - bool swap_memory_; -}; - REGISTER_KERNEL_BUILDER(Name("StackPush").Device(DEVICE_CPU), - StackPushOp); + TemplatedStackPushOp); REGISTER_KERNEL_BUILDER(Name("StackPushV2").Device(DEVICE_CPU), - StackPushOp); - -#define REGISTER_GPU_KERNEL(type) \ - REGISTER_KERNEL_BUILDER(Name("StackPush") \ - .Device(DEVICE_GPU) \ - .HostMemory("handle") \ - .TypeConstraint("T"), \ - StackPushOp); \ - REGISTER_KERNEL_BUILDER(Name("StackPushV2") \ - .Device(DEVICE_GPU) \ - .HostMemory("handle") \ - .TypeConstraint("T"), \ - StackPushOp); + TemplatedStackPushOp); + +#define REGISTER_GPU_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("StackPush") \ + .Device(DEVICE_GPU) \ + .HostMemory("handle") \ + .TypeConstraint("T"), \ + TemplatedStackPushOp); \ + REGISTER_KERNEL_BUILDER(Name("StackPushV2") \ + .Device(DEVICE_GPU) \ + .HostMemory("handle") \ + .TypeConstraint("T"), \ + TemplatedStackPushOp); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); #undef REGISTER_GPU_KERNEL @@ -345,21 +82,21 @@ TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); // Special GPU kernels for int32 and bool. // TODO(b/25387198): Also enable int32 in device memory. This kernel // registration requires all int32 inputs and outputs to be in host memory. -#define REGISTER_GPU_HOST_KERNEL(type) \ - REGISTER_KERNEL_BUILDER(Name("StackPush") \ - .Device(DEVICE_GPU) \ - .HostMemory("handle") \ - .HostMemory("elem") \ - .HostMemory("output") \ - .TypeConstraint("T"), \ - StackPushOp); \ - REGISTER_KERNEL_BUILDER(Name("StackPushV2") \ - .Device(DEVICE_GPU) \ - .HostMemory("handle") \ - .HostMemory("elem") \ - .HostMemory("output") \ - .TypeConstraint("T"), \ - StackPushOp); +#define REGISTER_GPU_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("StackPush") \ + .Device(DEVICE_GPU) \ + .HostMemory("handle") \ + .HostMemory("elem") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + TemplatedStackPushOp); \ + REGISTER_KERNEL_BUILDER(Name("StackPushV2") \ + .Device(DEVICE_GPU) \ + .HostMemory("handle") \ + .HostMemory("elem") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + TemplatedStackPushOp); REGISTER_GPU_HOST_KERNEL(int32); REGISTER_GPU_HOST_KERNEL(bool); @@ -372,7 +109,7 @@ REGISTER_GPU_HOST_KERNEL(bool); .Device(DEVICE_SYCL) \ .HostMemory("handle") \ .TypeConstraint("T"), \ - StackPushOp); + TemplatedStackPushOp); TF_CALL_GPU_NUMBER_TYPES(REGISTER_SYCL_KERNEL); @@ -383,7 +120,7 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_SYCL_KERNEL); .HostMemory("elem") \ .HostMemory("output") \ .TypeConstraint("T"), \ - StackPushOp) + TemplatedStackPushOp) REGISTER_SYCL_HOST_KERNEL(int32); REGISTER_SYCL_HOST_KERNEL(bool); @@ -391,48 +128,6 @@ REGISTER_SYCL_HOST_KERNEL(bool); #undef REGISTER_SYCL_HOST_KERNEL #endif // TENSORFLOW_USE_SYCL -class StackPopOp : public AsyncOpKernel { - public: - explicit StackPopOp(OpKernelConstruction* context) : AsyncOpKernel(context) {} - - void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { - // Get the stack from the handle. - Stack* stack = nullptr; - OP_REQUIRES_OK_ASYNC(ctx, GetStack(ctx, &stack), done); - core::ScopedUnref unref(stack); - - // Pop the tensor. Transfer the tensor back to device if it was - // swapped out to CPU. - Stack::TensorAndAllocation value; - OP_REQUIRES_OK_ASYNC(ctx, stack->Pop(&value), done); - if (value.swapped_to_cpu) { - // Asynchronously copy the tensor back from CPU to GPU memory. - DeviceContext* device_ctxt = ctx->op_device_context(); - Device* device = static_cast(ctx->device()); - Tensor* cpu_tensor = &value.tensor; - Allocator* gpu_allocator = device->GetAllocator(value.alloc_attrs); - Tensor* device_tensor = - new Tensor(gpu_allocator, cpu_tensor->dtype(), cpu_tensor->shape()); - device_ctxt->CopyCPUTensorToDevice( - cpu_tensor, device, device_tensor, - [device_tensor, ctx, done](const Status& s) { - ctx->SetStatus(s); - if (s.ok()) { - ctx->set_output(0, *device_tensor); - } - done(); - delete device_tensor; - }); - } else { - // Execute synchronously if not swapped. - ctx->set_output(0, value.tensor); - done(); - } - } - - bool IsExpensive() override { return false; } -}; - REGISTER_KERNEL_BUILDER(Name("StackPop").Device(DEVICE_CPU), StackPopOp); REGISTER_KERNEL_BUILDER(Name("StackPopV2").Device(DEVICE_CPU), StackPopOp); @@ -498,20 +193,6 @@ REGISTER_SYCL_HOST_KERNEL(bool); #undef REGISTER_SYCL_HOST_KERNEL #endif // TENSORFLOW_USE_SYCL -class StackCloseOp : public OpKernel { - public: - explicit StackCloseOp(OpKernelConstruction* context) : OpKernel(context) {} - - void Compute(OpKernelContext* ctx) override { - Stack* stack = nullptr; - OP_REQUIRES_OK(ctx, GetStack(ctx, &stack)); - core::ScopedUnref unref(stack); - stack->Close(); - } - - bool IsExpensive() override { return false; } -}; - REGISTER_KERNEL_BUILDER(Name("StackClose").Device(DEVICE_CPU), StackCloseOp); REGISTER_KERNEL_BUILDER( Name("StackClose").Device(DEVICE_GPU).HostMemory("handle"), StackCloseOp); diff --git a/tensorflow/core/kernels/strided_slice_op_impl.h b/tensorflow/core/kernels/strided_slice_op_impl.h index 91058b5a2f4e3f..c4205159c380cb 100644 --- a/tensorflow/core/kernels/strided_slice_op_impl.h +++ b/tensorflow/core/kernels/strided_slice_op_impl.h @@ -179,7 +179,7 @@ class HandleStridedSliceAssignCase { } }; -// NODE(aselle): according to bsteiner, we need this because otherwise +// NOTE(aselle): according to bsteiner, we need this because otherwise // nvcc instantiates templates that are invalid. strided_slice_op_gpu.cu // handles instantiates externally. It is important that this is done // before the HandleXXCase's are instantiated to avoid duplicate diff --git a/tensorflow/core/kernels/tile_functor.h b/tensorflow/core/kernels/tile_functor.h index 95986af8b77a05..9a460d191fc917 100644 --- a/tensorflow/core/kernels/tile_functor.h +++ b/tensorflow/core/kernels/tile_functor.h @@ -36,9 +36,11 @@ void TileUsingEigen(const Device& d, Tensor* out, const Tensor& in, auto x = in.tensor(); auto y = out->tensor(); + bool use_32bit = y.size() < Eigen::NumTraits::highest(); + Eigen::array b; for (int i = 0; i < NDIM; ++i) b[i] = broadcast_array[i]; - if (Eigen::internal::is_same::value) { + if (use_32bit && Eigen::internal::is_same::value) { // Use 32bit indexing to speed up the computations To32Bit(y).device(d) = To32Bit(x).broadcast(b); } else { diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 329a8b9f5e6dad..916deabd6ff677 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -21903,6 +21903,29 @@ op { } is_stateful: true } +op { + name: "ExperimentalNonSerializableDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} op { name: "ExperimentalNumaMapAndBatchDataset" input_arg { @@ -29132,6 +29155,75 @@ op { } } } +op { + name: "LeakyRelu" + input_arg { + name: "features" + type_attr: "T" + } + output_arg { + name: "activations" + type_attr: "T" + } + attr { + name: "alpha" + type: "float" + default_value { + f: 0.2 + } + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_BFLOAT16 + type: DT_FLOAT + type: DT_DOUBLE + } + } + } +} +op { + name: "LeakyReluGrad" + input_arg { + name: "gradients" + type_attr: "T" + } + input_arg { + name: "features" + type_attr: "T" + } + output_arg { + name: "backprops" + type_attr: "T" + } + attr { + name: "alpha" + type: "float" + default_value { + f: 0.2 + } + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } +} op { name: "LeakyReluGrad" input_arg { @@ -29162,6 +29254,7 @@ op { allowed_values { list { type: DT_HALF + type: DT_BFLOAT16 type: DT_FLOAT type: DT_DOUBLE } diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc index d077954f9eb041..088d1865ddf074 100644 --- a/tensorflow/core/ops/experimental_dataset_ops.cc +++ b/tensorflow/core/ops/experimental_dataset_ops.cc @@ -75,6 +75,13 @@ REGISTER_OP("ExperimentalIgnoreErrorsDataset") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("ExperimentalNonSerializableDataset") + .Input("input_dataset: variant") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape); + REGISTER_OP("ExperimentalSleepDataset") .Input("input_dataset: variant") .Input("sleep_microseconds: int64") diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 38fe45936a12a7..9796587709bba3 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -987,7 +987,7 @@ REGISTER_OP("LeakyRelu") .Input("features: T") .Output("activations: T") .Attr("alpha: float = 0.2") - .Attr("T: {half, float, double} = DT_FLOAT") + .Attr("T: {half, bfloat16, float, double} = DT_FLOAT") .SetShapeFn(shape_inference::UnchangedShape); REGISTER_OP("LeakyReluGrad") @@ -995,7 +995,7 @@ REGISTER_OP("LeakyReluGrad") .Input("features: T") .Output("backprops: T") .Attr("alpha: float = 0.2") - .Attr("T: {half, float, double} = DT_FLOAT") + .Attr("T: {half, bfloat16, float, double} = DT_FLOAT") .SetShapeFn(shape_inference::MergeBothInputsShapeFn); REGISTER_OP("Elu") @@ -2385,7 +2385,7 @@ REGISTER_OP("_MklToTf") .Input("input: T") .Input("mkl_input: uint8") .Output("output: T") - .Attr("T: {half, float, double}") + .Attr("T: {half, float, double, qint8, quint8, qint32}") .Attr(GetConvnetDataFormat2D3DAttrString()) .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( @@ -2418,6 +2418,343 @@ element-wise MKL op. NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); + +REGISTER_OP("QuantizedConv2DAndRequantize") + .Input("input: Tinput") + .Input("filter: Tfilter") + .Input("min_input: float") + .Input("max_input: float") + .Input("min_filter: float") + .Input("max_filter: float") + .Input("min_freezed_output: float") + .Input("max_freezed_output: float") + .Output("output: out_type") + .Output("min_output: float") + .Output("max_output: float") + .Attr("Tinput: quantizedtype") + .Attr("Tfilter: quantizedtype") + .Attr("out_type: quantizedtype = DT_QINT8") + .Attr("strides: list(int)") + .Attr(GetPaddingAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }); + +// Fusion of Quantized Conv2D and BiasAdd. +REGISTER_OP("QuantizedConv2DWithBias") + .Input("input: Tinput") + .Input("filter: Tfilter") + .Input("bias: float") + .Input("min_input: float") + .Input("max_input: float") + .Input("min_filter: float") + .Input("max_filter: float") + .Output("output: out_type") + .Output("min_output: float") + .Output("max_output: float") + .Attr("Tinput: quantizedtype") + .Attr("Tfilter: quantizedtype") + .Attr("out_type: quantizedtype = DT_QINT32") + .Attr("strides: list(int)") + .Attr(GetPaddingAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }); + +REGISTER_OP("QuantizedConv2DWithBiasAndRequantize") + .Input("input: Tinput") + .Input("filter: Tfilter") + .Input("bias: Tbias") + .Input("min_input: float") + .Input("max_input: float") + .Input("min_filter: float") + .Input("max_filter: float") + .Input("min_freezed_output: float") + .Input("max_freezed_output: float") + .Output("output: out_type") + .Output("min_output: float") + .Output("max_output: float") + .Attr("Tinput: quantizedtype") + .Attr("Tfilter: quantizedtype") + .Attr("Tbias: {float, qint32}") + .Attr("out_type: quantizedtype = DT_QINT8") + .Attr("strides: list(int)") + .Attr(GetPaddingAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }); + +// Fusion of Quantized Conv2D and Relu. +REGISTER_OP("QuantizedConv2DAndRelu") + .Input("input: Tinput") + .Input("filter: Tfilter") + .Input("min_input: float") + .Input("max_input: float") + .Input("min_filter: float") + .Input("max_filter: float") + .Output("output: out_type") + .Output("min_output: float") + .Output("max_output: float") + .Attr("Tinput: quantizedtype") + .Attr("Tfilter: quantizedtype") + .Attr("out_type: quantizedtype = DT_QINT32") + .Attr("strides: list(int)") + .Attr(GetPaddingAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }); + +REGISTER_OP("QuantizedConv2DAndReluAndRequantize") + .Input("input: Tinput") + .Input("filter: Tfilter") + .Input("min_input: float") + .Input("max_input: float") + .Input("min_filter: float") + .Input("max_filter: float") + .Input("min_freezed_output: float") + .Input("max_freezed_output: float") + .Output("output: out_type") + .Output("min_output: float") + .Output("max_output: float") + .Attr("Tinput: quantizedtype") + .Attr("Tfilter: quantizedtype") + .Attr("out_type: quantizedtype = DT_QUINT8") + .Attr("strides: list(int)") + .Attr(GetPaddingAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }); + +// Fusion of Quantized Conv2D, BiasAdd and Relu. +REGISTER_OP("QuantizedConv2DWithBiasAndRelu") + .Input("input: Tinput") + .Input("filter: Tfilter") + .Input("bias: float") + .Input("min_input: float") + .Input("max_input: float") + .Input("min_filter: float") + .Input("max_filter: float") + .Output("output: out_type") + .Output("min_output: float") + .Output("max_output: float") + .Attr("Tinput: quantizedtype") + .Attr("Tfilter: quantizedtype") + .Attr("out_type: quantizedtype = DT_QINT32") + .Attr("strides: list(int)") + .Attr(GetPaddingAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }); + +// Fusion of Quantized Conv2D, BiasAdd, Relu, and Requantize. +REGISTER_OP("QuantizedConv2DWithBiasAndReluAndRequantize") + .Input("input: Tinput") + .Input("filter: Tfilter") + .Input("bias: Tbias") + .Input("min_input: float") + .Input("max_input: float") + .Input("min_filter: float") + .Input("max_filter: float") + .Input("min_freezed_output: float") + .Input("max_freezed_output: float") + .Output("output: out_type") + .Output("min_output: float") + .Output("max_output: float") + .Attr("Tinput: quantizedtype") + .Attr("Tfilter: quantizedtype") + .Attr("Tbias: {float, qint32}") + .Attr("out_type: quantizedtype = DT_QUINT8") + .Attr("strides: list(int)") + .Attr(GetPaddingAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }); + +// Fusion of Quantized Conv2D, BiasAdd, Sum, and Relu. +REGISTER_OP("QuantizedConv2DWithBiasSumAndRelu") + .Input("input: Tinput") + .Input("filter: Tfilter") + .Input("bias: float") + .Input("min_input: float") + .Input("max_input: float") + .Input("min_filter: float") + .Input("max_filter: float") + .Input("summand: float") + .Output("output: out_type") + .Output("min_output: float") + .Output("max_output: float") + .Attr("Tinput: quantizedtype") + .Attr("Tfilter: quantizedtype") + .Attr("out_type: quantizedtype = DT_QINT32") + .Attr("strides: list(int)") + .Attr(GetPaddingAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }); + +REGISTER_OP("QuantizedConv2DWithBiasSumAndReluAndRequantize") + .Input("input: Tinput") + .Input("filter: Tfilter") + .Input("bias: Tbias") + .Input("min_input: float") + .Input("max_input: float") + .Input("min_filter: float") + .Input("max_filter: float") + .Input("min_freezed_output: float") + .Input("max_freezed_output: float") + .Input("summand: Tsummand") + .Input("min_summand: float") + .Input("max_summand: float") + .Output("output: out_type") + .Output("min_output: float") + .Output("max_output: float") + .Attr("Tinput: quantizedtype") + .Attr("Tfilter: quantizedtype") + .Attr("Tbias: {float, qint32}") + .Attr("Tsummand: quantizedtype") + .Attr("out_type: quantizedtype = DT_QUINT8") + .Attr("strides: list(int)") + .Attr(GetPaddingAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }); + +REGISTER_OP("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize") + .Input("input: Tinput") + .Input("filter: Tfilter") + .Input("bias: Tbias") + .Input("min_input: float") + .Input("max_input: float") + .Input("min_filter: float") + .Input("max_filter: float") + .Input("min_freezed_output: float") + .Input("max_freezed_output: float") + .Input("summand: Tsummand") + .Input("min_summand: float") + .Input("max_summand: float") + .Output("output: out_type") + .Output("min_output: float") + .Output("max_output: float") + .Attr("Tinput: quantizedtype") + .Attr("Tfilter: quantizedtype") + .Attr("Tbias: {float, qint32}") + .Attr("Tsummand: quantizedtype") + .Attr("out_type: quantizedtype = DT_QUINT8") + .Attr("strides: list(int)") + .Attr(GetPaddingAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }); + #endif // INTEL_MKL } // namespace tensorflow diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 68d8e8e25a84ce..f590794f3054a8 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -10366,6 +10366,29 @@ op { } is_stateful: true } +op { + name: "ExperimentalNonSerializableDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} op { name: "ExperimentalNumaMapAndBatchDataset" input_arg { @@ -14397,6 +14420,7 @@ op { allowed_values { list { type: DT_HALF + type: DT_BFLOAT16 type: DT_FLOAT type: DT_DOUBLE } @@ -14433,6 +14457,7 @@ op { allowed_values { list { type: DT_HALF + type: DT_BFLOAT16 type: DT_FLOAT type: DT_DOUBLE } diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index 104ab039cb71e5..4689af06afedb2 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -5,7 +5,7 @@ option cc_enable_arenas = true; option java_outer_classname = "ConfigProtos"; option java_multiple_files = true; option java_package = "org.tensorflow.framework"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf"; +// add go_package externally with copybara import "tensorflow/core/framework/cost_graph.proto"; import "tensorflow/core/framework/graph.proto"; import "tensorflow/core/framework/step_stats.proto"; @@ -148,6 +148,14 @@ message GPUOptions { // for each GPUDevice. Default value is 0, which is automatically // converted to 1. int32 num_dev_to_dev_copy_streams = 3; + + // If non-empty, defines a good GPU ring order on a single worker based on + // device interconnect. This assumes that all workers have the same GPU + // topology. Specify as a comma-separated string, e.g. "3,2,1,0,7,6,5,4". + // This ring order is used by the RingReducer implementation of + // CollectiveReduce, and serves as an override to automatic ring order + // generation in OrderTaskDeviceMap() during CollectiveParam resolution. + string collective_ring_order = 4; } // Everything inside experimental is subject to change and is not subject @@ -400,6 +408,11 @@ message ConfigProto { // Which executor to use, the default executor will be used // if it is an empty string or "DEFAULT" string executor_type = 3; + + // Guidance to formatting of large RecvBuf fields for transfer. + // Any positive value sets the max chunk size. 0 defaults to 4096. + // Any negative value indicates no max, i.e. one chunk only. + int32 recv_buf_max_chunk = 4; }; Experimental experimental = 16; diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto index 143df115f424ae..d68f2735365b0a 100644 --- a/tensorflow/core/protobuf/rewriter_config.proto +++ b/tensorflow/core/protobuf/rewriter_config.proto @@ -137,6 +137,11 @@ message RewriterConfig { // meta-optimizer or when manually specified through the optimizers field. AutoParallelOptions auto_parallel = 5; + // If true, any optimization pass failing will cause the MetaOptimizer to + // stop with an error. By default - or when set to false, failing passes are + // skipped silently. + bool fail_on_optimizer_errors = 21; + ScopedAllocatorOptions scoped_allocator_opts = 16; // If non-empty, will use this as an alternative way to specify a list of diff --git a/tensorflow/core/protobuf/transport_options.proto b/tensorflow/core/protobuf/transport_options.proto index d7b1bddbbe3d7d..1d32475e9b9d6c 100644 --- a/tensorflow/core/protobuf/transport_options.proto +++ b/tensorflow/core/protobuf/transport_options.proto @@ -4,5 +4,5 @@ package tensorflow; // Extra data needed on a non-RDMA RecvBufResponse. message RecvBufRespExtra { - bytes tensor_content = 1; + repeated bytes tensor_content = 1; }; diff --git a/tensorflow/core/util/ctc/ctc_beam_entry.h b/tensorflow/core/util/ctc/ctc_beam_entry.h index 24002e72a0920f..7382b8e6849b88 100644 --- a/tensorflow/core/util/ctc/ctc_beam_entry.h +++ b/tensorflow/core/util/ctc/ctc_beam_entry.h @@ -146,4 +146,4 @@ class BeamComparer { } // namespace tensorflow #endif // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_ -// LINT.ThenChange(//tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h) +// LINT.ThenChange(//tensorflow/lite/experimental/kernels/ctc_beam_entry.h) diff --git a/tensorflow/core/util/ctc/ctc_beam_scorer.h b/tensorflow/core/util/ctc/ctc_beam_scorer.h index 1e45a8abd39a75..fc63dfb0fd2901 100644 --- a/tensorflow/core/util/ctc/ctc_beam_scorer.h +++ b/tensorflow/core/util/ctc/ctc_beam_scorer.h @@ -74,4 +74,4 @@ class BaseBeamScorer { } // namespace tensorflow #endif // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SCORER_H_ -// LINT.ThenChange(//tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h) +// LINT.ThenChange(//tensorflow/lite/experimental/kernels/ctc_beam_scorer.h) diff --git a/tensorflow/core/util/ctc/ctc_beam_search.h b/tensorflow/core/util/ctc/ctc_beam_search.h index 6fbb1ed0dae179..f2022d486c76e2 100644 --- a/tensorflow/core/util/ctc/ctc_beam_search.h +++ b/tensorflow/core/util/ctc/ctc_beam_search.h @@ -431,4 +431,4 @@ Status CTCBeamSearchDecoder::TopPaths( } // namespace tensorflow #endif // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_ -// LINT.ThenChange(//tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h) +// LINT.ThenChange(//tensorflow/lite/experimental/kernels/ctc_beam_search.h) diff --git a/tensorflow/core/util/ctc/ctc_decoder.h b/tensorflow/core/util/ctc/ctc_decoder.h index b55d7d77ac0f07..f5c9e4bb596dac 100644 --- a/tensorflow/core/util/ctc/ctc_decoder.h +++ b/tensorflow/core/util/ctc/ctc_decoder.h @@ -113,4 +113,4 @@ class CTCGreedyDecoder : public CTCDecoder { } // namespace tensorflow #endif // TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_ -// LINT.ThenChange(//tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h) +// LINT.ThenChange(//tensorflow/lite/experimental/kernels/ctc_decoder.h) diff --git a/tensorflow/core/util/ctc/ctc_loss_util.h b/tensorflow/core/util/ctc/ctc_loss_util.h index 054412d388dd53..df0de926d9a8b6 100644 --- a/tensorflow/core/util/ctc/ctc_loss_util.h +++ b/tensorflow/core/util/ctc/ctc_loss_util.h @@ -47,4 +47,4 @@ inline float LogSumExp(float log_prob_1, float log_prob_2) { } // namespace tensorflow #endif // TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_UTIL_H_ -// LINT.ThenChange(//tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h) +// LINT.ThenChange(//tensorflow/lite/experimental/kernels/ctc_loss_util.h) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index d9706b5478335a..d65063fe794909 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -5475,6 +5475,78 @@ func OrderedMapUnstage(scope *Scope, key tf.Output, indices tf.Output, dtypes [] return values } +// OrderedMapPeekAttr is an optional argument to OrderedMapPeek. +type OrderedMapPeekAttr func(optionalAttr) + +// OrderedMapPeekCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func OrderedMapPeekCapacity(value int64) OrderedMapPeekAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// OrderedMapPeekMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func OrderedMapPeekMemoryLimit(value int64) OrderedMapPeekAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// OrderedMapPeekContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func OrderedMapPeekContainer(value string) OrderedMapPeekAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// OrderedMapPeekSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func OrderedMapPeekSharedName(value string) OrderedMapPeekAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op peeks at the values at the specified key. If the +// +// underlying container does not contain this key +// this op will block until it does. This Op is optimized for +// performance. +func OrderedMapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...OrderedMapPeekAttr) (values []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "OrderedMapPeek", + Input: []tf.Input{ + key, indices, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("OrderedMapPeek", err) + return + } + return values +} + // Returns the truth value of x OR y element-wise. // // *NOTE*: `LogicalOr` supports broadcasting. More about broadcasting @@ -6994,6 +7066,36 @@ func Cast(scope *Scope, x tf.Output, DstT tf.DataType, optional ...CastAttr) (y return op.Output(0) } +// Outputs a tensor containing the reduction across all input tensors. +// +// Outputs a tensor containing the reduction across all input tensors passed to ops +// within the same `shared_name. +// +// The graph should be constructed so if one op runs with shared_name value `c`, +// then `num_devices` ops will run with shared_name value `c`. Failure to do so +// will cause the graph execution to fail to complete. +// +// input: the input to the reduction +// data: the value of the reduction across all `num_devices` devices. +// reduction: the reduction operation to perform. +// num_devices: The number of devices participating in this reduction. +// shared_name: Identifier that shared between ops of the same reduction. +func NcclAllReduce(scope *Scope, input tf.Output, reduction string, num_devices int64, shared_name string) (data tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"reduction": reduction, "num_devices": num_devices, "shared_name": shared_name} + opspec := tf.OpSpec{ + Type: "NcclAllReduce", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // RegexReplaceAttr is an optional argument to RegexReplace. type RegexReplaceAttr func(optionalAttr) @@ -8971,6 +9073,32 @@ func SparseFillEmptyRows(scope *Scope, indices tf.Output, values tf.Output, dens return op.Output(0), op.Output(1), op.Output(2), op.Output(3) } +// Reduces `input` from `num_devices` using `reduction` to a single device. +// +// Reduces `input` from `num_devices` using `reduction` to a single device. +// +// The graph should be constructed so that all inputs have a valid device +// assignment, and the op itself is assigned one of these devices. +// +// input: The input to the reduction. +// data: the value of the reduction across all `num_devices` devices. +// reduction: the reduction operation to perform. +func NcclReduce(scope *Scope, input []tf.Output, reduction string) (data tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"reduction": reduction} + opspec := tf.OpSpec{ + Type: "NcclReduce", + Input: []tf.Input{ + tf.OutputList(input), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // BiasAddGradAttr is an optional argument to BiasAddGrad. type BiasAddGradAttr func(optionalAttr) @@ -10738,6 +10866,31 @@ func NonMaxSuppressionV2(scope *Scope, boxes tf.Output, scores tf.Output, max_ou return op.Output(0) } +// Converts a `RaggedTensor` into a `SparseTensor` with the same values. +// +// input=ragged.from_nested_row_splits(rt_dense_values, rt_nested_splits) +// output=SparseTensor(indices=sparse_indices, values=sparse_values, +// dense_shape=sparse_dense_shape) +// +// Arguments: +// rt_nested_splits: The `row_splits` for the `RaggedTensor`. +// rt_dense_values: The `inner_values` for the `RaggedTensor`. +// +// Returns The indices for the `SparseTensor`.The values of the `SparseTensor`.`sparse_dense_shape` is a tight bounding box of the input `RaggedTensor`. +func RaggedTensorToSparse(scope *Scope, rt_nested_splits []tf.Output, rt_dense_values tf.Output) (sparse_indices tf.Output, sparse_values tf.Output, sparse_dense_shape tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RaggedTensorToSparse", + Input: []tf.Input{ + tf.OutputList(rt_nested_splits), rt_dense_values, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + // Check if the input matches the regex pattern. // // The input is a string tensor of any shape. The pattern is a scalar @@ -15397,78 +15550,6 @@ func SparseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf return op.Output(0), op.Output(1), op.Output(2) } -// OrderedMapPeekAttr is an optional argument to OrderedMapPeek. -type OrderedMapPeekAttr func(optionalAttr) - -// OrderedMapPeekCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func OrderedMapPeekCapacity(value int64) OrderedMapPeekAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// OrderedMapPeekMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func OrderedMapPeekMemoryLimit(value int64) OrderedMapPeekAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// OrderedMapPeekContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func OrderedMapPeekContainer(value string) OrderedMapPeekAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// OrderedMapPeekSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func OrderedMapPeekSharedName(value string) OrderedMapPeekAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Op peeks at the values at the specified key. If the -// -// underlying container does not contain this key -// this op will block until it does. This Op is optimized for -// performance. -func OrderedMapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...OrderedMapPeekAttr) (values []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "OrderedMapPeek", - Input: []tf.Input{ - key, indices, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("OrderedMapPeek", err) - return - } - return values -} - // LRNAttr is an optional argument to LRN. type LRNAttr func(optionalAttr) @@ -22120,6 +22201,108 @@ func MatrixTriangularSolve(scope *Scope, matrix tf.Output, rhs tf.Output, option return op.Output(0) } +// UnicodeTranscodeAttr is an optional argument to UnicodeTranscode. +type UnicodeTranscodeAttr func(optionalAttr) + +// UnicodeTranscodeErrors sets the optional errors attribute to value. +// +// value: Error handling policy when there is invalid formatting found in the input. +// The value of 'strict' will cause the operation to produce a InvalidArgument +// error on any invalid input formatting. A value of 'replace' (the default) will +// cause the operation to replace any invalid formatting in the input with the +// `replacement_char` codepoint. A value of 'ignore' will cause the operation to +// skip any invalid formatting in the input and produce no corresponding output +// character. +// If not specified, defaults to "replace" +func UnicodeTranscodeErrors(value string) UnicodeTranscodeAttr { + return func(m optionalAttr) { + m["errors"] = value + } +} + +// UnicodeTranscodeReplacementChar sets the optional replacement_char attribute to value. +// +// value: The replacement character codepoint to be used in place of any invalid +// formatting in the input when `errors='replace'`. Any valid unicode codepoint may +// be used. The default value is the default unicode replacement character is +// 0xFFFD or U+65533.) +// +// Note that for UTF-8, passing a replacement character expressible in 1 byte, such +// as ' ', will preserve string alignment to the source since invalid bytes will be +// replaced with a 1-byte replacement. For UTF-16-BE and UTF-16-LE, any 1 or 2 byte +// replacement character will preserve byte alignment to the source. +// If not specified, defaults to 65533 +func UnicodeTranscodeReplacementChar(value int64) UnicodeTranscodeAttr { + return func(m optionalAttr) { + m["replacement_char"] = value + } +} + +// UnicodeTranscodeReplaceControlCharacters sets the optional replace_control_characters attribute to value. +// +// value: Whether to replace the C0 control characters (00-1F) with the +// `replacement_char`. Default is false. +// If not specified, defaults to false +func UnicodeTranscodeReplaceControlCharacters(value bool) UnicodeTranscodeAttr { + return func(m optionalAttr) { + m["replace_control_characters"] = value + } +} + +// Transcode the input text from a source encoding to a destination encoding. +// +// The input is a string tensor of any shape. The output is a string tensor of +// the same shape containing the transcoded strings. Output strings are always +// valid unicode. If the input contains invalid encoding positions, the +// `errors` attribute sets the policy for how to deal with them. If the default +// error-handling policy is used, invalid formatting will be substituted in the +// output by the `replacement_char`. If the errors policy is to `ignore`, any +// invalid encoding positions in the input are skipped and not included in the +// output. If it set to `strict` then any invalid formatting will result in an +// InvalidArgument error. +// +// This operation can be used with `output_encoding = input_encoding` to enforce +// correct formatting for inputs even if they are already in the desired encoding. +// +// If the input is prefixed by a Byte Order Mark needed to determine encoding +// (e.g. if the encoding is UTF-16 and the BOM indicates big-endian), then that +// BOM will be consumed and not emitted into the output. If the input encoding +// is marked with an explicit endianness (e.g. UTF-16-BE), then the BOM is +// interpreted as a non-breaking-space and is preserved in the output (including +// always for UTF-8). +// +// The end result is that if the input is marked as an explicit endianness the +// transcoding is faithful to all codepoints in the source. If it is not marked +// with an explicit endianness, the BOM is not considered part of the string itself +// but as metadata, and so is not preserved in the output. +// +// Arguments: +// input: The text to be processed. Can have any shape. +// input_encoding: Text encoding of the input strings. This is any of the encodings supported +// by ICU ucnv algorithmic converters. Examples: `"UTF-16", "US ASCII", "UTF-8"`. +// output_encoding: The unicode encoding to use in the output. Must be one of +// `"UTF-8", "UTF-16-BE", "UTF-32-BE"`. Multi-byte encodings will be big-endian. +// +// Returns A string tensor containing unicode text encoded using `output_encoding`. +func UnicodeTranscode(scope *Scope, input tf.Output, input_encoding string, output_encoding string, optional ...UnicodeTranscodeAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"input_encoding": input_encoding, "output_encoding": output_encoding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "UnicodeTranscode", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Computes inverse hyperbolic sine of x element-wise. func Asinh(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { @@ -24464,6 +24647,33 @@ func Real(scope *Scope, input tf.Output, optional ...RealAttr) (output tf.Output return op.Output(0) } +// Sends `input` to all devices that are connected to the output. +// +// Sends `input` to all devices that are connected to the output. +// +// The graph should be constructed so that all ops connected to the output have a +// valid device assignment, and the op itself is assigned one of these devices. +// +// input: The input to the broadcast. +// output: The same as input. +// shape: The shape of the input tensor. +// +func NcclBroadcast(scope *Scope, input tf.Output, shape tf.Shape) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"shape": shape} + opspec := tf.OpSpec{ + Type: "NcclBroadcast", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // ResizeAreaAttr is an optional argument to ResizeArea. type ResizeAreaAttr func(optionalAttr) diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/lite/BUILD similarity index 76% rename from tensorflow/contrib/lite/BUILD rename to tensorflow/lite/BUILD index 787a85644c35c8..f8bb7191c4eb8f 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -5,7 +5,7 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") +load("//tensorflow/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") exports_files(glob([ "testdata/*.bin", @@ -48,7 +48,7 @@ cc_library( ":graph_info", ":memory_planner", ":simple_memory_arena", - "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/lite/c:c_api_internal", ], ) @@ -62,9 +62,9 @@ cc_test( ], deps = [ ":arena_planner", - "//tensorflow/contrib/lite/testing:util", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -74,26 +74,26 @@ cc_test( cc_library( name = "context", hdrs = ["context.h"], - deps = ["//tensorflow/contrib/lite/c:c_api_internal"], + deps = ["//tensorflow/lite/c:c_api_internal"], ) cc_library( name = "graph_info", hdrs = ["graph_info.h"], - deps = ["//tensorflow/contrib/lite/c:c_api_internal"], + deps = ["//tensorflow/lite/c:c_api_internal"], ) cc_library( name = "memory_planner", hdrs = ["memory_planner.h"], - deps = ["//tensorflow/contrib/lite/c:c_api_internal"], + deps = ["//tensorflow/lite/c:c_api_internal"], ) cc_library( name = "simple_memory_arena", srcs = ["simple_memory_arena.cc"], hdrs = ["simple_memory_arena.h"], - deps = ["//tensorflow/contrib/lite/c:c_api_internal"], + deps = ["//tensorflow/lite/c:c_api_internal"], ) cc_library( @@ -101,7 +101,7 @@ cc_library( hdrs = [ "builtin_op_data.h", ], - deps = ["//tensorflow/contrib/lite/c:c_api_internal"], + deps = ["//tensorflow/lite/c:c_api_internal"], ) cc_library( @@ -182,16 +182,16 @@ cc_library( ":simple_memory_arena", ":string", ":util", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/core/api", - "//tensorflow/contrib/lite/kernels:eigen_support", - "//tensorflow/contrib/lite/kernels:gemm_support", - "//tensorflow/contrib/lite/nnapi:nnapi_lib", - "//tensorflow/contrib/lite/profiling:profiler", - "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/core/api", + "//tensorflow/lite/kernels:eigen_support", + "//tensorflow/lite/kernels:gemm_support", + "//tensorflow/lite/nnapi:nnapi_lib", + "//tensorflow/lite/profiling:profiler", + "//tensorflow/lite/schema:schema_fbs", ] + select({ ":with_tflite_flex": [ - "//tensorflow/contrib/lite/delegates/flex:delegate", + "//tensorflow/lite/delegates/flex:delegate", ], "//conditions:default": [], }), @@ -214,7 +214,7 @@ cc_test( deps = [ ":framework", ":string_util", - "//tensorflow/contrib/lite/testing:util", + "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -227,13 +227,13 @@ cc_test( deps = [ ":framework", ":string_util", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/core/api", - "//tensorflow/contrib/lite/kernels:builtin_ops", - "//tensorflow/contrib/lite/kernels:kernel_util", - "//tensorflow/contrib/lite/kernels/internal:tensor_utils", - "//tensorflow/contrib/lite/schema:schema_fbs", - "//tensorflow/contrib/lite/testing:util", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/core/api", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/kernels:kernel_util", + "//tensorflow/lite/kernels/internal:tensor_utils", + "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -247,7 +247,7 @@ cc_test( deps = [ ":framework", ":string_util", - "//tensorflow/contrib/lite/testing:util", + "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -259,7 +259,7 @@ cc_test( srcs = ["simple_memory_arena_test.cc"], deps = [ ":simple_memory_arena", - "//tensorflow/contrib/lite/testing:util", + "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -279,10 +279,10 @@ cc_test( ], deps = [ ":framework", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/core/api", - "//tensorflow/contrib/lite/kernels:builtin_ops", - "//tensorflow/contrib/lite/testing:util", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/core/api", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -298,10 +298,10 @@ tf_cc_test( tags = ["no_windows"], # TODO(b/116667551): No weak symbols with MSVC. deps = [ ":framework", - "//tensorflow/contrib/lite/core/api", - "//tensorflow/contrib/lite/delegates/flex:delegate", - "//tensorflow/contrib/lite/kernels:builtin_ops", - "//tensorflow/contrib/lite/testing:util", + "//tensorflow/lite/core/api", + "//tensorflow/lite/delegates/flex:delegate", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -314,7 +314,7 @@ cc_test( tags = ["no_oss"], deps = [ ":framework", - "//tensorflow/contrib/lite/testing:util", + "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -324,7 +324,7 @@ cc_library( srcs = ["util.cc"], hdrs = ["util.h"], deps = [ - "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/lite/c:c_api_internal", ], ) @@ -335,7 +335,7 @@ cc_test( tags = ["no_oss"], deps = [ ":util", - "//tensorflow/contrib/lite/testing:util", + "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/contrib/lite/README.md b/tensorflow/lite/README.md similarity index 79% rename from tensorflow/contrib/lite/README.md rename to tensorflow/lite/README.md index a4b3d83efe0935..589d4f93481e50 100644 --- a/tensorflow/contrib/lite/README.md +++ b/tensorflow/lite/README.md @@ -5,4 +5,4 @@ devices. It enables low-latency inference of on-device machine learning models with a small binary size and fast performance supporting hardware acceleration. See the documentation: https://www.tensorflow.org/lite/ -Documentation edits can be made here: [tensorflow/contrib/lite/g3doc](./g3doc/) +Documentation edits can be made here: [tensorflow/lite/g3doc](./g3doc/) diff --git a/tensorflow/contrib/lite/allocation.cc b/tensorflow/lite/allocation.cc similarity index 94% rename from tensorflow/contrib/lite/allocation.cc rename to tensorflow/lite/allocation.cc index 21cb1832a7af49..f9a34322f0cbf5 100644 --- a/tensorflow/contrib/lite/allocation.cc +++ b/tensorflow/lite/allocation.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/lite/allocation.h" #include #include @@ -23,8 +23,8 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/core/api/error_reporter.h" namespace tflite { diff --git a/tensorflow/contrib/lite/allocation.h b/tensorflow/lite/allocation.h similarity index 88% rename from tensorflow/contrib/lite/allocation.h rename to tensorflow/lite/allocation.h index 182bc0977f62f1..f25d7fa232a740 100644 --- a/tensorflow/contrib/lite/allocation.h +++ b/tensorflow/lite/allocation.h @@ -14,16 +14,16 @@ limitations under the License. ==============================================================================*/ // Main abstraction controlling the tflite interpreter. // See context.h for the API for defining operations (TfLiteRegistration). -#ifndef TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ -#define TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ +#ifndef TENSORFLOW_LITE_ALLOCATION_H_ +#define TENSORFLOW_LITE_ALLOCATION_H_ #include #include #include -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/core/api/error_reporter.h" -#include "tensorflow/contrib/lite/simple_memory_arena.h" -#include "tensorflow/contrib/lite/string.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/simple_memory_arena.h" +#include "tensorflow/lite/string.h" namespace tflite { @@ -94,4 +94,4 @@ class MemoryAllocation : public Allocation { } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ +#endif // TENSORFLOW_LITE_ALLOCATION_H_ diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/lite/arena_planner.cc similarity index 99% rename from tensorflow/contrib/lite/arena_planner.cc rename to tensorflow/lite/arena_planner.cc index 02442575b3aeed..8200b6adaa1c6e 100644 --- a/tensorflow/contrib/lite/arena_planner.cc +++ b/tensorflow/lite/arena_planner.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/arena_planner.h" +#include "tensorflow/lite/arena_planner.h" #include namespace tflite { diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/lite/arena_planner.h similarity index 92% rename from tensorflow/contrib/lite/arena_planner.h rename to tensorflow/lite/arena_planner.h index 382577045b6d54..beaadaf4eff758 100644 --- a/tensorflow/contrib/lite/arena_planner.h +++ b/tensorflow/lite/arena_planner.h @@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_ARENA_PLANNER_H_ -#define TENSORFLOW_CONTRIB_LITE_ARENA_PLANNER_H_ +#ifndef TENSORFLOW_LITE_ARENA_PLANNER_H_ +#define TENSORFLOW_LITE_ARENA_PLANNER_H_ #include #include -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/graph_info.h" -#include "tensorflow/contrib/lite/memory_planner.h" -#include "tensorflow/contrib/lite/simple_memory_arena.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/graph_info.h" +#include "tensorflow/lite/memory_planner.h" +#include "tensorflow/lite/simple_memory_arena.h" namespace tflite { @@ -124,4 +124,4 @@ class ArenaPlanner : public MemoryPlanner { } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_ARENA_PLANNER_H_ +#endif // TENSORFLOW_LITE_ARENA_PLANNER_H_ diff --git a/tensorflow/contrib/lite/arena_planner_test.cc b/tensorflow/lite/arena_planner_test.cc similarity index 99% rename from tensorflow/contrib/lite/arena_planner_test.cc rename to tensorflow/lite/arena_planner_test.cc index 7d7c41289cad95..479f25cafef5c4 100644 --- a/tensorflow/contrib/lite/arena_planner_test.cc +++ b/tensorflow/lite/arena_planner_test.cc @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/arena_planner.h" +#include "tensorflow/lite/arena_planner.h" #include #include #include -#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/lite/testing/util.h" #include "tensorflow/core/platform/logging.h" namespace tflite { diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/lite/build_def.bzl similarity index 94% rename from tensorflow/contrib/lite/build_def.bzl rename to tensorflow/lite/build_def.bzl index b30c05a2e23d3e..3b0af52fb93690 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/lite/build_def.bzl @@ -38,7 +38,7 @@ def tflite_copts(): return copts -LINKER_SCRIPT = "//tensorflow/contrib/lite/java/src/main/native:version_script.lds" +LINKER_SCRIPT = "//tensorflow/lite/java/src/main/native:version_script.lds" def tflite_linkopts_unstripped(): """Defines linker flags to reduce size of TFLite binary. @@ -152,7 +152,7 @@ def tf_to_tflite(name, src, options, out): """ toco_cmdline = " ".join([ - "//tensorflow/contrib/lite/toco:toco", + "//tensorflow/lite/toco:toco", "--input_format=TENSORFLOW_GRAPHDEF", "--output_format=TFLITE", ("--input_file=$(location %s)" % src), @@ -163,7 +163,7 @@ def tf_to_tflite(name, src, options, out): srcs = [src], outs = [out], cmd = toco_cmdline, - tools = ["//tensorflow/contrib/lite/toco:toco"], + tools = ["//tensorflow/lite/toco:toco"], ) def tflite_to_json(name, src, out): @@ -176,7 +176,7 @@ def tflite_to_json(name, src, out): """ flatc = "@flatbuffers//:flatc" - schema = "//tensorflow/contrib/lite/schema:schema.fbs" + schema = "//tensorflow/lite/schema:schema.fbs" native.genrule( name = name, srcs = [schema, src], @@ -199,7 +199,7 @@ def json_to_tflite(name, src, out): """ flatc = "@flatbuffers//:flatc" - schema = "//tensorflow/contrib/lite/schema:schema_fbs" + schema = "//tensorflow/lite/schema:schema_fbs" native.genrule( name = name, srcs = [schema, src], @@ -307,11 +307,8 @@ def generated_test_models(): # bug or issue. def generated_test_models_failing(conversion_mode): if conversion_mode == "toco-flex": - # TODO(b/117328698): Fix and enable the known flex failures. return [ - "lstm", - "split", - "unpack", + "lstm", # TODO(b/117510976): Restore when lstm flex conversion works. ] return [] @@ -360,12 +357,12 @@ def gen_zip_test(name, test_name, conversion_mode, **kwargs): list above. **kwargs: tf_cc_test kwargs """ - toco = "//tensorflow/contrib/lite/toco:toco" + toco = "//tensorflow/lite/toco:toco" flags = "" if conversion_mode: # TODO(nupurgarg): Comment in when pb2lite is in open source. b/113614050. # if conversion_mode == "pb2lite": - # toco = "//tensorflow/contrib/lite/experimental/pb2lite:pb2lite" + # toco = "//tensorflow/lite/experimental/pb2lite:pb2lite" flags = "--ignore_toco_errors --run_with_flex" gen_zipped_test_file( @@ -409,8 +406,8 @@ def gen_selected_ops(name, model): model: TFLite model to interpret. """ out = name + "_registration.cc" - tool = "//tensorflow/contrib/lite/tools:generate_op_registrations" - tflite_path = "//tensorflow/contrib/lite" + tool = "//tensorflow/lite/tools:generate_op_registrations" + tflite_path = "//tensorflow/lite" native.genrule( name = name, srcs = [model], @@ -453,8 +450,8 @@ def gen_model_coverage_test(model_name, data, failure_type): "notap", ], deps = [ - "//tensorflow/contrib/lite/testing/model_coverage:model_coverage_lib", - "//tensorflow/contrib/lite/python:lite", + "//tensorflow/lite/testing/model_coverage:model_coverage_lib", + "//tensorflow/lite/python:lite", "//tensorflow/python:client_testlib", ], ) diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/lite/builtin_op_data.h similarity index 77% rename from tensorflow/contrib/lite/builtin_op_data.h rename to tensorflow/lite/builtin_op_data.h index 30901bd0fae951..b9d4284513de94 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/lite/builtin_op_data.h @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ // Compatibility shim for new location of interface definitions. -#ifndef TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ -#define TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ +#ifndef TENSORFLOW_LITE_BUILTIN_OP_DATA_H_ +#define TENSORFLOW_LITE_BUILTIN_OP_DATA_H_ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/builtin_op_data.h" -#endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ +#endif // TENSORFLOW_LITE_BUILTIN_OP_DATA_H_ diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/lite/builtin_ops.h similarity index 96% rename from tensorflow/contrib/lite/builtin_ops.h rename to tensorflow/lite/builtin_ops.h index 1b115291b338ea..b8c05f57bb59b5 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/lite/builtin_ops.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ -#define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ +#ifndef TENSORFLOW_LITE_BUILTIN_OPS_H_ +#define TENSORFLOW_LITE_BUILTIN_OPS_H_ // DO NOT EDIT MANUALLY: This file is automatically generated by // `schema/builtin_ops_header/generator.cc`. @@ -128,4 +128,4 @@ typedef enum { #ifdef __cplusplus } // extern "C" #endif // __cplusplus -#endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ +#endif // TENSORFLOW_LITE_BUILTIN_OPS_H_ diff --git a/tensorflow/contrib/lite/c/BUILD b/tensorflow/lite/c/BUILD similarity index 94% rename from tensorflow/contrib/lite/c/BUILD rename to tensorflow/lite/c/BUILD index 663eb63cad0da0..91c04a5f1fb5bb 100644 --- a/tensorflow/contrib/lite/c/BUILD +++ b/tensorflow/lite/c/BUILD @@ -13,6 +13,7 @@ cc_library( ], visibility = [ "//tensorflow/contrib/lite:__subpackages__", + "//tensorflow/lite:__subpackages__", ], ) diff --git a/tensorflow/contrib/lite/c/builtin_op_data.h b/tensorflow/lite/c/builtin_op_data.h similarity index 97% rename from tensorflow/contrib/lite/c/builtin_op_data.h rename to tensorflow/lite/c/builtin_op_data.h index c0513df9f64c94..855983d60dfd18 100644 --- a/tensorflow/contrib/lite/c/builtin_op_data.h +++ b/tensorflow/lite/c/builtin_op_data.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_ -#define TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_ +#ifndef TENSORFLOW_LITE_C_BUILTIN_OP_DATA_H_ +#define TENSORFLOW_LITE_C_BUILTIN_OP_DATA_H_ #include -#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/lite/c/c_api_internal.h" #ifdef __cplusplus extern "C" { @@ -332,4 +332,4 @@ typedef struct { } // extern "C" #endif // __cplusplus -#endif // TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_ +#endif // TENSORFLOW_LITE_C_BUILTIN_OP_DATA_H_ diff --git a/tensorflow/contrib/lite/c/builtin_op_data_test.cc b/tensorflow/lite/c/builtin_op_data_test.cc similarity index 98% rename from tensorflow/contrib/lite/c/builtin_op_data_test.cc rename to tensorflow/lite/c/builtin_op_data_test.cc index ba458b4252c53e..0e33dcd8c8447d 100644 --- a/tensorflow/contrib/lite/c/builtin_op_data_test.cc +++ b/tensorflow/lite/c/builtin_op_data_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/builtin_op_data.h" #include namespace tflite { diff --git a/tensorflow/contrib/lite/c/c_api_internal.c b/tensorflow/lite/c/c_api_internal.c similarity index 98% rename from tensorflow/contrib/lite/c/c_api_internal.c rename to tensorflow/lite/c/c_api_internal.c index 0a88b5ef7bcbe0..b131f0677467b3 100644 --- a/tensorflow/contrib/lite/c/c_api_internal.c +++ b/tensorflow/lite/c/c_api_internal.c @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/lite/c/c_api_internal.h" #ifndef TF_LITE_STATIC_MEMORY #include #include diff --git a/tensorflow/contrib/lite/c/c_api_internal.h b/tensorflow/lite/c/c_api_internal.h similarity index 95% rename from tensorflow/contrib/lite/c/c_api_internal.h rename to tensorflow/lite/c/c_api_internal.h index f6aee92f2abeb4..cbc69f804be114 100644 --- a/tensorflow/contrib/lite/c/c_api_internal.h +++ b/tensorflow/lite/c/c_api_internal.h @@ -26,8 +26,8 @@ limitations under the License. // TfLiteRegistration - the implementation of a conceptual operation. // // Some abstractions in this file are created and managed by Interpreter. -#ifndef TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_ -#define TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_ +#ifndef TENSORFLOW_LITE_C_C_API_INTERNAL_H_ +#define TENSORFLOW_LITE_C_C_API_INTERNAL_H_ #include #include @@ -456,6 +456,22 @@ typedef struct _TfLiteRegistration { int version; } TfLiteRegistration; +// The flags used in `TfLiteDelegate`. Note that this is a bitmask, so the +// values should be 1, 2, 4, 8, ...etc. +typedef enum { + kTfLiteDelegateFlagsNone = 0, + // The flag is set if the delegate can handle dynamic sized tensors. + // For example, the output shape of a `Resize` op with non-constant shape + // can only be inferred when the op is invoked. + // In this case, the Delegate is responsible for calling + // `SetTensorToDynamic` to mark the tensor as a dynamic tensor, and calling + // `ResizeTensor` when invoking the op. + // + // If the delegate isn't capable to handle dynamic tensors, this flag need + // to be set to false. + kTfLiteDelegateFlagsAllowDynamicTensors = 1 +} TfLiteDelegateFlags; + // WARNING: This is an experimental interface that is subject to change. typedef struct _TfLiteDelegate { // Data that delegate needs to identify itself. This data is owned by the @@ -490,6 +506,9 @@ typedef struct _TfLiteDelegate { // This can be null if the delegate doesn't use its own buffer. void (*FreeBufferHandle)(TfLiteContext* context, TfLiteDelegate* delegate, TfLiteBufferHandle* handle); + + // Bitmask flags. See the comments in `TfLiteDelegateFlags`. + int64_t flags; } TfLiteDelegate; // WARNING: This is an experimental interface that is subject to change. @@ -509,4 +528,4 @@ typedef struct { #ifdef __cplusplus } // extern "C" #endif // __cplusplus -#endif // TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_ +#endif // TENSORFLOW_LITE_C_C_API_INTERNAL_H_ diff --git a/tensorflow/contrib/lite/c/c_api_internal_test.cc b/tensorflow/lite/c/c_api_internal_test.cc similarity index 78% rename from tensorflow/contrib/lite/c/c_api_internal_test.cc rename to tensorflow/lite/c/c_api_internal_test.cc index af398f32075b46..e21823c41f0b43 100644 --- a/tensorflow/contrib/lite/c/c_api_internal_test.cc +++ b/tensorflow/lite/c/c_api_internal_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/lite/c/c_api_internal.h" #include namespace tflite { @@ -65,6 +65,21 @@ TEST(IntArray, TestIntArrayEqual) { TfLiteIntArrayFree(d); } +TEST(Types, TestTypeNames) { + auto type_name = [](TfLiteType t) { + return std::string(TfLiteTypeGetName(t)); + }; + EXPECT_EQ(type_name(kTfLiteNoType), "NOTYPE"); + EXPECT_EQ(type_name(kTfLiteFloat32), "FLOAT32"); + EXPECT_EQ(type_name(kTfLiteInt16), "INT16"); + EXPECT_EQ(type_name(kTfLiteInt32), "INT32"); + EXPECT_EQ(type_name(kTfLiteUInt8), "UINT8"); + EXPECT_EQ(type_name(kTfLiteInt64), "INT64"); + EXPECT_EQ(type_name(kTfLiteBool), "BOOL"); + EXPECT_EQ(type_name(kTfLiteComplex64), "COMPLEX64"); + EXPECT_EQ(type_name(kTfLiteString), "STRING"); +} + } // namespace tflite int main(int argc, char** argv) { diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/lite/context.h similarity index 79% rename from tensorflow/contrib/lite/context.h rename to tensorflow/lite/context.h index b86c2819b821d7..3d3c8c08b24e69 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/lite/context.h @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // Compatibility shim for moved header location. -#ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ -#define TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ +#ifndef TENSORFLOW_LITE_CONTEXT_H_ +#define TENSORFLOW_LITE_CONTEXT_H_ -#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/lite/c/c_api_internal.h" -#endif // TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ +#endif // TENSORFLOW_LITE_CONTEXT_H_ diff --git a/tensorflow/contrib/lite/context_util.h b/tensorflow/lite/context_util.h similarity index 89% rename from tensorflow/contrib/lite/context_util.h rename to tensorflow/lite/context_util.h index ccda4c7393dd16..68b91ea0b93e60 100644 --- a/tensorflow/contrib/lite/context_util.h +++ b/tensorflow/lite/context_util.h @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ // This provides a few C++ helpers that are useful for manipulating C structures // in C++. -#ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_ -#define TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_ +#ifndef TENSORFLOW_LITE_CONTEXT_UTIL_H_ +#define TENSORFLOW_LITE_CONTEXT_UTIL_H_ -#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/lite/c/c_api_internal.h" namespace tflite { @@ -45,4 +45,4 @@ class TfLiteIntArrayView { } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_ +#endif // TENSORFLOW_LITE_CONTEXT_UTIL_H_ diff --git a/tensorflow/contrib/lite/core/api/BUILD b/tensorflow/lite/core/api/BUILD similarity index 80% rename from tensorflow/contrib/lite/core/api/BUILD rename to tensorflow/lite/core/api/BUILD index e4500534f348f1..6a43b0322d1704 100644 --- a/tensorflow/contrib/lite/core/api/BUILD +++ b/tensorflow/lite/core/api/BUILD @@ -4,7 +4,7 @@ package( licenses(["notice"]) # Apache 2.0 -load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") +load("//tensorflow/lite:build_def.bzl", "tflite_copts") cc_library( name = "api", @@ -20,8 +20,8 @@ cc_library( ], copts = tflite_copts(), deps = [ - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/schema:schema_fbs", ], ) @@ -51,7 +51,7 @@ cc_test( srcs = ["flatbuffer_conversions_test.cc"], deps = [ ":api", - "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/lite/c:c_api_internal", "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/contrib/lite/core/api/error_reporter.cc b/tensorflow/lite/core/api/error_reporter.cc similarity index 95% rename from tensorflow/contrib/lite/core/api/error_reporter.cc rename to tensorflow/lite/core/api/error_reporter.cc index 423f83b1a9f4c9..7070eaa57c589a 100644 --- a/tensorflow/contrib/lite/core/api/error_reporter.cc +++ b/tensorflow/lite/core/api/error_reporter.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include "tensorflow/lite/core/api/error_reporter.h" #include namespace tflite { diff --git a/tensorflow/contrib/lite/core/api/error_reporter.h b/tensorflow/lite/core/api/error_reporter.h similarity index 88% rename from tensorflow/contrib/lite/core/api/error_reporter.h rename to tensorflow/lite/core/api/error_reporter.h index a2f780b003fc21..357722cc45911f 100644 --- a/tensorflow/contrib/lite/core/api/error_reporter.h +++ b/tensorflow/lite/core/api/error_reporter.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_CORE_API_ERROR_REPORTER_H_ -#define TENSORFLOW_CONTRIB_LITE_CORE_API_ERROR_REPORTER_H_ +#ifndef TENSORFLOW_LITE_CORE_API_ERROR_REPORTER_H_ +#define TENSORFLOW_LITE_CORE_API_ERROR_REPORTER_H_ #include @@ -42,4 +42,4 @@ class ErrorReporter { } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_CORE_API_ERROR_REPORTER_H_ +#endif // TENSORFLOW_LITE_CORE_API_ERROR_REPORTER_H_ diff --git a/tensorflow/contrib/lite/core/api/error_reporter_test.cc b/tensorflow/lite/core/api/error_reporter_test.cc similarity index 95% rename from tensorflow/contrib/lite/core/api/error_reporter_test.cc rename to tensorflow/lite/core/api/error_reporter_test.cc index 0463eee6be554e..4e44a6465d1ed9 100644 --- a/tensorflow/contrib/lite/core/api/error_reporter_test.cc +++ b/tensorflow/lite/core/api/error_reporter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include "tensorflow/lite/core/api/error_reporter.h" #include diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc similarity index 99% rename from tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc rename to tensorflow/lite/core/api/flatbuffer_conversions.cc index 5afde439156e70..8cd3faabb72809 100644 --- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h" +#include "tensorflow/lite/core/api/flatbuffer_conversions.h" #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/builtin_op_data.h" namespace tflite { diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h b/tensorflow/lite/core/api/flatbuffer_conversions.h similarity index 85% rename from tensorflow/contrib/lite/core/api/flatbuffer_conversions.h rename to tensorflow/lite/core/api/flatbuffer_conversions.h index c770e627fd572d..0132a431c5daad 100644 --- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h +++ b/tensorflow/lite/core/api/flatbuffer_conversions.h @@ -12,17 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_ -#define TENSORFLOW_CONTRIB_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_ +#ifndef TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_ +#define TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_ // These functions transform codes and data structures that are defined in the // flatbuffer serialization format into in-memory values that are used by the // runtime API and interpreter. -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/core/api/error_reporter.h" -#include "tensorflow/contrib/lite/core/api/op_resolver.h" -#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { @@ -65,4 +65,4 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_ +#endif // TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_ diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc b/tensorflow/lite/core/api/flatbuffer_conversions_test.cc similarity index 97% rename from tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc rename to tensorflow/lite/core/api/flatbuffer_conversions_test.cc index 8ae94e1d330c19..4d1d1b21fda106 100644 --- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions_test.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h" +#include "tensorflow/lite/core/api/flatbuffer_conversions.h" #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/builtin_op_data.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/core/api/op_resolver.cc b/tensorflow/lite/core/api/op_resolver.cc similarity index 97% rename from tensorflow/contrib/lite/core/api/op_resolver.cc rename to tensorflow/lite/core/api/op_resolver.cc index 55ee92484305c3..94d76889d07903 100644 --- a/tensorflow/contrib/lite/core/api/op_resolver.cc +++ b/tensorflow/lite/core/api/op_resolver.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/core/api/op_resolver.h" +#include "tensorflow/lite/core/api/op_resolver.h" namespace tflite { diff --git a/tensorflow/contrib/lite/core/api/op_resolver.h b/tensorflow/lite/core/api/op_resolver.h similarity index 84% rename from tensorflow/contrib/lite/core/api/op_resolver.h rename to tensorflow/lite/core/api/op_resolver.h index 5f5e6b27363b52..c8c7479f334c7c 100644 --- a/tensorflow/contrib/lite/core/api/op_resolver.h +++ b/tensorflow/lite/core/api/op_resolver.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_CORE_API_OP_RESOLVER_H_ -#define TENSORFLOW_CONTRIB_LITE_CORE_API_OP_RESOLVER_H_ +#ifndef TENSORFLOW_LITE_CORE_API_OP_RESOLVER_H_ +#define TENSORFLOW_LITE_CORE_API_OP_RESOLVER_H_ -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/core/api/error_reporter.h" -#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { @@ -44,4 +44,4 @@ TfLiteStatus GetRegistrationFromOpCode(const OperatorCode* opcode, } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_CORE_API_OP_RESOLVER_H_ +#endif // TENSORFLOW_LITE_CORE_API_OP_RESOLVER_H_ diff --git a/tensorflow/contrib/lite/core/api/op_resolver_test.cc b/tensorflow/lite/core/api/op_resolver_test.cc similarity index 99% rename from tensorflow/contrib/lite/core/api/op_resolver_test.cc rename to tensorflow/lite/core/api/op_resolver_test.cc index 167463110ed8ec..cd8d0929b64495 100644 --- a/tensorflow/contrib/lite/core/api/op_resolver_test.cc +++ b/tensorflow/lite/core/api/op_resolver_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/core/api/op_resolver.h" +#include "tensorflow/lite/core/api/op_resolver.h" #include diff --git a/tensorflow/contrib/lite/delegates/flex/BUILD b/tensorflow/lite/delegates/flex/BUILD similarity index 81% rename from tensorflow/contrib/lite/delegates/flex/BUILD rename to tensorflow/lite/delegates/flex/BUILD index cd6bf4d9a55f22..222a043a88e880 100644 --- a/tensorflow/contrib/lite/delegates/flex/BUILD +++ b/tensorflow/lite/delegates/flex/BUILD @@ -16,8 +16,9 @@ cc_library( deps = [ ":util", "//tensorflow/c:c_api_internal", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite:kernel_api", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite:string", + "//tensorflow/lite:string_util", ] + select({ "//tensorflow:android": [ "//tensorflow/core:android_tensorflow_lib_lite", @@ -35,9 +36,10 @@ tf_cc_test( srcs = ["buffer_map_test.cc"], deps = [ ":buffer_map", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:util", - "//tensorflow/contrib/lite/testing:util", + "//tensorflow/lite:framework", + "//tensorflow/lite:string_util", + "//tensorflow/lite:util", + "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -51,7 +53,9 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + ":delegate_data", ":delegate_only_runtime", + "//tensorflow/lite/c:c_api_internal", ] + select({ "//tensorflow:android": [ "//tensorflow/core:android_tensorflow_lib", @@ -79,9 +83,9 @@ cc_library( ":delegate_data", ":kernel", ":util", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite:kernel_api", - "//tensorflow/contrib/lite:util", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite:kernel_api", + "//tensorflow/lite:util", ] + select({ "//tensorflow:android": [ "//tensorflow/core:android_tensorflow_lib_lite", @@ -101,7 +105,7 @@ tf_cc_test( deps = [ ":delegate", ":test_util", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -130,10 +134,10 @@ tf_cc_test( srcs = ["delegate_data_test.cc"], deps = [ ":delegate_data", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:util", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/testing:util", + "//tensorflow/lite:framework", + "//tensorflow/lite:util", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -146,10 +150,10 @@ cc_library( ":delegate_data", ":util", "@flatbuffers", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite:kernel_api", - "//tensorflow/contrib/lite:string", - "//tensorflow/contrib/lite/kernels:kernel_util", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite:kernel_api", + "//tensorflow/lite:string", + "//tensorflow/lite/kernels:kernel_util", "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:execute", "//tensorflow/core/common_runtime/eager:tensor_handle", @@ -195,8 +199,8 @@ cc_library( hdrs = ["test_util.h"], deps = [ "//tensorflow/c:c_api_internal", - "//tensorflow/contrib/lite:string", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:string", + "//tensorflow/lite/kernels:test_util", "@com_google_absl//absl/memory", "@flatbuffers", ], @@ -208,8 +212,8 @@ cc_library( hdrs = ["util.h"], deps = [ "//tensorflow/c:c_api_internal", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite:kernel_api", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite:kernel_api", ] + select({ "//tensorflow:android": [ "//tensorflow/core:android_tensorflow_lib_lite", @@ -227,8 +231,8 @@ tf_cc_test( srcs = ["util_test.cc"], deps = [ ":util", - "//tensorflow/contrib/lite:string", - "//tensorflow/contrib/lite/testing:util", + "//tensorflow/lite:string", + "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/contrib/lite/delegates/flex/buffer_map.cc b/tensorflow/lite/delegates/flex/buffer_map.cc similarity index 58% rename from tensorflow/contrib/lite/delegates/flex/buffer_map.cc rename to tensorflow/lite/delegates/flex/buffer_map.cc index 63e39196d96a17..9a6c5e74a7b8d7 100644 --- a/tensorflow/contrib/lite/delegates/flex/buffer_map.cc +++ b/tensorflow/lite/delegates/flex/buffer_map.cc @@ -12,10 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/delegates/flex/buffer_map.h" +#include "tensorflow/lite/delegates/flex/buffer_map.h" #include "tensorflow/c/c_api_internal.h" -#include "tensorflow/contrib/lite/delegates/flex/util.h" +#include "tensorflow/lite/delegates/flex/util.h" +#include "tensorflow/lite/string.h" +#include "tensorflow/lite/string_util.h" #include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/log_memory.h" @@ -23,54 +25,101 @@ namespace tflite { namespace flex { namespace { // A tensor buffer that is allocated, deallocated and populated by TF Lite. -class TfLiteTensorBuffer : public tensorflow::TensorBuffer { +class BaseTfLiteTensorBuffer : public tensorflow::TensorBuffer { + TensorBuffer* root_buffer() override { return this; } + void FillAllocationDescription( + tensorflow::AllocationDescription* proto) const override { + tensorflow::int64 rb = size(); + proto->set_requested_bytes(rb); + proto->set_allocator_name(tensorflow::cpu_allocator()->Name()); + } + + // Prevents input forwarding from mutating this buffer. + bool OwnsMemory() const override { return false; } + + protected: + void LogAllocation() { + if (tensorflow::LogMemory::IsEnabled() && data() != nullptr) { + tensorflow::LogMemory::RecordRawAllocation( + "TfLiteTensorBuffer_New", + tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, size(), + data(), tensorflow::cpu_allocator()); + } + } + void LogDeallocation() { + if (tensorflow::LogMemory::IsEnabled() && data() != nullptr) { + tensorflow::LogMemory::RecordRawDeallocation( + "TfLiteTensorBuffer_Delete", + tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data(), + tensorflow::cpu_allocator(), false); + } + } +}; + +// A tensor buffer for most data types. Numeric types have exactly the same +// representation in TFLITE and TF, so we just need use memcpy(). +class TfLiteTensorBuffer : public BaseTfLiteTensorBuffer { public: explicit TfLiteTensorBuffer(const TfLiteTensor* tensor) { - len_ = tensor->bytes; // TODO(ahentz): if we can guarantee that TF Lite allocated tensors with // the same alignment as TensorFlow (EIGEN_MAX_ALIGN_BYTES), then we can // potentially eliminate the copy below. + len_ = tensor->bytes; data_ = tensorflow::cpu_allocator()->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len_); - if (data_ != nullptr) { - if (tensorflow::LogMemory::IsEnabled()) { - tensorflow::LogMemory::RecordRawAllocation( - "TfLiteTensorBuffer_New", - tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, len_, - data_, tensorflow::cpu_allocator()); - } + + LogAllocation(); + + if (data_) { std::memcpy(data_, tensor->data.raw, tensor->bytes); } } ~TfLiteTensorBuffer() override { - if (tensorflow::LogMemory::IsEnabled() && data_ != nullptr) { - tensorflow::LogMemory::RecordRawDeallocation( - "TfLiteTensorBuffer_Delete", - tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data_, - tensorflow::cpu_allocator(), false); - } + LogDeallocation(); tensorflow::cpu_allocator()->DeallocateRaw(data_); } void* data() const override { return data_; } size_t size() const override { return len_; } - TensorBuffer* root_buffer() override { return this; } - void FillAllocationDescription( - tensorflow::AllocationDescription* proto) const override { - tensorflow::int64 rb = size(); - proto->set_requested_bytes(rb); - proto->set_allocator_name(tensorflow::cpu_allocator()->Name()); + private: + void* data_; + size_t len_; +}; + +// A string buffer. TFLITE string tensor format is different than +// TF's so we need perform the conversion here. +class StringTfLiteTensorBuffer : public BaseTfLiteTensorBuffer { + public: + explicit StringTfLiteTensorBuffer(const TfLiteTensor* tensor) { + num_strings_ = GetStringCount(tensor->data.raw); + data_ = tensorflow::cpu_allocator()->Allocate(num_strings_); + + LogAllocation(); + + if (data_) { + string* p = data_; + for (size_t i = 0; i < num_strings_; ++p, ++i) { + auto ref = GetString(tensor->data.raw, i); + p->assign(ref.str, ref.len); + } + } } - // Prevents input forwarding from mutating this buffer. - bool OwnsMemory() const override { return false; } + ~StringTfLiteTensorBuffer() override { + LogDeallocation(); + tensorflow::cpu_allocator()->Deallocate(data_, num_strings_); + } + + void* data() const override { return data_; } + size_t size() const override { return num_strings_ * sizeof(string); } private: - void* data_; - size_t len_; + string* data_; + int num_strings_; }; + } // namespace BufferMap::BufferMap() {} @@ -81,6 +130,10 @@ bool BufferMap::HasTensor(int tensor_index) const { return id_to_tensor_.count(tensor_index) != 0; } +bool BufferMap::IsTensorFlowTensor(int tensor_index) const { + return HasTensor(tensor_index) && owned_by_tf_.count(tensor_index) > 0; +} + tensorflow::Tensor BufferMap::GetTensor(int tensor_index) const { return id_to_tensor_.at(tensor_index); } @@ -93,18 +146,25 @@ void BufferMap::SetFromTfLite(int tensor_index, const TfLiteTensor* tensor) { } // TODO(ahentz): we assume this is a new tensor and allocate a new buffer // for it. This is not always the best approach. For example, this might - // be a reallocation after resizing tensors. In that case we would be + // be a reallocation after resizing tensors. In that case it would be // preferable to somehow reuse the buffer. - auto* buf = new TfLiteTensorBuffer(tensor); + BaseTfLiteTensorBuffer* buf; + if (tensor->type == kTfLiteString) { + buf = new StringTfLiteTensorBuffer(tensor); + } else { + buf = new TfLiteTensorBuffer(tensor); + } tensorflow::Tensor t = tensorflow::TensorCApi::MakeTensor( GetTensorFlowDataType(tensor->type), shape, buf); buf->Unref(); - SetFromTensorFlow(tensor_index, std::move(t)); + id_to_tensor_[tensor_index] = std::move(t); + owned_by_tf_.erase(tensor_index); } void BufferMap::SetFromTensorFlow(int tensor_index, tensorflow::Tensor tensor) { id_to_tensor_[tensor_index] = std::move(tensor); + owned_by_tf_.insert(tensor_index); } } // namespace flex diff --git a/tensorflow/contrib/lite/delegates/flex/buffer_map.h b/tensorflow/lite/delegates/flex/buffer_map.h similarity index 61% rename from tensorflow/contrib/lite/delegates/flex/buffer_map.h rename to tensorflow/lite/delegates/flex/buffer_map.h index 4ce886568a5577..b73ed88d3789d5 100644 --- a/tensorflow/contrib/lite/delegates/flex/buffer_map.h +++ b/tensorflow/lite/delegates/flex/buffer_map.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_BUFFER_MAP_H_ -#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_BUFFER_MAP_H_ +#ifndef TENSORFLOW_LITE_DELEGATES_FLEX_BUFFER_MAP_H_ +#define TENSORFLOW_LITE_DELEGATES_FLEX_BUFFER_MAP_H_ #include -#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/core/framework/tensor.h" namespace tflite { @@ -38,12 +38,17 @@ class BufferMap { // tensorflow::Tensor. bool HasTensor(int tensor_index) const; + // Returns true if the given 'tensor_index' has a corresponding + // tensorflow::Tensor *and* the content is owned by TensorFlow (that is, the + // mapping was added by SetFromTensorFlow()). + bool IsTensorFlowTensor(int tensor_index) const; + // Returns the tensorflow::Tensor associated with the given 'tensor_index'. // Precondition: HasTensor() is true. tensorflow::Tensor GetTensor(int tensor_index) const; // Associates the given tensorflow::Tensor with the given 'tensor_index'. - // Note that tensorflow Tensors share data buffers, so this method is only a + // Note that TensorFlow Tensors share data buffers, so this method is only a // shallow copy. void SetFromTensorFlow(int tensor_index, tensorflow::Tensor tensor); @@ -52,10 +57,20 @@ class BufferMap { void SetFromTfLite(int tensor_index, const TfLiteTensor* tensor); private: + // Mapping from TL Lite tensor ID to TensorFlow's Tensor. All tensors that + // are inputs or outputs of a subgraph will be added here, irrespective of + // whether their data are managed by TF Lite or TensorFlow. std::map id_to_tensor_; + // A list of tensors that are completely managed by TensorFlow. Most of the + // time, TF Lite will populate tensors that are inputs to subgraphs, while + // TensorFlow will populate output tensors. Occasionally, however, an input + // tensor is coming from a previous subgraph and could have been populated by + // TensorFlow. This set keeps track of all input or output tensors that have + // been populated by tensorflow. + std::set owned_by_tf_; }; } // namespace flex } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_BUFFER_MAP_H_ +#endif // TENSORFLOW_LITE_DELEGATES_FLEX_BUFFER_MAP_H_ diff --git a/tensorflow/contrib/lite/delegates/flex/buffer_map_test.cc b/tensorflow/lite/delegates/flex/buffer_map_test.cc similarity index 71% rename from tensorflow/contrib/lite/delegates/flex/buffer_map_test.cc rename to tensorflow/lite/delegates/flex/buffer_map_test.cc index bb80e25e8076bb..9e8472f1e7d2c3 100644 --- a/tensorflow/contrib/lite/delegates/flex/buffer_map_test.cc +++ b/tensorflow/lite/delegates/flex/buffer_map_test.cc @@ -12,13 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/delegates/flex/buffer_map.h" +#include "tensorflow/lite/delegates/flex/buffer_map.h" #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/testing/util.h" -#include "tensorflow/contrib/lite/util.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/string_util.h" +#include "tensorflow/lite/testing/util.h" +#include "tensorflow/lite/util.h" namespace tflite { namespace flex { @@ -42,11 +43,35 @@ UniqueTfLiteTensor MakeLiteTensor(const std::vector& shape, tensor->type = typeToTfLiteType(); tensor->dims = ConvertVectorToTfLiteIntArray(shape); tensor->data.raw = nullptr; + tensor->is_variable = false; TfLiteTensorRealloc(data.size() * sizeof(T), tensor.get()); memcpy(tensor->data.raw, data.data(), data.size() * sizeof(T)); return tensor; } +template <> +UniqueTfLiteTensor MakeLiteTensor(const std::vector& shape, + const std::vector& data) { + auto tensor = UniqueTfLiteTensor(new TfLiteTensor, [](TfLiteTensor* t) { + TfLiteTensorDataFree(t); + TfLiteIntArrayFree(t->dims); + delete t; + }); + tensor->allocation_type = kTfLiteDynamic; + tensor->type = typeToTfLiteType(); + tensor->dims = ConvertVectorToTfLiteIntArray(shape); + tensor->data.raw = nullptr; + tensor->is_variable = false; + TfLiteTensorRealloc(data.size() * sizeof(string), tensor.get()); + + DynamicBuffer b; + for (const string& s : data) { + b.AddString(s.data(), s.size()); + } + b.WriteToTensor(tensor.get(), ConvertVectorToTfLiteIntArray(shape)); + return tensor; +} + template tensorflow::Tensor MakeTensor(const std::vector& shape, const std::vector& data) { @@ -93,6 +118,24 @@ TEST(BufferMapTest, SetFromTfLite) { ASSERT_THAT(GetTensorShape(out_tensor), ElementsAre(1, 2, 1, 3)); } +TEST(BufferMapTest, SetFromTfLiteString) { + BufferMap buffer_map; + + UniqueTfLiteTensor t = + MakeLiteTensor({1, 2, 1, 3}, {"", "", "", "str1", "", ""}); + buffer_map.SetFromTfLite(0, t.get()); + ASSERT_TRUE(buffer_map.HasTensor(0)); + + EXPECT_THAT(GetTensorData(buffer_map.GetTensor(0)), + ElementsAre("", "", "", "str1", "", "")); + + // Also check details of the tensor. + tensorflow::Tensor out_tensor = buffer_map.GetTensor(0); + ASSERT_EQ(out_tensor.dtype(), tensorflow::DT_STRING); + ASSERT_EQ(out_tensor.NumElements(), 6); + ASSERT_THAT(GetTensorShape(out_tensor), ElementsAre(1, 2, 1, 3)); +} + TEST(BufferMapTest, SetFromTfLiteTwice) { UniqueTfLiteTensor t1 = MakeLiteTensor({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0}); @@ -107,6 +150,20 @@ TEST(BufferMapTest, SetFromTfLiteTwice) { ElementsAre(0, 0, 0, 3, 0, 0, 1, 2)); } +TEST(BufferMapTest, SetFromTfLiteStringTwice) { + UniqueTfLiteTensor t1 = + MakeLiteTensor({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0}); + UniqueTfLiteTensor t2 = + MakeLiteTensor({1, 2, 4}, {"", "", "", "s3", "", "", "s1", "s2"}); + + BufferMap buffer_map; + buffer_map.SetFromTfLite(0, t1.get()); + buffer_map.SetFromTfLite(0, t2.get()); + + EXPECT_THAT(GetTensorData(buffer_map.GetTensor(0)), + ElementsAre("", "", "", "s3", "", "", "s1", "s2")); +} + TEST(BufferMapTest, SetFromTensorFlow) { tensorflow::Tensor t1 = MakeTensor({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0}); @@ -146,6 +203,7 @@ TEST(BufferMapTest, TfLiteOverwritesTensorFlow) { buffer_map.SetFromTensorFlow(0, t1); buffer_map.SetFromTfLite(0, t2.get()); + EXPECT_FALSE(buffer_map.IsTensorFlowTensor(0)); EXPECT_THAT(GetTensorData(buffer_map.GetTensor(0)), ElementsAre(0, 0, 0, 3, 0, 0, 1, 2)); } @@ -159,6 +217,7 @@ TEST(BufferMapTest, TensorFlowOverwritesTfLite) { buffer_map.SetFromTfLite(0, t2.get()); buffer_map.SetFromTensorFlow(0, t1); + EXPECT_TRUE(buffer_map.IsTensorFlowTensor(0)); EXPECT_THAT(GetTensorData(buffer_map.GetTensor(0)), ElementsAre(0, 0, 0, 0.123f, 0, 0)); } diff --git a/tensorflow/contrib/lite/delegates/flex/delegate.cc b/tensorflow/lite/delegates/flex/delegate.cc similarity index 90% rename from tensorflow/contrib/lite/delegates/flex/delegate.cc rename to tensorflow/lite/delegates/flex/delegate.cc index c72b0cf5138389..e7433f0c47874e 100644 --- a/tensorflow/contrib/lite/delegates/flex/delegate.cc +++ b/tensorflow/lite/delegates/flex/delegate.cc @@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/delegates/flex/delegate.h" +#include "tensorflow/lite/delegates/flex/delegate.h" #include -#include "tensorflow/contrib/lite/context_util.h" -#include "tensorflow/contrib/lite/delegates/flex/buffer_map.h" -#include "tensorflow/contrib/lite/delegates/flex/kernel.h" -#include "tensorflow/contrib/lite/delegates/flex/util.h" -#include "tensorflow/contrib/lite/util.h" +#include "tensorflow/lite/context_util.h" +#include "tensorflow/lite/delegates/flex/buffer_map.h" +#include "tensorflow/lite/delegates/flex/kernel.h" +#include "tensorflow/lite/delegates/flex/util.h" +#include "tensorflow/lite/util.h" #include "tensorflow/core/lib/core/status.h" namespace tflite { @@ -109,7 +109,8 @@ FlexDelegate::FlexDelegate(std::unique_ptr delegate_data) /*nullptr,*/ &flex::delegate::Prepare, /*CopyFromBufferHandle=*/&flex::delegate::CopyFromBufferHandle, /*CopyToBufferHandle=*/nullptr, - /*FreeBufferHandle=*/nullptr}, + /*FreeBufferHandle=*/nullptr, + /*flags=*/kTfLiteDelegateFlagsAllowDynamicTensors}, delegate_data_(std::move(delegate_data)) {} FlexDelegate::~FlexDelegate() {} diff --git a/tensorflow/contrib/lite/delegates/flex/delegate.h b/tensorflow/lite/delegates/flex/delegate.h similarity index 85% rename from tensorflow/contrib/lite/delegates/flex/delegate.h rename to tensorflow/lite/delegates/flex/delegate.h index 1017780dc75de1..018ff3e0b0e1fe 100644 --- a/tensorflow/contrib/lite/delegates/flex/delegate.h +++ b/tensorflow/lite/delegates/flex/delegate.h @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_H_ -#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_H_ +#ifndef TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_H_ +#define TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_H_ -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/delegates/flex/delegate_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/delegates/flex/delegate_data.h" namespace tflite { @@ -56,4 +56,4 @@ class FlexDelegate : public TfLiteDelegate { } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_H_ +#endif // TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_H_ diff --git a/tensorflow/contrib/lite/delegates/flex/delegate_data.cc b/tensorflow/lite/delegates/flex/delegate_data.cc similarity index 96% rename from tensorflow/contrib/lite/delegates/flex/delegate_data.cc rename to tensorflow/lite/delegates/flex/delegate_data.cc index 8f985f770cfba9..b62479a448073d 100644 --- a/tensorflow/contrib/lite/delegates/flex/delegate_data.cc +++ b/tensorflow/lite/delegates/flex/delegate_data.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/delegates/flex/delegate_data.h" +#include "tensorflow/lite/delegates/flex/delegate_data.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/contrib/lite/delegates/flex/delegate_data.h b/tensorflow/lite/delegates/flex/delegate_data.h similarity index 87% rename from tensorflow/contrib/lite/delegates/flex/delegate_data.h rename to tensorflow/lite/delegates/flex/delegate_data.h index 8d75f0b0efe758..a88cc98d03cd40 100644 --- a/tensorflow/contrib/lite/delegates/flex/delegate_data.h +++ b/tensorflow/lite/delegates/flex/delegate_data.h @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_ -#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_ +#ifndef TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_ +#define TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_ -#include "tensorflow/contrib/lite/delegates/flex/buffer_map.h" +#include "tensorflow/lite/delegates/flex/buffer_map.h" #include "tensorflow/core/common_runtime/eager/context.h" namespace tflite { @@ -49,4 +49,4 @@ class DelegateData { } // namespace flex } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_ +#endif // TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_ diff --git a/tensorflow/contrib/lite/delegates/flex/delegate_data_test.cc b/tensorflow/lite/delegates/flex/delegate_data_test.cc similarity index 90% rename from tensorflow/contrib/lite/delegates/flex/delegate_data_test.cc rename to tensorflow/lite/delegates/flex/delegate_data_test.cc index 30b10f435a2378..cd274e7cb1ccb5 100644 --- a/tensorflow/contrib/lite/delegates/flex/delegate_data_test.cc +++ b/tensorflow/lite/delegates/flex/delegate_data_test.cc @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/delegates/flex/delegate_data.h" +#include "tensorflow/lite/delegates/flex/delegate_data.h" #include #include -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/testing/util.h" namespace tflite { namespace flex { diff --git a/tensorflow/contrib/lite/delegates/flex/delegate_test.cc b/tensorflow/lite/delegates/flex/delegate_test.cc similarity index 96% rename from tensorflow/contrib/lite/delegates/flex/delegate_test.cc rename to tensorflow/lite/delegates/flex/delegate_test.cc index 1813952cef99ef..e13029d9a514e7 100644 --- a/tensorflow/contrib/lite/delegates/flex/delegate_test.cc +++ b/tensorflow/lite/delegates/flex/delegate_test.cc @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/delegates/flex/delegate.h" +#include "tensorflow/lite/delegates/flex/delegate.h" #include #include -#include "tensorflow/contrib/lite/delegates/flex/test_util.h" +#include "tensorflow/lite/delegates/flex/test_util.h" namespace tflite { namespace flex { @@ -40,8 +40,7 @@ class DelegateTest : public testing::FlexModelTest { } void ConfigureDelegate() { - ASSERT_EQ(interpreter_->ModifyGraphWithDelegate( - delegate_.get(), /*allow_dynamic_tensors=*/true), + ASSERT_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), kTfLiteOk); } diff --git a/tensorflow/contrib/lite/delegates/flex/kernel.cc b/tensorflow/lite/delegates/flex/kernel.cc similarity index 93% rename from tensorflow/contrib/lite/delegates/flex/kernel.cc rename to tensorflow/lite/delegates/flex/kernel.cc index e4f1aea990da97..c4fe142dff1051 100644 --- a/tensorflow/contrib/lite/delegates/flex/kernel.cc +++ b/tensorflow/lite/delegates/flex/kernel.cc @@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/delegates/flex/kernel.h" +#include "tensorflow/lite/delegates/flex/kernel.h" #include "flatbuffers/flexbuffers.h" // TF:flatbuffers -#include "tensorflow/contrib/lite/builtin_ops.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/context_util.h" -#include "tensorflow/contrib/lite/delegates/flex/delegate_data.h" -#include "tensorflow/contrib/lite/delegates/flex/util.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/string.h" +#include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/context_util.h" +#include "tensorflow/lite/delegates/flex/delegate_data.h" +#include "tensorflow/lite/delegates/flex/util.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/string.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/execute.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" @@ -251,7 +251,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { for (auto tensor_index : op_data->subgraph_inputs) { TfLiteTensor* tensor = &context->tensors[tensor_index]; if (!IsConstantTensor(tensor)) { - buffer_map->SetFromTfLite(tensor_index, tensor); + // If this tensor is part of an earlier TF subgraph we should not add it + // to the BufferMap again, because TF already knows about it and its + // contents are kept automatically up-to-date. + if (!buffer_map->IsTensorFlowTensor(tensor_index)) { + buffer_map->SetFromTfLite(tensor_index, tensor); + } } } diff --git a/tensorflow/contrib/lite/delegates/flex/kernel.h b/tensorflow/lite/delegates/flex/kernel.h similarity index 83% rename from tensorflow/contrib/lite/delegates/flex/kernel.h rename to tensorflow/lite/delegates/flex/kernel.h index ac9313a37bd5a3..ca74c28570f6aa 100644 --- a/tensorflow/contrib/lite/delegates/flex/kernel.h +++ b/tensorflow/lite/delegates/flex/kernel.h @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_KERNEL_H_ -#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_KERNEL_H_ +#ifndef TENSORFLOW_LITE_DELEGATES_FLEX_KERNEL_H_ +#define TENSORFLOW_LITE_DELEGATES_FLEX_KERNEL_H_ -#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/lite/c/c_api_internal.h" namespace tflite { namespace flex { @@ -31,4 +31,4 @@ TfLiteRegistration GetKernel(); } // namespace flex } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_KERNEL_H_ +#endif // TENSORFLOW_LITE_DELEGATES_FLEX_KERNEL_H_ diff --git a/tensorflow/contrib/lite/delegates/flex/kernel_test.cc b/tensorflow/lite/delegates/flex/kernel_test.cc similarity index 68% rename from tensorflow/contrib/lite/delegates/flex/kernel_test.cc rename to tensorflow/lite/delegates/flex/kernel_test.cc index 94a6f8b61ad281..4742c24bfc907e 100644 --- a/tensorflow/contrib/lite/delegates/flex/kernel_test.cc +++ b/tensorflow/lite/delegates/flex/kernel_test.cc @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/delegates/flex/kernel.h" +#include "tensorflow/lite/delegates/flex/kernel.h" #include #include -#include "tensorflow/contrib/lite/delegates/flex/delegate_data.h" -#include "tensorflow/contrib/lite/delegates/flex/test_util.h" +#include "tensorflow/lite/delegates/flex/delegate_data.h" +#include "tensorflow/lite/delegates/flex/test_util.h" namespace tflite { namespace flex { @@ -53,6 +53,7 @@ class KernelTest : public testing::FlexModelTest { template void ConfigureDelegate(T prepare_function) { delegate_.data_ = delegate_data_.get(); + delegate_.flags = kTfLiteDelegateFlagsAllowDynamicTensors; delegate_.FreeBufferHandle = nullptr; delegate_.Prepare = prepare_function; delegate_.CopyFromBufferHandle = [](TfLiteContext* context, @@ -66,8 +67,7 @@ class KernelTest : public testing::FlexModelTest { memcpy(data, values.data(), values.size()); return kTfLiteOk; }; - CHECK(interpreter_->ModifyGraphWithDelegate( - &delegate_, /*allow_dynamic_tensors=*/true) == kTfLiteOk); + CHECK(interpreter_->ModifyGraphWithDelegate(&delegate_) == kTfLiteOk); } private: @@ -100,6 +100,17 @@ TEST_F(KernelTest, FullGraph) { ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); + + // Try again with different inputs + SetShape(0, {2, 3, 1}); + SetValues(0, {2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f}); + SetShape(3, {2, 3, 1}); + SetValues(3, {2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(8), ElementsAre(3, 1)); + ASSERT_THAT(GetValues(8), ElementsAre(24.0f, 32.0f, 48.0f)); } TEST_F(KernelTest, BadTensorFlowOp) { @@ -194,29 +205,69 @@ TEST_F(KernelTest, MixedGraph) { ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); } +// We will build a complex graph where most of the ops are TF ops, but one +// of them, right in the middle is handle natively by TF Lite. This results +// in two flex subgraphs to handle the TF ops, and some of the tensors +// connect those two subgraphs directly. TEST_F(KernelTest, SplitGraph) { - AddTensors(10, {0}, {9}, kTfLiteFloat32, {3}); + std::vector a = {3.0f, 1.0f, 0.5f, -1.0f, 4.0f, -1.0f, -2.0f, 5.0f}; + std::vector b = {0.0f, 1.0f, 1.5f, 3.0f}; - AddTfOp(testing::kUnpack, {0}, {1, 2}); - AddTfOp(testing::kAdd, {1, 2}, {3}); - AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTensors(18, {0, 1}, {17}, kTfLiteFloat32, {3}); + + // Split the first input. Each branch below uses one half of it. + AddTfOp(testing::kUnpack, {0}, {2, 10}); - AddTfLiteMulOp({4, 5}, {6}); + // The left branch: l = (a0 + b0) * (a2 + b2) + (a1 + b1) * (a3 + b3) = 10 + AddTfOp(testing::kAdd, {1, 2}, {3}); // => 3, 2, 2, 2 + AddTfOp(testing::kUnpack, {3}, {4, 5}); // => 3, 2 --- 2, 2 + AddTfLiteMulOp({4, 5}, {6}); // => 6, 4 + AddTfOp(testing::kUnpack, {6}, {7, 8}); // => 6 -- 4 + AddTfOp(testing::kAdd, {7, 8}, {9}); // => 10 - AddTfOp(testing::kUnpack, {6}, {7, 8}); - AddTfOp(testing::kAdd, {7, 8}, {9}); + // The right branch: r = (a4 + a6) + (a5 + a7) = 6 + AddTfOp(testing::kUnpack, {10}, {11, 12}); // => 4, -1 --- -2, 5 + AddTfOp(testing::kAdd, {11, 12}, {13}); // => 2, 4 + AddTfOp(testing::kUnpack, {13}, {14, 15}); // => 2 --- 4 + AddTfOp(testing::kAdd, {14, 15}, {16}); // => 6 + + // The two branches added together: + AddTfOp(testing::kAdd, {9, 16}, {17}); // => 16 ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { - return GenericPrepare(context, delegate, {0, 1, 2, 4, 5}); + // All ops by #3 are TF ops, handled by the delegate. However, because #4 + // depends on the non-TF op, two subgraphs are necessary: + // TF subgraph 1: 0, 1, 2, 6, 7, 8, 9 + // TF Lite Op: 3 + // TF subgraph 2: 4, 5, 10 + return GenericPrepare(context, delegate, {0, 1, 2, 4, 5, 6, 7, 8, 9, 10}); }); SetShape(0, {2, 2, 2, 1}); - SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f}); + SetValues(0, a); + SetShape(1, {2, 2, 1}); + SetValues(1, b); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(17), ElementsAre(1)); + ASSERT_THAT(GetValues(17), ElementsAre(16.0f)); + + // Same as above but with slightly different output. + // We still expect the result to be l + r where + // l = (a0 + b0) * (a2 + b2) + (a1 + b1) * (a3 + b3) + // r = (a4 + a6) + (a5 + a7) + SetShape(0, {2, 2, 2, 1}); + SetValues(0, {4.0f, 1.0f, 1.5f, -2.0f, 2.0f, 0.0f, -2.0f, 3.0f}); + SetShape(1, {2, 2, 1}); + SetValues(1, {0.0f, 2.0f, 1.5f, 3.0f}); + // So l = (4 + 0) * (1.5 + 1.5) + (1 + 2) * (-2 + 3) = 12 + 3 = 15 + // r = (2 - 2) + (0 + 3) = 3 ASSERT_TRUE(Invoke()); - ASSERT_THAT(GetShape(9), ElementsAre(1)); - ASSERT_THAT(GetValues(9), ElementsAre(10.0f)); + ASSERT_THAT(GetShape(17), ElementsAre(1)); + ASSERT_THAT(GetValues(17), ElementsAre(18.0f)); } } // namespace diff --git a/tensorflow/contrib/lite/delegates/flex/test_util.cc b/tensorflow/lite/delegates/flex/test_util.cc similarity index 98% rename from tensorflow/contrib/lite/delegates/flex/test_util.cc rename to tensorflow/lite/delegates/flex/test_util.cc index 69c336a01a5741..08feb349e6dbf1 100644 --- a/tensorflow/contrib/lite/delegates/flex/test_util.cc +++ b/tensorflow/lite/delegates/flex/test_util.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/delegates/flex/test_util.h" +#include "tensorflow/lite/delegates/flex/test_util.h" #include "absl/memory/memory.h" #include "flatbuffers/flexbuffers.h" // TF:flatbuffers -#include "tensorflow/contrib/lite/string.h" +#include "tensorflow/lite/string.h" namespace tflite { namespace flex { diff --git a/tensorflow/contrib/lite/delegates/flex/test_util.h b/tensorflow/lite/delegates/flex/test_util.h similarity index 94% rename from tensorflow/contrib/lite/delegates/flex/test_util.h rename to tensorflow/lite/delegates/flex/test_util.h index a8c81b90a3b8dc..4d3f5ad0968ad3 100644 --- a/tensorflow/contrib/lite/delegates/flex/test_util.h +++ b/tensorflow/lite/delegates/flex/test_util.h @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_TEST_UTIL_H_ -#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_TEST_UTIL_H_ +#ifndef TENSORFLOW_LITE_DELEGATES_FLEX_TEST_UTIL_H_ +#define TENSORFLOW_LITE_DELEGATES_FLEX_TEST_UTIL_H_ #include "tensorflow/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/lite/kernels/test_util.h" namespace tflite { namespace flex { @@ -116,4 +116,4 @@ class FlexModelTest : public ::testing::Test { } // namespace flex } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_TEST_UTIL_H_ +#endif // TENSORFLOW_LITE_DELEGATES_FLEX_TEST_UTIL_H_ diff --git a/tensorflow/contrib/lite/delegates/flex/util.cc b/tensorflow/lite/delegates/flex/util.cc similarity index 98% rename from tensorflow/contrib/lite/delegates/flex/util.cc rename to tensorflow/lite/delegates/flex/util.cc index 829bc388bf4f61..c786ffa1a2150b 100644 --- a/tensorflow/contrib/lite/delegates/flex/util.cc +++ b/tensorflow/lite/delegates/flex/util.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/delegates/flex/util.h" +#include "tensorflow/lite/delegates/flex/util.h" namespace tflite { namespace flex { diff --git a/tensorflow/contrib/lite/delegates/flex/util.h b/tensorflow/lite/delegates/flex/util.h similarity index 88% rename from tensorflow/contrib/lite/delegates/flex/util.h rename to tensorflow/lite/delegates/flex/util.h index 7f910e7316e673..8aaa73d1b3e370 100644 --- a/tensorflow/contrib/lite/delegates/flex/util.h +++ b/tensorflow/lite/delegates/flex/util.h @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_UTIL_H_ -#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_UTIL_H_ +#ifndef TENSORFLOW_LITE_DELEGATES_FLEX_UTIL_H_ +#define TENSORFLOW_LITE_DELEGATES_FLEX_UTIL_H_ #include "tensorflow/c/c_api_internal.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" @@ -44,4 +44,4 @@ TfLiteType GetTensorFlowLiteType(TF_DataType); } // namespace flex } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_UTIL_H_ +#endif // TENSORFLOW_LITE_DELEGATES_FLEX_UTIL_H_ diff --git a/tensorflow/contrib/lite/delegates/flex/util_test.cc b/tensorflow/lite/delegates/flex/util_test.cc similarity index 96% rename from tensorflow/contrib/lite/delegates/flex/util_test.cc rename to tensorflow/lite/delegates/flex/util_test.cc index 5f049e7b0a0c1f..87104751b81b6a 100644 --- a/tensorflow/contrib/lite/delegates/flex/util_test.cc +++ b/tensorflow/lite/delegates/flex/util_test.cc @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/delegates/flex/util.h" +#include "tensorflow/lite/delegates/flex/util.h" #include #include #include -#include "tensorflow/contrib/lite/string.h" -#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/lite/string.h" +#include "tensorflow/lite/testing/util.h" namespace tflite { namespace flex { diff --git a/tensorflow/contrib/lite/delegates/nnapi/BUILD b/tensorflow/lite/delegates/nnapi/BUILD similarity index 58% rename from tensorflow/contrib/lite/delegates/nnapi/BUILD rename to tensorflow/lite/delegates/nnapi/BUILD index 4e7b2948fb920c..c24f0f71ac4edd 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/BUILD +++ b/tensorflow/lite/delegates/nnapi/BUILD @@ -11,11 +11,11 @@ cc_library( srcs = ["nnapi_delegate.cc"], hdrs = ["nnapi_delegate.h"], deps = [ - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:kernel_api", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/kernels:kernel_util", - "//tensorflow/contrib/lite/nnapi:nnapi_lib", + "//tensorflow/lite:framework", + "//tensorflow/lite:kernel_api", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:kernel_util", + "//tensorflow/lite/nnapi:nnapi_lib", ], ) @@ -29,9 +29,9 @@ tf_cc_test( ], deps = [ ":nnapi_delegate", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc similarity index 99% rename from tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc rename to tensorflow/lite/delegates/nnapi/nnapi_delegate.cc index d85e576284fac8..74aec27f82ac6f 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc @@ -17,14 +17,14 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/allocation.h" -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/builtin_ops.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/context_util.h" -#include "tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" +#include "tensorflow/lite/allocation.h" +#include "tensorflow/lite/builtin_op_data.h" +#include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/context_util.h" +#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/nnapi/NeuralNetworksShim.h" #ifdef __ANDROID__ #include @@ -1141,6 +1141,7 @@ class NNAPIDelegateKernel { TfLiteDelegate* NnApiDelegate() { static TfLiteDelegate delegate = { .data_ = nullptr, + .flags = kTfLiteDelegateFlagsAllowDynamicTensors, .Prepare = [](TfLiteContext* context, TfLiteDelegate* delegate) -> TfLiteStatus { // Do not check nodes_ if NN API is unavailable. diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h b/tensorflow/lite/delegates/nnapi/nnapi_delegate.h similarity index 80% rename from tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h rename to tensorflow/lite/delegates/nnapi/nnapi_delegate.h index 4852b7697432c3..099fb724292d79 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.h @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ -#define TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ +#ifndef TENSORFLOW_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ +#define TENSORFLOW_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ -#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/lite/c/c_api_internal.h" namespace tflite { @@ -28,4 +28,4 @@ namespace tflite { TfLiteDelegate* NnApiDelegate(); } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ +#endif // TENSORFLOW_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc similarity index 99% rename from tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc rename to tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc index 9626c54c7473bf..84a0a6a1d1cd0c 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h" +#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { @@ -31,7 +31,7 @@ class SingleOpModelWithNNAPI : public SingleOpModel { public: SingleOpModelWithNNAPI() { this->SetApplyDelegate([](Interpreter* interpreter) { - interpreter->ModifyGraphWithDelegate(NnApiDelegate(), false); + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); }); } }; diff --git a/tensorflow/contrib/lite/error_reporter.h b/tensorflow/lite/error_reporter.h similarity index 72% rename from tensorflow/contrib/lite/error_reporter.h rename to tensorflow/lite/error_reporter.h index 5c20eedc255ca6..38518d63321eda 100644 --- a/tensorflow/contrib/lite/error_reporter.h +++ b/tensorflow/lite/error_reporter.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // Compatibility shim for moved header location. -#ifndef TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ -#define TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ +#ifndef TENSORFLOW_LITE_ERROR_REPORTER_H_ +#define TENSORFLOW_LITE_ERROR_REPORTER_H_ -#include "tensorflow/contrib/lite/core/api/error_reporter.h" -#include "tensorflow/contrib/lite/stderr_reporter.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/stderr_reporter.h" -#endif // TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ +#endif // TENSORFLOW_LITE_ERROR_REPORTER_H_ diff --git a/tensorflow/contrib/lite/examples/android/BUILD b/tensorflow/lite/examples/android/BUILD similarity index 77% rename from tensorflow/contrib/lite/examples/android/BUILD rename to tensorflow/lite/examples/android/BUILD index d180cb478566a9..761a60314e8fb6 100644 --- a/tensorflow/contrib/lite/examples/android/BUILD +++ b/tensorflow/lite/examples/android/BUILD @@ -33,14 +33,14 @@ android_binary( # Remove undesired models (and corresponding Activities in source) # to reduce APK size. assets = [ - "//tensorflow/contrib/lite/examples/android/app/src/main/assets:labels_mobilenet_quant_v1_224.txt", + "//tensorflow/lite/examples/android/app/src/main/assets:labels_mobilenet_quant_v1_224.txt", "@tflite_mobilenet//:mobilenet_quant_v1_224.tflite", "@tflite_conv_actions_frozen//:conv_actions_frozen.tflite", - "//tensorflow/contrib/lite/examples/android/app/src/main/assets:conv_actions_labels.txt", + "//tensorflow/lite/examples/android/app/src/main/assets:conv_actions_labels.txt", "@tflite_mobilenet_ssd//:mobilenet_ssd.tflite", "@tflite_mobilenet_ssd_quant//:detect.tflite", - "//tensorflow/contrib/lite/examples/android/app/src/main/assets:box_priors.txt", - "//tensorflow/contrib/lite/examples/android/app/src/main/assets:coco_labels_list.txt", + "//tensorflow/lite/examples/android/app/src/main/assets:box_priors.txt", + "//tensorflow/lite/examples/android/app/src/main/assets:coco_labels_list.txt", ], assets_dir = "", custom_package = "org.tensorflow.lite.demo", @@ -56,6 +56,6 @@ android_binary( ], deps = [ ":tensorflow_native_libs", - "//tensorflow/contrib/lite/java:tensorflowlite", + "//tensorflow/lite/java:tensorflowlite", ], ) diff --git a/tensorflow/contrib/lite/examples/android/android.iml b/tensorflow/lite/examples/android/android.iml similarity index 100% rename from tensorflow/contrib/lite/examples/android/android.iml rename to tensorflow/lite/examples/android/android.iml diff --git a/tensorflow/contrib/lite/examples/android/app/README.md b/tensorflow/lite/examples/android/app/README.md similarity index 94% rename from tensorflow/contrib/lite/examples/android/app/README.md rename to tensorflow/lite/examples/android/app/README.md index 7347147f997540..e2b1b2691bb926 100644 --- a/tensorflow/contrib/lite/examples/android/app/README.md +++ b/tensorflow/lite/examples/android/app/README.md @@ -43,12 +43,12 @@ for our external and internal code to merge. ```shell bazel build -c opt --cxxopt='--std=c++11' --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \ - //tensorflow/contrib/lite/examples/android:tflite_demo + //tensorflow/lite/examples/android:tflite_demo ``` 3. Install the demo on a [debug-enabled device](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install): ```shell - adb install bazel-bin/tensorflow/contrib/lite/examples/android/tflite_demo.apk + adb install bazel-bin/tensorflow/lite/examples/android/tflite_demo.apk ``` diff --git a/tensorflow/contrib/lite/examples/android/app/build.gradle b/tensorflow/lite/examples/android/app/build.gradle similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/build.gradle rename to tensorflow/lite/examples/android/app/build.gradle diff --git a/tensorflow/contrib/lite/examples/android/app/download-models.gradle b/tensorflow/lite/examples/android/app/download-models.gradle similarity index 96% rename from tensorflow/contrib/lite/examples/android/app/download-models.gradle rename to tensorflow/lite/examples/android/app/download-models.gradle index c100e37c16f38a..d2f03db5f6373b 100644 --- a/tensorflow/contrib/lite/examples/android/app/download-models.gradle +++ b/tensorflow/lite/examples/android/app/download-models.gradle @@ -14,7 +14,7 @@ def models = ['conv_actions_tflite.zip', 'mobilenet_ssd_tflite_v1.zip', 'mobilenet_v1_224_android_quant_2017_11_08.zip', 'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip'] -// LINT.ThenChange(//tensorflow/contrib/lite/examples/android/BUILD) +// LINT.ThenChange(//tensorflow/lite/examples/android/BUILD) // Root URL for model archives def MODEL_URL = 'https://storage.googleapis.com/download.tensorflow.org/models/tflite' diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/AndroidManifest.xml b/tensorflow/lite/examples/android/app/src/main/AndroidManifest.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/AndroidManifest.xml rename to tensorflow/lite/examples/android/app/src/main/AndroidManifest.xml diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/assets/BUILD b/tensorflow/lite/examples/android/app/src/main/assets/BUILD similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/assets/BUILD rename to tensorflow/lite/examples/android/app/src/main/assets/BUILD diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/assets/box_priors.txt b/tensorflow/lite/examples/android/app/src/main/assets/box_priors.txt similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/assets/box_priors.txt rename to tensorflow/lite/examples/android/app/src/main/assets/box_priors.txt diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/assets/coco_labels_list.txt b/tensorflow/lite/examples/android/app/src/main/assets/coco_labels_list.txt similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/assets/coco_labels_list.txt rename to tensorflow/lite/examples/android/app/src/main/assets/coco_labels_list.txt diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/assets/conv_actions_labels.txt b/tensorflow/lite/examples/android/app/src/main/assets/conv_actions_labels.txt similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/assets/conv_actions_labels.txt rename to tensorflow/lite/examples/android/app/src/main/assets/conv_actions_labels.txt diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/assets/labels_mobilenet_quant_v1_224.txt b/tensorflow/lite/examples/android/app/src/main/assets/labels_mobilenet_quant_v1_224.txt similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/assets/labels_mobilenet_quant_v1_224.txt rename to tensorflow/lite/examples/android/app/src/main/assets/labels_mobilenet_quant_v1_224.txt diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/assets/pets_labels_list.txt b/tensorflow/lite/examples/android/app/src/main/assets/pets_labels_list.txt similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/assets/pets_labels_list.txt rename to tensorflow/lite/examples/android/app/src/main/assets/pets_labels_list.txt diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/AutoFitTextureView.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/AutoFitTextureView.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/AutoFitTextureView.java rename to tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/AutoFitTextureView.java diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraActivity.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraActivity.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraActivity.java rename to tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraActivity.java diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraConnectionFragment.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraConnectionFragment.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraConnectionFragment.java rename to tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraConnectionFragment.java diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/Classifier.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/Classifier.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/Classifier.java rename to tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/Classifier.java diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ClassifierActivity.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/ClassifierActivity.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ClassifierActivity.java rename to tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/ClassifierActivity.java diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java rename to tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/LegacyCameraConnectionFragment.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/LegacyCameraConnectionFragment.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/LegacyCameraConnectionFragment.java rename to tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/LegacyCameraConnectionFragment.java diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/OverlayView.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/OverlayView.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/OverlayView.java rename to tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/OverlayView.java diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognitionScoreView.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognitionScoreView.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognitionScoreView.java rename to tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognitionScoreView.java diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognizeCommands.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognizeCommands.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognizeCommands.java rename to tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognizeCommands.java diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ResultsView.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/ResultsView.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ResultsView.java rename to tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/ResultsView.java diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/SpeechActivity.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/SpeechActivity.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/SpeechActivity.java rename to tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/SpeechActivity.java diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteImageClassifier.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteImageClassifier.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteImageClassifier.java rename to tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteImageClassifier.java diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java rename to tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/AssetUtils.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/AssetUtils.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/AssetUtils.java rename to tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/AssetUtils.java diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/BorderedText.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/BorderedText.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/BorderedText.java rename to tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/BorderedText.java diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/ImageUtils.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/ImageUtils.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/ImageUtils.java rename to tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/ImageUtils.java diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Logger.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Logger.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Logger.java rename to tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Logger.java diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Size.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Size.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Size.java rename to tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Size.java diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/SplitTimer.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/SplitTimer.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/SplitTimer.java rename to tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/SplitTimer.java diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/MultiBoxTracker.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/MultiBoxTracker.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/MultiBoxTracker.java rename to tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/MultiBoxTracker.java diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/ObjectTracker.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/ObjectTracker.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/ObjectTracker.java rename to tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/ObjectTracker.java diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/animator/color_animation.xml b/tensorflow/lite/examples/android/app/src/main/res/animator/color_animation.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/animator/color_animation.xml rename to tensorflow/lite/examples/android/app/src/main/res/animator/color_animation.xml diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_action_info.png b/tensorflow/lite/examples/android/app/src/main/res/drawable-hdpi/ic_action_info.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_action_info.png rename to tensorflow/lite/examples/android/app/src/main/res/drawable-hdpi/ic_action_info.png diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_launcher.png b/tensorflow/lite/examples/android/app/src/main/res/drawable-hdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_launcher.png rename to tensorflow/lite/examples/android/app/src/main/res/drawable-hdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/tile.9.png b/tensorflow/lite/examples/android/app/src/main/res/drawable-hdpi/tile.9.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/tile.9.png rename to tensorflow/lite/examples/android/app/src/main/res/drawable-hdpi/tile.9.png diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_action_info.png b/tensorflow/lite/examples/android/app/src/main/res/drawable-mdpi/ic_action_info.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_action_info.png rename to tensorflow/lite/examples/android/app/src/main/res/drawable-mdpi/ic_action_info.png diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_launcher.png b/tensorflow/lite/examples/android/app/src/main/res/drawable-mdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_launcher.png rename to tensorflow/lite/examples/android/app/src/main/res/drawable-mdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_action_info.png b/tensorflow/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_action_info.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_action_info.png rename to tensorflow/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_action_info.png diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_launcher.png b/tensorflow/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_launcher.png rename to tensorflow/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_action_info.png b/tensorflow/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_action_info.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_action_info.png rename to tensorflow/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_action_info.png diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_launcher.png b/tensorflow/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_launcher.png rename to tensorflow/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable/border.xml b/tensorflow/lite/examples/android/app/src/main/res/drawable/border.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/drawable/border.xml rename to tensorflow/lite/examples/android/app/src/main/res/drawable/border.xml diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_camera.xml b/tensorflow/lite/examples/android/app/src/main/res/layout/activity_camera.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_camera.xml rename to tensorflow/lite/examples/android/app/src/main/res/layout/activity_camera.xml diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_speech.xml b/tensorflow/lite/examples/android/app/src/main/res/layout/activity_speech.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_speech.xml rename to tensorflow/lite/examples/android/app/src/main/res/layout/activity_speech.xml diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment.xml b/tensorflow/lite/examples/android/app/src/main/res/layout/camera_connection_fragment.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment.xml rename to tensorflow/lite/examples/android/app/src/main/res/layout/camera_connection_fragment.xml diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_stylize.xml b/tensorflow/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_stylize.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_stylize.xml rename to tensorflow/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_stylize.xml diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_tracking.xml b/tensorflow/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_tracking.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_tracking.xml rename to tensorflow/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_tracking.xml diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/list_text_item.xml b/tensorflow/lite/examples/android/app/src/main/res/layout/list_text_item.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/layout/list_text_item.xml rename to tensorflow/lite/examples/android/app/src/main/res/layout/list_text_item.xml diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-dimens.xml b/tensorflow/lite/examples/android/app/src/main/res/values-sw600dp/template-dimens.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-dimens.xml rename to tensorflow/lite/examples/android/app/src/main/res/values-sw600dp/template-dimens.xml diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-styles.xml b/tensorflow/lite/examples/android/app/src/main/res/values-sw600dp/template-styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-styles.xml rename to tensorflow/lite/examples/android/app/src/main/res/values-sw600dp/template-styles.xml diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/styles.xml b/tensorflow/lite/examples/android/app/src/main/res/values-v11/styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/styles.xml rename to tensorflow/lite/examples/android/app/src/main/res/values-v11/styles.xml diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/template-styles.xml b/tensorflow/lite/examples/android/app/src/main/res/values-v11/template-styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/template-styles.xml rename to tensorflow/lite/examples/android/app/src/main/res/values-v11/template-styles.xml diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v14/styles.xml b/tensorflow/lite/examples/android/app/src/main/res/values-v14/styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/values-v14/styles.xml rename to tensorflow/lite/examples/android/app/src/main/res/values-v14/styles.xml diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-colors.xml b/tensorflow/lite/examples/android/app/src/main/res/values-v21/base-colors.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-colors.xml rename to tensorflow/lite/examples/android/app/src/main/res/values-v21/base-colors.xml diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-template-styles.xml b/tensorflow/lite/examples/android/app/src/main/res/values-v21/base-template-styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-template-styles.xml rename to tensorflow/lite/examples/android/app/src/main/res/values-v21/base-template-styles.xml diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/values/attrs.xml b/tensorflow/lite/examples/android/app/src/main/res/values/attrs.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/values/attrs.xml rename to tensorflow/lite/examples/android/app/src/main/res/values/attrs.xml diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/values/base-strings.xml b/tensorflow/lite/examples/android/app/src/main/res/values/base-strings.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/values/base-strings.xml rename to tensorflow/lite/examples/android/app/src/main/res/values/base-strings.xml diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/values/colors.xml b/tensorflow/lite/examples/android/app/src/main/res/values/colors.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/values/colors.xml rename to tensorflow/lite/examples/android/app/src/main/res/values/colors.xml diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/values/strings.xml b/tensorflow/lite/examples/android/app/src/main/res/values/strings.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/values/strings.xml rename to tensorflow/lite/examples/android/app/src/main/res/values/strings.xml diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/values/styles.xml b/tensorflow/lite/examples/android/app/src/main/res/values/styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/values/styles.xml rename to tensorflow/lite/examples/android/app/src/main/res/values/styles.xml diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-dimens.xml b/tensorflow/lite/examples/android/app/src/main/res/values/template-dimens.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-dimens.xml rename to tensorflow/lite/examples/android/app/src/main/res/values/template-dimens.xml diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-styles.xml b/tensorflow/lite/examples/android/app/src/main/res/values/template-styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-styles.xml rename to tensorflow/lite/examples/android/app/src/main/res/values/template-styles.xml diff --git a/tensorflow/contrib/lite/examples/android/build.gradle b/tensorflow/lite/examples/android/build.gradle similarity index 100% rename from tensorflow/contrib/lite/examples/android/build.gradle rename to tensorflow/lite/examples/android/build.gradle diff --git a/tensorflow/contrib/lite/examples/android/settings.gradle b/tensorflow/lite/examples/android/settings.gradle similarity index 100% rename from tensorflow/contrib/lite/examples/android/settings.gradle rename to tensorflow/lite/examples/android/settings.gradle diff --git a/tensorflow/contrib/lite/examples/ios/camera/.gitignore b/tensorflow/lite/examples/ios/camera/.gitignore similarity index 100% rename from tensorflow/contrib/lite/examples/ios/camera/.gitignore rename to tensorflow/lite/examples/ios/camera/.gitignore diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleAppDelegate.h b/tensorflow/lite/examples/ios/camera/CameraExampleAppDelegate.h similarity index 100% rename from tensorflow/contrib/lite/examples/ios/camera/CameraExampleAppDelegate.h rename to tensorflow/lite/examples/ios/camera/CameraExampleAppDelegate.h diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleAppDelegate.m b/tensorflow/lite/examples/ios/camera/CameraExampleAppDelegate.m similarity index 100% rename from tensorflow/contrib/lite/examples/ios/camera/CameraExampleAppDelegate.m rename to tensorflow/lite/examples/ios/camera/CameraExampleAppDelegate.m diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.h b/tensorflow/lite/examples/ios/camera/CameraExampleViewController.h similarity index 94% rename from tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.h rename to tensorflow/lite/examples/ios/camera/CameraExampleViewController.h index fb5800e86d365b..6bc94e950220b9 100644 --- a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.h +++ b/tensorflow/lite/examples/ios/camera/CameraExampleViewController.h @@ -17,8 +17,8 @@ #include -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" @interface CameraExampleViewController : UIViewController { diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm b/tensorflow/lite/examples/ios/camera/CameraExampleViewController.mm similarity index 99% rename from tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm rename to tensorflow/lite/examples/ios/camera/CameraExampleViewController.mm index 996cff26162021..1e6725592b0c6b 100644 --- a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm +++ b/tensorflow/lite/examples/ios/camera/CameraExampleViewController.mm @@ -23,10 +23,10 @@ #include #include -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/string_util.h" -#include "tensorflow/contrib/lite/op_resolver.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/string_util.h" +#include "tensorflow/lite/op_resolver.h" #define LOG(x) std::cerr diff --git a/tensorflow/contrib/lite/examples/ios/camera/Info.plist b/tensorflow/lite/examples/ios/camera/Info.plist similarity index 100% rename from tensorflow/contrib/lite/examples/ios/camera/Info.plist rename to tensorflow/lite/examples/ios/camera/Info.plist diff --git a/tensorflow/contrib/lite/examples/ios/camera/MainStoryboard_iPhone.storyboard b/tensorflow/lite/examples/ios/camera/MainStoryboard_iPhone.storyboard similarity index 100% rename from tensorflow/contrib/lite/examples/ios/camera/MainStoryboard_iPhone.storyboard rename to tensorflow/lite/examples/ios/camera/MainStoryboard_iPhone.storyboard diff --git a/tensorflow/contrib/lite/examples/ios/camera/Podfile b/tensorflow/lite/examples/ios/camera/Podfile similarity index 100% rename from tensorflow/contrib/lite/examples/ios/camera/Podfile rename to tensorflow/lite/examples/ios/camera/Podfile diff --git a/tensorflow/contrib/lite/examples/ios/camera/README.md b/tensorflow/lite/examples/ios/camera/README.md similarity index 100% rename from tensorflow/contrib/lite/examples/ios/camera/README.md rename to tensorflow/lite/examples/ios/camera/README.md diff --git a/tensorflow/contrib/lite/examples/ios/camera/data/.gitignore b/tensorflow/lite/examples/ios/camera/data/.gitignore similarity index 100% rename from tensorflow/contrib/lite/examples/ios/camera/data/.gitignore rename to tensorflow/lite/examples/ios/camera/data/.gitignore diff --git a/tensorflow/contrib/lite/examples/ios/camera/main.mm b/tensorflow/lite/examples/ios/camera/main.mm similarity index 100% rename from tensorflow/contrib/lite/examples/ios/camera/main.mm rename to tensorflow/lite/examples/ios/camera/main.mm diff --git a/tensorflow/contrib/lite/examples/ios/camera/tflite_camera_example.xcodeproj/project.pbxproj b/tensorflow/lite/examples/ios/camera/tflite_camera_example.xcodeproj/project.pbxproj similarity index 100% rename from tensorflow/contrib/lite/examples/ios/camera/tflite_camera_example.xcodeproj/project.pbxproj rename to tensorflow/lite/examples/ios/camera/tflite_camera_example.xcodeproj/project.pbxproj diff --git a/tensorflow/contrib/lite/examples/ios/camera/tflite_camera_example_with_select_tf_ops.xcodeproj/project.pbxproj b/tensorflow/lite/examples/ios/camera/tflite_camera_example_with_select_tf_ops.xcodeproj/project.pbxproj similarity index 100% rename from tensorflow/contrib/lite/examples/ios/camera/tflite_camera_example_with_select_tf_ops.xcodeproj/project.pbxproj rename to tensorflow/lite/examples/ios/camera/tflite_camera_example_with_select_tf_ops.xcodeproj/project.pbxproj diff --git a/tensorflow/contrib/lite/examples/ios/download_models.sh b/tensorflow/lite/examples/ios/download_models.sh similarity index 100% rename from tensorflow/contrib/lite/examples/ios/download_models.sh rename to tensorflow/lite/examples/ios/download_models.sh diff --git a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h b/tensorflow/lite/examples/ios/simple/AppDelegate.h similarity index 100% rename from tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h rename to tensorflow/lite/examples/ios/simple/AppDelegate.h diff --git a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm b/tensorflow/lite/examples/ios/simple/AppDelegate.mm similarity index 100% rename from tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm rename to tensorflow/lite/examples/ios/simple/AppDelegate.mm diff --git a/tensorflow/contrib/lite/examples/ios/simple/Podfile b/tensorflow/lite/examples/ios/simple/Podfile similarity index 100% rename from tensorflow/contrib/lite/examples/ios/simple/Podfile rename to tensorflow/lite/examples/ios/simple/Podfile diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModel-Info.plist b/tensorflow/lite/examples/ios/simple/RunModel-Info.plist similarity index 100% rename from tensorflow/contrib/lite/examples/ios/simple/RunModel-Info.plist rename to tensorflow/lite/examples/ios/simple/RunModel-Info.plist diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h b/tensorflow/lite/examples/ios/simple/RunModelViewController.h similarity index 100% rename from tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h rename to tensorflow/lite/examples/ios/simple/RunModelViewController.h diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm b/tensorflow/lite/examples/ios/simple/RunModelViewController.mm similarity index 97% rename from tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm rename to tensorflow/lite/examples/ios/simple/RunModelViewController.mm index 650c73f7322c31..e5764944f66507 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm +++ b/tensorflow/lite/examples/ios/simple/RunModelViewController.mm @@ -22,10 +22,10 @@ #include #include -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/string_util.h" -#include "tensorflow/contrib/lite/op_resolver.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/string_util.h" +#include "tensorflow/lite/op_resolver.h" #include "ios_image_load.h" diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.xib b/tensorflow/lite/examples/ios/simple/RunModelViewController.xib similarity index 100% rename from tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.xib rename to tensorflow/lite/examples/ios/simple/RunModelViewController.xib diff --git a/tensorflow/contrib/lite/examples/ios/simple/data/grace_hopper.jpg b/tensorflow/lite/examples/ios/simple/data/grace_hopper.jpg similarity index 100% rename from tensorflow/contrib/lite/examples/ios/simple/data/grace_hopper.jpg rename to tensorflow/lite/examples/ios/simple/data/grace_hopper.jpg diff --git a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h b/tensorflow/lite/examples/ios/simple/ios_image_load.h similarity index 78% rename from tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h rename to tensorflow/lite/examples/ios/simple/ios_image_load.h index 96d28109375a71..74c6cf3c7b1ac6 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h +++ b/tensorflow/lite/examples/ios/simple/ios_image_load.h @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_IOS_SIMPLE_IOS_IMAGE_LOAD_H_ -#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_IOS_SIMPLE_IOS_IMAGE_LOAD_H_ +#ifndef TENSORFLOW_LITE_EXAMPLES_IOS_SIMPLE_IOS_IMAGE_LOAD_H_ +#define TENSORFLOW_LITE_EXAMPLES_IOS_SIMPLE_IOS_IMAGE_LOAD_H_ #include std::vector LoadImageFromFile(const char* file_name, int* out_width, int* out_height, int* out_channels); -#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_IOS_SIMPLE_IOS_IMAGE_LOAD_H_ +#endif // TENSORFLOW_LITE_EXAMPLES_IOS_SIMPLE_IOS_IMAGE_LOAD_H_ diff --git a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm b/tensorflow/lite/examples/ios/simple/ios_image_load.mm similarity index 100% rename from tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm rename to tensorflow/lite/examples/ios/simple/ios_image_load.mm diff --git a/tensorflow/contrib/lite/examples/ios/simple/main.mm b/tensorflow/lite/examples/ios/simple/main.mm similarity index 100% rename from tensorflow/contrib/lite/examples/ios/simple/main.mm rename to tensorflow/lite/examples/ios/simple/main.mm diff --git a/tensorflow/contrib/lite/examples/ios/simple/simple.xcodeproj/project.pbxproj b/tensorflow/lite/examples/ios/simple/simple.xcodeproj/project.pbxproj similarity index 100% rename from tensorflow/contrib/lite/examples/ios/simple/simple.xcodeproj/project.pbxproj rename to tensorflow/lite/examples/ios/simple/simple.xcodeproj/project.pbxproj diff --git a/tensorflow/contrib/lite/examples/label_image/BUILD b/tensorflow/lite/examples/label_image/BUILD similarity index 67% rename from tensorflow/contrib/lite/examples/label_image/BUILD rename to tensorflow/lite/examples/label_image/BUILD index fc55a78019b4a1..de1bfd70532565 100644 --- a/tensorflow/contrib/lite/examples/label_image/BUILD +++ b/tensorflow/lite/examples/label_image/BUILD @@ -6,7 +6,7 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts") +load("//tensorflow/lite:build_def.bzl", "tflite_linkopts") exports_files(glob([ "testdata/*.bmp", @@ -28,9 +28,9 @@ tf_cc_binary( }), deps = [ ":bitmap_helpers", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:string_util", - "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/lite:framework", + "//tensorflow/lite:string_util", + "//tensorflow/lite/kernels:builtin_ops", ], ) @@ -43,13 +43,13 @@ cc_library( "label_image.h", ], deps = [ - "//tensorflow/contrib/lite:builtin_op_data", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:schema_fbs_version", - "//tensorflow/contrib/lite:string", - "//tensorflow/contrib/lite:string_util", - "//tensorflow/contrib/lite/kernels:builtin_ops", - "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/lite:builtin_op_data", + "//tensorflow/lite:framework", + "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite:string", + "//tensorflow/lite:string_util", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/schema:schema_fbs", ], ) diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc b/tensorflow/lite/examples/label_image/bitmap_helpers.cc similarity index 98% rename from tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc rename to tensorflow/lite/examples/label_image/bitmap_helpers.cc index 2735d1f5ea4e2a..0adad68ddca892 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc +++ b/tensorflow/lite/examples/label_image/bitmap_helpers.cc @@ -21,7 +21,7 @@ limitations under the License. #include // NOLINT(build/include_order) -#include "tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h" +#include "tensorflow/lite/examples/label_image/bitmap_helpers.h" #define LOG(x) std::cerr diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h b/tensorflow/lite/examples/label_image/bitmap_helpers.h similarity index 79% rename from tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h rename to tensorflow/lite/examples/label_image/bitmap_helpers.h index 7881ee80cad432..05209963a16c12 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h +++ b/tensorflow/lite/examples/label_image/bitmap_helpers.h @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H_ -#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H_ +#ifndef TENSORFLOW_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H_ +#define TENSORFLOW_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H_ -#include "tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h" -#include "tensorflow/contrib/lite/examples/label_image/label_image.h" +#include "tensorflow/lite/examples/label_image/bitmap_helpers_impl.h" +#include "tensorflow/lite/examples/label_image/label_image.h" namespace tflite { namespace label_image { @@ -39,4 +39,4 @@ template void resize(float*, unsigned char*, int, int, int, int, int, } // namespace label_image } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H_ +#endif // TENSORFLOW_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H_ diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h b/tensorflow/lite/examples/label_image/bitmap_helpers_impl.h similarity index 84% rename from tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h rename to tensorflow/lite/examples/label_image/bitmap_helpers_impl.h index 21ad39a6bf75e5..b581d807734213 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h +++ b/tensorflow/lite/examples/label_image/bitmap_helpers_impl.h @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_ -#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_ +#ifndef TENSORFLOW_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_ +#define TENSORFLOW_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_ -#include "tensorflow/contrib/lite/examples/label_image/label_image.h" +#include "tensorflow/lite/examples/label_image/label_image.h" -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/string_util.h" -#include "tensorflow/contrib/lite/version.h" +#include "tensorflow/lite/builtin_op_data.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/string_util.h" +#include "tensorflow/lite/version.h" namespace tflite { namespace label_image { @@ -93,4 +93,4 @@ void resize(T* out, uint8_t* in, int image_height, int image_width, } // namespace label_image } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_ +#endif // TENSORFLOW_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_ diff --git a/tensorflow/contrib/lite/examples/label_image/get_top_n.h b/tensorflow/lite/examples/label_image/get_top_n.h similarity index 82% rename from tensorflow/contrib/lite/examples/label_image/get_top_n.h rename to tensorflow/lite/examples/label_image/get_top_n.h index adef434c00a680..47fea2f775826d 100644 --- a/tensorflow/contrib/lite/examples/label_image/get_top_n.h +++ b/tensorflow/lite/examples/label_image/get_top_n.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_H_ -#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_H_ +#ifndef TENSORFLOW_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_H_ +#define TENSORFLOW_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_H_ -#include "tensorflow/contrib/lite/examples/label_image/get_top_n_impl.h" +#include "tensorflow/lite/examples/label_image/get_top_n_impl.h" namespace tflite { namespace label_image { @@ -35,4 +35,4 @@ template void get_top_n(float*, int, size_t, float, } // namespace label_image } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_H_ +#endif // TENSORFLOW_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_H_ diff --git a/tensorflow/contrib/lite/examples/label_image/get_top_n_impl.h b/tensorflow/lite/examples/label_image/get_top_n_impl.h similarity index 90% rename from tensorflow/contrib/lite/examples/label_image/get_top_n_impl.h rename to tensorflow/lite/examples/label_image/get_top_n_impl.h index 708cf2f2b1cab9..563ac09114c234 100644 --- a/tensorflow/contrib/lite/examples/label_image/get_top_n_impl.h +++ b/tensorflow/lite/examples/label_image/get_top_n_impl.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_IMPL_H_ -#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_IMPL_H_ +#ifndef TENSORFLOW_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_IMPL_H_ +#define TENSORFLOW_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_IMPL_H_ #include #include @@ -67,4 +67,4 @@ void get_top_n(T* prediction, int prediction_size, size_t num_results, } // namespace label_image } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_IMPL_H_ +#endif // TENSORFLOW_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_IMPL_H_ diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/lite/examples/label_image/label_image.cc similarity index 97% rename from tensorflow/contrib/lite/examples/label_image/label_image.cc rename to tensorflow/lite/examples/label_image/label_image.cc index 7c6f523041ad5a..b8dc2840dfb49f 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.cc +++ b/tensorflow/lite/examples/label_image/label_image.cc @@ -32,13 +32,13 @@ limitations under the License. #include // NOLINT(build/include_order) #include // NOLINT(build/include_order) -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/optional_debug_tools.h" -#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/optional_debug_tools.h" +#include "tensorflow/lite/string_util.h" -#include "tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h" -#include "tensorflow/contrib/lite/examples/label_image/get_top_n.h" +#include "tensorflow/lite/examples/label_image/bitmap_helpers.h" +#include "tensorflow/lite/examples/label_image/get_top_n.h" #define LOG(x) std::cerr diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.h b/tensorflow/lite/examples/label_image/label_image.h similarity index 82% rename from tensorflow/contrib/lite/examples/label_image/label_image.h rename to tensorflow/lite/examples/label_image/label_image.h index f0be881b58573a..88b047fecc4b3e 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.h +++ b/tensorflow/lite/examples/label_image/label_image.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H_ -#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H_ +#ifndef TENSORFLOW_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H_ +#define TENSORFLOW_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H_ -#include "tensorflow/contrib/lite/string.h" +#include "tensorflow/lite/string.h" namespace tflite { namespace label_image { @@ -40,4 +40,4 @@ struct Settings { } // namespace label_image } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H_ +#endif // TENSORFLOW_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H_ diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.md b/tensorflow/lite/examples/label_image/label_image.md similarity index 84% rename from tensorflow/contrib/lite/examples/label_image/label_image.md rename to tensorflow/lite/examples/label_image/label_image.md index 9ce32cf101897f..fd9f49918b4494 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.md +++ b/tensorflow/lite/examples/label_image/label_image.md @@ -10,12 +10,12 @@ To build it for android ARMv8: --crosstool_top=//external:android/crosstool \ --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ --cpu=arm64-v8a \ - //tensorflow/contrib/lite/examples/label_image:label_image + //tensorflow/lite/examples/label_image:label_image ``` or ``` > bazel build --config android_arm64 --config monolithic --cxxopt=-std=c++11 \ - //tensorflow/contrib/lite/examples/label_image:label_image + //tensorflow/lite/examples/label_image:label_image ``` To build it for android arm-v7a: @@ -24,17 +24,17 @@ To build it for android arm-v7a: --crosstool_top=//external:android/crosstool \ --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ --cpu=armeabi-v7a \ - //tensorflow/contrib/lite/examples/label_image:label_image + //tensorflow/lite/examples/label_image:label_image ``` or ``` > bazel build --config android_arm --config monolithic --cxxopt=-std=c++11 \ - //tensorflow/contrib/lite/examples/label_image:label_image + //tensorflow/lite/examples/label_image:label_image ``` Build it for desktop machines (tested on Ubuntu and OS X) ``` -> bazel build --config opt --cxxopt=-std=c++11 //tensorflow/contrib/lite/examples/label_image:label_image +> bazel build --config opt --cxxopt=-std=c++11 //tensorflow/lite/examples/label_image:label_image ``` To run it. Prepare `./mobilenet_quant_v1_224.tflite`, `./grace_hopper.bmp`, and `./labels.txt`. diff --git a/tensorflow/contrib/lite/examples/label_image/label_image_test.cc b/tensorflow/lite/examples/label_image/label_image_test.cc similarity index 86% rename from tensorflow/contrib/lite/examples/label_image/label_image_test.cc rename to tensorflow/lite/examples/label_image/label_image_test.cc index de7de21f7741d3..6b4ec2a9374ca5 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image_test.cc +++ b/tensorflow/lite/examples/label_image/label_image_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h" -#include "tensorflow/contrib/lite/examples/label_image/get_top_n.h" -#include "tensorflow/contrib/lite/examples/label_image/label_image.h" +#include "tensorflow/lite/examples/label_image/bitmap_helpers.h" +#include "tensorflow/lite/examples/label_image/get_top_n.h" +#include "tensorflow/lite/examples/label_image/label_image.h" using ::testing::ElementsAreArray; @@ -27,7 +27,7 @@ namespace label_image { TEST(LabelImageTest, GraceHopper) { std::string lena_file = - "tensorflow/contrib/lite/examples/label_image/testdata/" + "tensorflow/lite/examples/label_image/testdata/" "grace_hopper.bmp"; int height, width, channels; Settings s; diff --git a/tensorflow/contrib/lite/examples/label_image/testdata/grace_hopper.bmp b/tensorflow/lite/examples/label_image/testdata/grace_hopper.bmp similarity index 100% rename from tensorflow/contrib/lite/examples/label_image/testdata/grace_hopper.bmp rename to tensorflow/lite/examples/label_image/testdata/grace_hopper.bmp diff --git a/tensorflow/contrib/lite/examples/minimal/BUILD b/tensorflow/lite/examples/minimal/BUILD similarity index 76% rename from tensorflow/contrib/lite/examples/minimal/BUILD rename to tensorflow/lite/examples/minimal/BUILD index b403628d6c457c..cdd67af1e93661 100644 --- a/tensorflow/contrib/lite/examples/minimal/BUILD +++ b/tensorflow/lite/examples/minimal/BUILD @@ -6,7 +6,7 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts") +load("//tensorflow/lite:build_def.bzl", "tflite_linkopts") tf_cc_binary( name = "minimal", @@ -21,7 +21,7 @@ tf_cc_binary( "//conditions:default": [], }), deps = [ - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:builtin_ops", ], ) diff --git a/tensorflow/contrib/lite/examples/minimal/minimal.cc b/tensorflow/lite/examples/minimal/minimal.cc similarity index 92% rename from tensorflow/contrib/lite/examples/minimal/minimal.cc rename to tensorflow/lite/examples/minimal/minimal.cc index 8b65cde7b79fde..46f8b09df6cee1 100644 --- a/tensorflow/contrib/lite/examples/minimal/minimal.cc +++ b/tensorflow/lite/examples/minimal/minimal.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/optional_debug_tools.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/optional_debug_tools.h" // This is an example that is minimal to read a model // from disk and perform inference. There is no data being loaded diff --git a/tensorflow/contrib/lite/examples/python/BUILD b/tensorflow/lite/examples/python/BUILD similarity index 83% rename from tensorflow/contrib/lite/examples/python/BUILD rename to tensorflow/lite/examples/python/BUILD index d337c3ddc43a23..a606d1aa563261 100644 --- a/tensorflow/contrib/lite/examples/python/BUILD +++ b/tensorflow/lite/examples/python/BUILD @@ -8,6 +8,6 @@ py_binary( main = "label_image.py", srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/lite/python:lite", + "//tensorflow/lite/python:lite", ], ) diff --git a/tensorflow/contrib/lite/examples/python/label_image.md b/tensorflow/lite/examples/python/label_image.md similarity index 81% rename from tensorflow/contrib/lite/examples/python/label_image.md rename to tensorflow/lite/examples/python/label_image.md index e81192a96c142f..b4ec42f52594cf 100644 --- a/tensorflow/contrib/lite/examples/python/label_image.md +++ b/tensorflow/lite/examples/python/label_image.md @@ -6,7 +6,7 @@ The example input image and labels file are from TensorFlow repo and MobileNet V1 model files. ``` -curl https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/contrib/lite/examples/label_image/testdata/grace_hopper.bmp > /tmp/grace_hopper.bmp +curl https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/lite/examples/label_image/testdata/grace_hopper.bmp > /tmp/grace_hopper.bmp curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz | tar xzv -C /tmp mobilenet_v1_1.0_224/labels.txt mv /tmp/mobilenet_v1_1.0_224/labels.txt /tmp/ @@ -17,7 +17,7 @@ Run ``` curl http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224_quant.tgz | tar xzv -C /tmp -bazel run --config opt //tensorflow/contrib/lite/examples/python:label_image +bazel run --config opt //tensorflow/lite/examples/python:label_image ``` We can get results like @@ -34,7 +34,7 @@ Run ``` curl http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz | tar xzv -C /tmp -bazel run --config opt //tensorflow/contrib/lite/examples/python:label_image \ +bazel run --config opt //tensorflow/lite/examples/python:label_image \ -- --model_file /tmp/mobilenet_v1_1.0_224.tflite ``` diff --git a/tensorflow/contrib/lite/examples/python/label_image.py b/tensorflow/lite/examples/python/label_image.py similarity index 97% rename from tensorflow/contrib/lite/examples/python/label_image.py rename to tensorflow/lite/examples/python/label_image.py index 282118a1d2b43a..0bc15d36a8ac2e 100644 --- a/tensorflow/contrib/lite/examples/python/label_image.py +++ b/tensorflow/lite/examples/python/label_image.py @@ -23,7 +23,7 @@ from PIL import Image -from tensorflow.contrib.lite.python import interpreter as interpreter_wrapper +from tensorflow.lite.python import interpreter as interpreter_wrapper def load_labels(filename): my_labels = [] diff --git a/tensorflow/contrib/lite/experimental/c/BUILD b/tensorflow/lite/experimental/c/BUILD similarity index 64% rename from tensorflow/contrib/lite/experimental/c/BUILD rename to tensorflow/lite/experimental/c/BUILD index 52e71619def71a..5dd62194deac2e 100644 --- a/tensorflow/contrib/lite/experimental/c/BUILD +++ b/tensorflow/lite/experimental/c/BUILD @@ -3,14 +3,14 @@ package(default_visibility = ["//visibility:private"]) package_group( name = "experimental", packages = [ - "//tensorflow/contrib/lite/experimental/...", + "//tensorflow/lite/experimental/...", ], ) licenses(["notice"]) # Apache 2.0 load( - "//tensorflow/contrib/lite:build_def.bzl", + "//tensorflow/lite:build_def.bzl", "tflite_cc_shared_object", "tflite_copts", "tflite_jni_binary", @@ -21,14 +21,14 @@ tflite_cc_shared_object( linkopts = select({ "//tensorflow:darwin": [ "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file - "$(location //tensorflow/contrib/lite/experimental/c:exported_symbols.lds)", + "$(location //tensorflow/lite/experimental/c:exported_symbols.lds)", "-Wl,-install_name,@rpath/libtensorflowlite_c.so", ], "//tensorflow:windows": [], "//conditions:default": [ "-z defs", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file - "$(location //tensorflow/contrib/lite/experimental/c:version_script.lds)", + "$(location //tensorflow/lite/experimental/c:version_script.lds)", ], }), deps = [ @@ -45,11 +45,11 @@ cc_library( hdrs = ["c_api_internal.h"], copts = tflite_copts(), visibility = [ - "//tensorflow/contrib/lite/experimental/c:__subpackages__", + "//tensorflow/lite/experimental/c:__subpackages__", ], deps = [ - "//tensorflow/contrib/lite:context", - "//tensorflow/contrib/lite:framework", + "//tensorflow/lite:context", + "//tensorflow/lite:framework", ], ) @@ -63,10 +63,10 @@ cc_library( ], deps = [ ":c_api_internal", - "//tensorflow/contrib/lite:context", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:schema_fbs_version", - "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/lite:context", + "//tensorflow/lite:framework", + "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/kernels:builtin_ops", ], ) @@ -78,7 +78,7 @@ cc_library( deps = [ ":c_api", ":c_api_internal", - "//tensorflow/contrib/lite:kernel_api", + "//tensorflow/lite:kernel_api", ], ) @@ -86,12 +86,15 @@ cc_test( name = "c_api_test", size = "small", srcs = ["c_api_test.cc"], - data = ["//tensorflow/contrib/lite:testdata/add.bin"], + data = [ + "//tensorflow/lite:testdata/add.bin", + "//tensorflow/lite:testdata/add_quantized.bin", + ], deps = [ ":c_api", - "//tensorflow/contrib/lite:context", - "//tensorflow/contrib/lite:kernel_api", - "//tensorflow/contrib/lite/testing:util", + "//tensorflow/lite:context", + "//tensorflow/lite:kernel_api", + "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -100,12 +103,12 @@ cc_test( name = "c_api_experimental_test", size = "small", srcs = ["c_api_experimental_test.cc"], - data = ["//tensorflow/contrib/lite:testdata/add.bin"], + data = ["//tensorflow/lite:testdata/add.bin"], deps = [ ":c_api", ":c_api_experimental", - "//tensorflow/contrib/lite:kernel_api", - "//tensorflow/contrib/lite/testing:util", + "//tensorflow/lite:kernel_api", + "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/contrib/lite/experimental/c/c_api.cc b/tensorflow/lite/experimental/c/c_api.cc similarity index 92% rename from tensorflow/contrib/lite/experimental/c/c_api.cc rename to tensorflow/lite/experimental/c/c_api.cc index 9c29f9d8b9ddfd..9caacfeb3614a9 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api.cc +++ b/tensorflow/lite/experimental/c/c_api.cc @@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/c/c_api.h" +#include "tensorflow/lite/experimental/c/c_api.h" #include -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/error_reporter.h" -#include "tensorflow/contrib/lite/experimental/c/c_api_internal.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/error_reporter.h" +#include "tensorflow/lite/experimental/c/c_api_internal.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" #ifdef __cplusplus extern "C" { @@ -181,6 +181,10 @@ void* TFL_TensorData(const TFL_Tensor* tensor) { const char* TFL_TensorName(const TFL_Tensor* tensor) { return tensor->name; } +TFL_QuantizationParams TFL_TensorQuantizationParams(const TFL_Tensor* tensor) { + return tensor->params; +} + TFL_Status TFL_TensorCopyFromBuffer(TFL_Tensor* tensor, const void* input_data, size_t input_data_size) { if (tensor->bytes != input_data_size) { @@ -199,7 +203,7 @@ TFL_Status TFL_TensorCopyToBuffer(const TFL_Tensor* tensor, void* output_data, return kTfLiteOk; } -// LINT.ThenChange(//tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs) +// LINT.ThenChange(//tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs) #ifdef __cplusplus } // extern "C" diff --git a/tensorflow/contrib/lite/experimental/c/c_api.h b/tensorflow/lite/experimental/c/c_api.h similarity index 93% rename from tensorflow/contrib/lite/experimental/c/c_api.h rename to tensorflow/lite/experimental/c/c_api.h index f52ab8f9ed65aa..49089011d1376b 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api.h +++ b/tensorflow/lite/experimental/c/c_api.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_C_C_API_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_C_C_API_H_ #include #include @@ -21,7 +21,7 @@ limitations under the License. // Eventually the various C APIs defined in context.h will be migrated into // the appropriate /c/c_api*.h header. For now, we pull in existing definitions // for convenience. -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/lite/context.h" // -------------------------------------------------------------------------- // Experimental C API for TensorFlowLite. @@ -53,6 +53,7 @@ limitations under the License. extern "C" { #endif // __cplusplus +typedef TfLiteQuantizationParams TFL_QuantizationParams; typedef TfLiteRegistration TFL_Registration; typedef TfLiteStatus TFL_Status; typedef TfLiteTensor TFL_Tensor; @@ -200,6 +201,13 @@ TFL_CAPI_EXPORT extern void* TFL_TensorData(const TFL_Tensor* tensor); // Returns the (null-terminated) name of the tensor. TFL_CAPI_EXPORT extern const char* TFL_TensorName(const TFL_Tensor* tensor); +// Returns the parameters for asymmetric quantization. The quantization +// parameters are only valid when the tensor type is `kTfLiteUInt8` and the +// `scale != 0`. Quantized values can be converted back to float using: +// real_value = scale * (quantized_value - zero_point); +TFL_CAPI_EXPORT extern TFL_QuantizationParams TFL_TensorQuantizationParams( + const TFL_Tensor* tensor); + // Copies from the provided input buffer into the tensor's buffer. // REQUIRES: input_data_size == TFL_TensorByteSize(tensor) TFL_CAPI_EXPORT extern TFL_Status TFL_TensorCopyFromBuffer( @@ -215,4 +223,4 @@ TFL_CAPI_EXPORT extern TFL_Status TFL_TensorCopyToBuffer( } // extern "C" #endif // __cplusplus -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_C_C_API_H_ diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc b/tensorflow/lite/experimental/c/c_api_experimental.cc similarity index 92% rename from tensorflow/contrib/lite/experimental/c/c_api_experimental.cc rename to tensorflow/lite/experimental/c/c_api_experimental.cc index 29f8701f53407d..a246ed99cd3736 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc +++ b/tensorflow/lite/experimental/c/c_api_experimental.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/c/c_api_experimental.h" +#include "tensorflow/lite/experimental/c/c_api_experimental.h" -#include "tensorflow/contrib/lite/experimental/c/c_api_internal.h" +#include "tensorflow/lite/experimental/c/c_api_internal.h" #ifdef __cplusplus extern "C" { diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental.h b/tensorflow/lite/experimental/c/c_api_experimental.h similarity index 87% rename from tensorflow/contrib/lite/experimental/c/c_api_experimental.h rename to tensorflow/lite/experimental/c/c_api_experimental.h index fca5d92f77caff..e4cd084520e52c 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api_experimental.h +++ b/tensorflow/lite/experimental/c/c_api_experimental.h @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_ -#include "tensorflow/contrib/lite/builtin_ops.h" -#include "tensorflow/contrib/lite/experimental/c/c_api.h" +#include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/experimental/c/c_api.h" #ifdef __cplusplus extern "C" { @@ -54,4 +54,4 @@ void TFL_InterpreterOptionsAddCustomOp(TFL_InterpreterOptions* options, } // extern "C" #endif // __cplusplus -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_ diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc b/tensorflow/lite/experimental/c/c_api_experimental_test.cc similarity index 86% rename from tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc rename to tensorflow/lite/experimental/c/c_api_experimental_test.cc index 1b1bedb7547063..e79c7204c6e7b9 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc +++ b/tensorflow/lite/experimental/c/c_api_experimental_test.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/c/c_api_experimental.h" +#include "tensorflow/lite/experimental/c/c_api_experimental.h" #include -#include "tensorflow/contrib/lite/builtin_ops.h" -#include "tensorflow/contrib/lite/experimental/c/c_api.h" -#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/experimental/c/c_api.h" +#include "tensorflow/lite/testing/util.h" namespace { @@ -34,7 +34,7 @@ TfLiteRegistration* GetDummyRegistration() { TEST(CApiExperimentalSimple, Smoke) { TFL_Model* model = TFL_NewModelFromFile( - "tensorflow/contrib/lite/testdata/add.bin"); + "tensorflow/lite/testdata/add.bin"); ASSERT_NE(model, nullptr); TFL_InterpreterOptions* options = TFL_NewInterpreterOptions(); diff --git a/tensorflow/contrib/lite/experimental/c/c_api_internal.h b/tensorflow/lite/experimental/c/c_api_internal.h similarity index 82% rename from tensorflow/contrib/lite/experimental/c/c_api_internal.h rename to tensorflow/lite/experimental/c/c_api_internal.h index da3af3cad4c548..8a2987c8f1c88f 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api_internal.h +++ b/tensorflow/lite/experimental/c/c_api_internal.h @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_INTERNAL_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_INTERNAL_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_C_C_API_INTERNAL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_C_C_API_INTERNAL_H_ -#include "tensorflow/contrib/lite/experimental/c/c_api.h" +#include "tensorflow/lite/experimental/c/c_api.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/op_resolver.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/op_resolver.h" // Internal structures used by the C API. These are likely to change and should // not be depended on. @@ -58,4 +58,4 @@ struct TFL_Interpreter { std::unique_ptr impl; }; -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_INTERNAL_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_C_C_API_INTERNAL_H_ diff --git a/tensorflow/contrib/lite/experimental/c/c_api_test.cc b/tensorflow/lite/experimental/c/c_api_test.cc similarity index 61% rename from tensorflow/contrib/lite/experimental/c/c_api_test.cc rename to tensorflow/lite/experimental/c/c_api_test.cc index 48a3714ec345a6..5fb14f342cba89 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api_test.cc +++ b/tensorflow/lite/experimental/c/c_api_test.cc @@ -15,17 +15,17 @@ limitations under the License. #include -#include "tensorflow/contrib/lite/experimental/c/c_api.h" +#include "tensorflow/lite/experimental/c/c_api.h" #include -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/testing/util.h" namespace { TEST(CApiSimple, Smoke) { TFL_Model* model = TFL_NewModelFromFile( - "tensorflow/contrib/lite/testdata/add.bin"); + "tensorflow/lite/testdata/add.bin"); ASSERT_NE(model, nullptr); TFL_InterpreterOptions* options = TFL_NewInterpreterOptions(); @@ -58,6 +58,11 @@ TEST(CApiSimple, Smoke) { EXPECT_NE(TFL_TensorData(input_tensor), nullptr); EXPECT_STREQ(TFL_TensorName(input_tensor), "input"); + TFL_QuantizationParams input_params = + TFL_TensorQuantizationParams(input_tensor); + EXPECT_EQ(input_params.scale, 0.f); + EXPECT_EQ(input_params.zero_point, 0); + std::array input = {1.f, 3.f}; ASSERT_EQ(TFL_TensorCopyFromBuffer(input_tensor, input.data(), input.size() * sizeof(float)), @@ -75,6 +80,11 @@ TEST(CApiSimple, Smoke) { EXPECT_NE(TFL_TensorData(output_tensor), nullptr); EXPECT_STREQ(TFL_TensorName(output_tensor), "output"); + TFL_QuantizationParams output_params = + TFL_TensorQuantizationParams(output_tensor); + EXPECT_EQ(output_params.scale, 0.f); + EXPECT_EQ(output_params.zero_point, 0); + std::array output; ASSERT_EQ(TFL_TensorCopyToBuffer(output_tensor, output.data(), output.size() * sizeof(float)), @@ -85,9 +95,69 @@ TEST(CApiSimple, Smoke) { TFL_DeleteInterpreter(interpreter); } +TEST(CApiSimple, QuantizationParams) { + TFL_Model* model = TFL_NewModelFromFile( + "tensorflow/lite/testdata/add_quantized.bin"); + ASSERT_NE(model, nullptr); + + TFL_Interpreter* interpreter = TFL_NewInterpreter(model, nullptr); + ASSERT_NE(interpreter, nullptr); + + TFL_DeleteModel(model); + + const std::array input_dims = {2}; + ASSERT_EQ(TFL_InterpreterResizeInputTensor(interpreter, 0, input_dims.data(), + input_dims.size()), + kTfLiteOk); + ASSERT_EQ(TFL_InterpreterAllocateTensors(interpreter), kTfLiteOk); + + TFL_Tensor* input_tensor = TFL_InterpreterGetInputTensor(interpreter, 0); + ASSERT_NE(input_tensor, nullptr); + EXPECT_EQ(TFL_TensorType(input_tensor), kTfLiteUInt8); + EXPECT_EQ(TFL_TensorNumDims(input_tensor), 1); + EXPECT_EQ(TFL_TensorDim(input_tensor, 0), 2); + + TFL_QuantizationParams input_params = + TFL_TensorQuantizationParams(input_tensor); + EXPECT_EQ(input_params.scale, 0.003922f); + EXPECT_EQ(input_params.zero_point, 0); + + const std::array input = {1, 3}; + ASSERT_EQ(TFL_TensorCopyFromBuffer(input_tensor, input.data(), + input.size() * sizeof(uint8_t)), + kTfLiteOk); + + ASSERT_EQ(TFL_InterpreterInvoke(interpreter), kTfLiteOk); + + const TFL_Tensor* output_tensor = + TFL_InterpreterGetOutputTensor(interpreter, 0); + ASSERT_NE(output_tensor, nullptr); + + TFL_QuantizationParams output_params = + TFL_TensorQuantizationParams(output_tensor); + EXPECT_EQ(output_params.scale, 0.003922f); + EXPECT_EQ(output_params.zero_point, 0); + + std::array output; + ASSERT_EQ(TFL_TensorCopyToBuffer(output_tensor, output.data(), + output.size() * sizeof(uint8_t)), + kTfLiteOk); + EXPECT_EQ(output[0], 3); + EXPECT_EQ(output[1], 9); + + const float dequantizedOutput0 = + output_params.scale * (output[0] - output_params.zero_point); + const float dequantizedOutput1 = + output_params.scale * (output[1] - output_params.zero_point); + EXPECT_EQ(dequantizedOutput0, 0.011766f); + EXPECT_EQ(dequantizedOutput1, 0.035298f); + + TFL_DeleteInterpreter(interpreter); +} + TEST(CApiSimple, ErrorReporter) { TFL_Model* model = TFL_NewModelFromFile( - "tensorflow/contrib/lite/testdata/add.bin"); + "tensorflow/lite/testdata/add.bin"); TFL_InterpreterOptions* options = TFL_NewInterpreterOptions(); // Install a custom error reporter into the interpreter by way of options. diff --git a/tensorflow/contrib/lite/experimental/c/exported_symbols.lds b/tensorflow/lite/experimental/c/exported_symbols.lds similarity index 100% rename from tensorflow/contrib/lite/experimental/c/exported_symbols.lds rename to tensorflow/lite/experimental/c/exported_symbols.lds diff --git a/tensorflow/contrib/lite/experimental/c/version_script.lds b/tensorflow/lite/experimental/c/version_script.lds similarity index 100% rename from tensorflow/contrib/lite/experimental/c/version_script.lds rename to tensorflow/lite/experimental/c/version_script.lds diff --git a/tensorflow/contrib/lite/experimental/examples/lstm/BUILD b/tensorflow/lite/experimental/examples/lstm/BUILD similarity index 91% rename from tensorflow/contrib/lite/experimental/examples/lstm/BUILD rename to tensorflow/lite/experimental/examples/lstm/BUILD index 3c1fe5f8becc32..7a475a24d36b6a 100644 --- a/tensorflow/contrib/lite/experimental/examples/lstm/BUILD +++ b/tensorflow/lite/experimental/examples/lstm/BUILD @@ -11,7 +11,7 @@ py_library( visibility = ["//visibility:public"], deps = [ "//tensorflow:tensorflow_py", - "//tensorflow/contrib/lite/python:lite", + "//tensorflow/lite/python:lite", "//tensorflow/python:framework", "@six_archive//:six", ], @@ -31,8 +31,8 @@ py_test( deps = [ ":tflite_lstm", "//tensorflow:tensorflow_py", - "//tensorflow/contrib/lite/python:lite", "//tensorflow/examples/tutorials/mnist:input_data", + "//tensorflow/lite/python:lite", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform", "//tensorflow/python/tools:optimize_for_inference", diff --git a/tensorflow/contrib/lite/experimental/examples/lstm/tflite_lstm.py b/tensorflow/lite/experimental/examples/lstm/tflite_lstm.py similarity index 99% rename from tensorflow/contrib/lite/experimental/examples/lstm/tflite_lstm.py rename to tensorflow/lite/experimental/examples/lstm/tflite_lstm.py index 2357743266f708..2fe8ebf9e99f8b 100644 --- a/tensorflow/contrib/lite/experimental/examples/lstm/tflite_lstm.py +++ b/tensorflow/lite/experimental/examples/lstm/tflite_lstm.py @@ -21,7 +21,7 @@ from __future__ import print_function import tensorflow as tf -from tensorflow.contrib.lite.python import lite +from tensorflow.lite.python import lite from tensorflow.python.keras import activations from tensorflow.python.keras import initializers from tensorflow.python.layers import base as base_layer diff --git a/tensorflow/contrib/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py similarity index 94% rename from tensorflow/contrib/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py rename to tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py index 2ca977518cb11d..eeb48d123113c5 100644 --- a/tensorflow/contrib/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py +++ b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py @@ -19,8 +19,9 @@ import numpy as np import tensorflow as tf -from tensorflow.contrib.lite.experimental.examples.lstm.tflite_lstm import TFLiteLSTMCell from tensorflow.examples.tutorials.mnist import input_data +from tensorflow.lite.experimental.examples.lstm.tflite_lstm import TFLiteLSTMCell +from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs from tensorflow.python.framework import test_util from tensorflow.python.platform import test from tensorflow.python.tools import optimize_for_inference_lib @@ -50,17 +51,17 @@ def setUp(self): # Batch size self.batch_size = 16 # Lstm Units. - self.num_units = 64 + self.num_units = 16 def buildLstmLayer(self): return tf.nn.rnn_cell.MultiRNNCell([ TFLiteLSTMCell( self.num_units, use_peepholes=True, forget_bias=0, name="rnn1"), - TFLiteLSTMCell(self.num_units, num_proj=64, forget_bias=0, name="rnn2"), + TFLiteLSTMCell(self.num_units, num_proj=8, forget_bias=0, name="rnn2"), TFLiteLSTMCell( self.num_units // 2, use_peepholes=True, - num_proj=64, + num_proj=8, forget_bias=0, name="rnn3"), TFLiteLSTMCell(self.num_units, forget_bias=0, name="rnn4") @@ -150,15 +151,15 @@ def tfliteInvoke(self, graph, test_inputs, outputs): tf.import_graph_def(graph, name="", input_map={"INPUT_IMAGE": tflite_input}) with tf.Session() as sess: curr = sess.graph_def - curr = tf.contrib.lite.convert_op_hints_to_stubs(graph_def=curr) + curr = convert_op_hints_to_stubs(graph_def=curr) curr = optimize_for_inference_lib.optimize_for_inference( curr, ["INPUT_IMAGE_LITE"], ["OUTPUT_CLASS"], [tf.float32.as_datatype_enum]) - tflite = tf.contrib.lite.toco_convert( + tflite = tf.lite.toco_convert( curr, [tflite_input], [outputs], allow_custom_ops=False) - interpreter = tf.contrib.lite.Interpreter(model_content=tflite) + interpreter = tf.lite.Interpreter(model_content=tflite) try: interpreter.allocate_tensors() @@ -189,7 +190,7 @@ def testStaticRnnMultiRnnCell(self): x, output_class, new_sess) result = self.tfliteInvoke(frozen_graph, test_inputs, output_class) - self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-3)) + self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2)) def testDynamicRnnMultiRnnCell(self): sess = tf.Session(config=CONFIG) @@ -219,7 +220,7 @@ def testDynamicRnnMultiRnnCell(self): x, output_class, new_sess) result = self.tfliteInvoke(frozen_graph, test_inputs, output_class) - self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-3)) + self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2)) if __name__ == "__main__": diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/.gitignore b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/.gitignore similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/.gitignore rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/.gitignore diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite.meta b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite.meta similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite.meta rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite.meta diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples.meta b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples.meta similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples.meta rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples.meta diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite.meta b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite.meta similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite.meta rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite.meta diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes.meta b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes.meta similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes.meta rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes.meta diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity.meta b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity.meta similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity.meta rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity.meta diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes.meta b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes.meta similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes.meta rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes.meta diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts.meta b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts.meta similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts.meta rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts.meta diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs.meta b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs.meta similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs.meta rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs.meta diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK.meta b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK.meta similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK.meta rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK.meta diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts.meta b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts.meta similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts.meta rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts.meta diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs.meta b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs.meta similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs.meta rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs.meta diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/AudioManager.asset b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/AudioManager.asset similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/AudioManager.asset rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/AudioManager.asset diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ClusterInputManager.asset b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ClusterInputManager.asset similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ClusterInputManager.asset rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ClusterInputManager.asset diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/DynamicsManager.asset b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/DynamicsManager.asset similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/DynamicsManager.asset rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/DynamicsManager.asset diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorBuildSettings.asset b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorBuildSettings.asset similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorBuildSettings.asset rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorBuildSettings.asset diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorSettings.asset b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorSettings.asset similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorSettings.asset rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorSettings.asset diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/GraphicsSettings.asset b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/GraphicsSettings.asset similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/GraphicsSettings.asset rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/GraphicsSettings.asset diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/InputManager.asset b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/InputManager.asset similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/InputManager.asset rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/InputManager.asset diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NavMeshAreas.asset b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NavMeshAreas.asset similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NavMeshAreas.asset rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NavMeshAreas.asset diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NetworkManager.asset b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NetworkManager.asset similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NetworkManager.asset rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NetworkManager.asset diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/Physics2DSettings.asset b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/Physics2DSettings.asset similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/Physics2DSettings.asset rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/Physics2DSettings.asset diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectSettings.asset b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectSettings.asset similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectSettings.asset rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectSettings.asset diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectVersion.txt b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectVersion.txt similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectVersion.txt rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectVersion.txt diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/QualitySettings.asset b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/QualitySettings.asset similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/QualitySettings.asset rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/QualitySettings.asset diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TagManager.asset b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TagManager.asset similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TagManager.asset rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TagManager.asset diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TimeManager.asset b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TimeManager.asset similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TimeManager.asset rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TimeManager.asset diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/UnityConnectSettings.asset b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/UnityConnectSettings.asset similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/UnityConnectSettings.asset rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/UnityConnectSettings.asset diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md similarity index 88% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md index f480c49cd050de..8c85ebfb63885f 100644 --- a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md +++ b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md @@ -10,7 +10,7 @@ that this has only been tested on Linux; the syntax may differ on Mac/Windows): ```sh bazel build -c opt --cxxopt=--std=c++11 \ - //tensorflow/contrib/lite/experimental/c:libtensorflowlite_c.so + //tensorflow/lite/experimental/c:libtensorflowlite_c.so ``` and for Android: @@ -20,7 +20,7 @@ bazel build -c opt --cxxopt=--std=c++11 \ --crosstool_top=//external:android/crosstool \ --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ --cpu=armeabi-v7a \ - //tensorflow/contrib/lite/experimental/c:libtensorflowlite_c.so + //tensorflow/lite/experimental/c:libtensorflowlite_c.so ``` If you encounter issues with native plugin discovery on Mac ("Darwin") diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/UnityPackageManager/manifest.json b/tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/UnityPackageManager/manifest.json similarity index 100% rename from tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/UnityPackageManager/manifest.json rename to tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/UnityPackageManager/manifest.json diff --git a/tensorflow/contrib/lite/experimental/kernels/BUILD b/tensorflow/lite/experimental/kernels/BUILD similarity index 51% rename from tensorflow/contrib/lite/experimental/kernels/BUILD rename to tensorflow/lite/experimental/kernels/BUILD index 4786cc62f93dc0..dd314545cb6488 100644 --- a/tensorflow/contrib/lite/experimental/kernels/BUILD +++ b/tensorflow/lite/experimental/kernels/BUILD @@ -4,7 +4,7 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 -load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") +load("//tensorflow/lite:build_def.bzl", "tflite_copts") load("//tensorflow:tensorflow.bzl", "tf_cc_test") # ctc support classes imported directly from TensorFlow. @@ -19,7 +19,7 @@ cc_library( ], deps = [ ":top_n", - "//tensorflow/contrib/lite/kernels/internal:types", + "//tensorflow/lite/kernels/internal:types", "//third_party/eigen3", ], ) @@ -31,7 +31,7 @@ cc_library( "top_n.h", ], deps = [ - "//tensorflow/contrib/lite/kernels/internal:types", + "//tensorflow/lite/kernels/internal:types", ], ) @@ -50,21 +50,21 @@ cc_library( }), deps = [ ":ctc_utils", - "//tensorflow/contrib/lite:builtin_op_data", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:string_util", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/kernels:builtin_ops", - "//tensorflow/contrib/lite/kernels:gemm_support", - "//tensorflow/contrib/lite/kernels:kernel_util", - "//tensorflow/contrib/lite/kernels:op_macros", - "//tensorflow/contrib/lite/kernels/internal:kernel_utils", - "//tensorflow/contrib/lite/kernels/internal:optimized", - "//tensorflow/contrib/lite/kernels/internal:optimized_base", - "//tensorflow/contrib/lite/kernels/internal:quantization_util", - "//tensorflow/contrib/lite/kernels/internal:reference_base", - "//tensorflow/contrib/lite/kernels/internal:tensor", - "//tensorflow/contrib/lite/kernels/internal:tensor_utils", + "//tensorflow/lite:builtin_op_data", + "//tensorflow/lite:framework", + "//tensorflow/lite:string_util", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/kernels:gemm_support", + "//tensorflow/lite/kernels:kernel_util", + "//tensorflow/lite/kernels:op_macros", + "//tensorflow/lite/kernels/internal:kernel_utils", + "//tensorflow/lite/kernels/internal:optimized", + "//tensorflow/lite/kernels/internal:optimized_base", + "//tensorflow/lite/kernels/internal:quantization_util", + "//tensorflow/lite/kernels/internal:reference_base", + "//tensorflow/lite/kernels/internal:tensor", + "//tensorflow/lite/kernels/internal:tensor_utils", "@flatbuffers", ], ) @@ -76,9 +76,9 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":experimental_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:builtin_ops", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", "@flatbuffers", ], diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h b/tensorflow/lite/experimental/kernels/ctc_beam_entry.h similarity index 94% rename from tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h rename to tensorflow/lite/experimental/kernels/ctc_beam_entry.h index a60ff2a1c53f1b..70fbefa2ba52c6 100644 --- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h +++ b/tensorflow/lite/experimental/kernels/ctc_beam_entry.h @@ -15,8 +15,8 @@ limitations under the License. // Copied from tensorflow/core/util/ctc/ctc_beam_entry.h // TODO(b/111524997): Remove this file. -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_ #include #include @@ -24,7 +24,7 @@ limitations under the License. #include #include "third_party/eigen3/Eigen/Core" -#include "tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h" +#include "tensorflow/lite/experimental/kernels/ctc_loss_util.h" namespace tflite { namespace experimental { @@ -147,4 +147,4 @@ class BeamComparer { } // namespace experimental } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_ diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h b/tensorflow/lite/experimental/kernels/ctc_beam_scorer.h similarity index 91% rename from tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h rename to tensorflow/lite/experimental/kernels/ctc_beam_scorer.h index ec60e26257b0f4..202b2af28ee14f 100644 --- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h +++ b/tensorflow/lite/experimental/kernels/ctc_beam_scorer.h @@ -23,10 +23,10 @@ limitations under the License. // Copied from tensorflow/core/util/ctc/ctc_beam_scorer.h // TODO(b/111524997): Remove this file. -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SCORER_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SCORER_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SCORER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SCORER_H_ -#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h" +#include "tensorflow/lite/experimental/kernels/ctc_beam_entry.h" namespace tflite { namespace experimental { @@ -76,4 +76,4 @@ class BaseBeamScorer { } // namespace experimental } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SCORER_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SCORER_H_ diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h b/tensorflow/lite/experimental/kernels/ctc_beam_search.h similarity index 96% rename from tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h rename to tensorflow/lite/experimental/kernels/ctc_beam_search.h index 7c5099235a4e32..1cc3ab7605ec3b 100644 --- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h +++ b/tensorflow/lite/experimental/kernels/ctc_beam_search.h @@ -15,8 +15,8 @@ limitations under the License. // Copied from tensorflow/core/util/ctc/ctc_beam_search.h // TODO(b/111524997): Remove this file. -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_ #include #include @@ -25,12 +25,12 @@ limitations under the License. #include #include "third_party/eigen3/Eigen/Core" -#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h" -#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h" -#include "tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h" -#include "tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h" -#include "tensorflow/contrib/lite/experimental/kernels/top_n.h" -#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/experimental/kernels/ctc_beam_entry.h" +#include "tensorflow/lite/experimental/kernels/ctc_beam_scorer.h" +#include "tensorflow/lite/experimental/kernels/ctc_decoder.h" +#include "tensorflow/lite/experimental/kernels/ctc_loss_util.h" +#include "tensorflow/lite/experimental/kernels/top_n.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" namespace tflite { namespace experimental { @@ -429,4 +429,4 @@ bool CTCBeamSearchDecoder::TopPaths( } // namespace experimental } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_ diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc b/tensorflow/lite/experimental/kernels/ctc_beam_search_decoder.cc similarity index 96% rename from tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc rename to tensorflow/lite/experimental/kernels/ctc_beam_search_decoder.cc index b1ebe4a804a971..9b1a05ee6e77c4 100644 --- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc +++ b/tensorflow/lite/experimental/kernels/ctc_beam_search_decoder.cc @@ -14,12 +14,12 @@ limitations under the License. ==============================================================================*/ #include #include "flatbuffers/flexbuffers.h" // TF:flatbuffers -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/experimental/kernels/ctc_beam_search.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc b/tensorflow/lite/experimental/kernels/ctc_beam_search_decoder_test.cc similarity index 97% rename from tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc rename to tensorflow/lite/experimental/kernels/ctc_beam_search_decoder_test.cc index 942dbbbeae553b..572b56f1225ccc 100644 --- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc +++ b/tensorflow/lite/experimental/kernels/ctc_beam_search_decoder_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include "flatbuffers/flexbuffers.h" // TF:flatbuffers -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h b/tensorflow/lite/experimental/kernels/ctc_decoder.h similarity index 94% rename from tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h rename to tensorflow/lite/experimental/kernels/ctc_decoder.h index 596ad4a5f7264a..1ceb3f7de47667 100644 --- a/tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h +++ b/tensorflow/lite/experimental/kernels/ctc_decoder.h @@ -15,8 +15,8 @@ limitations under the License. // Copied from tensorflow/core/util/ctc/ctc_decoder.h // TODO(b/111524997): Remove this file. -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_DECODER_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_DECODER_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_DECODER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_DECODER_H_ #include #include @@ -111,4 +111,4 @@ class CTCGreedyDecoder : public CTCDecoder { } // namespace experimental } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_DECODER_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_DECODER_H_ diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h b/tensorflow/lite/experimental/kernels/ctc_loss_util.h similarity index 88% rename from tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h rename to tensorflow/lite/experimental/kernels/ctc_loss_util.h index 0bae732533716a..f2206dbcc07e75 100644 --- a/tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h +++ b/tensorflow/lite/experimental/kernels/ctc_loss_util.h @@ -15,8 +15,8 @@ limitations under the License. // Copied from tensorflow/core/util/ctc/ctc_loss_util.h // TODO(b/111524997): Remove this file. -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_LOSS_UTIL_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_LOSS_UTIL_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_LOSS_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_LOSS_UTIL_H_ #include #include @@ -47,4 +47,4 @@ inline float LogSumExp(float log_prob_1, float log_prob_2) { } // namespace experimental } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_LOSS_UTIL_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_CTC_LOSS_UTIL_H_ diff --git a/tensorflow/contrib/lite/experimental/kernels/top_n.h b/tensorflow/lite/experimental/kernels/top_n.h similarity index 98% rename from tensorflow/contrib/lite/experimental/kernels/top_n.h rename to tensorflow/lite/experimental/kernels/top_n.h index cd2a2f1c80276d..4e2581cc71785c 100644 --- a/tensorflow/contrib/lite/experimental/kernels/top_n.h +++ b/tensorflow/lite/experimental/kernels/top_n.h @@ -38,8 +38,8 @@ limitations under the License. // Copied from tensorflow/core/lib/gtl/top_n.h // TODO(b/111524997): Remove this file. -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_TOP_N_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_TOP_N_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_TOP_N_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_TOP_N_H_ #include #include @@ -47,7 +47,7 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" namespace tflite { namespace gtl { @@ -338,4 +338,4 @@ void TopN::Reset() { } // namespace gtl } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_TOP_N_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_TOP_N_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/BUILD b/tensorflow/lite/experimental/micro/BUILD similarity index 71% rename from tensorflow/contrib/lite/experimental/micro/BUILD rename to tensorflow/lite/experimental/micro/BUILD index df1036bc8b9cc8..e11159868e11a0 100644 --- a/tensorflow/contrib/lite/experimental/micro/BUILD +++ b/tensorflow/lite/experimental/micro/BUILD @@ -5,7 +5,7 @@ package( licenses(["notice"]) # Apache 2.0 load( - "//tensorflow/contrib/lite/experimental/micro/testing:micro_test.bzl", + "//tensorflow/lite/experimental/micro/testing:micro_test.bzl", "tflite_micro_cc_test", ) @@ -25,10 +25,10 @@ cc_library( "simple_tensor_allocator.h", ], deps = [ - "//tensorflow/contrib/lite:schema_fbs_version", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/core/api", - "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/core/api", + "//tensorflow/lite/schema:schema_fbs", ], ) @@ -49,7 +49,7 @@ tflite_micro_cc_test( ], deps = [ ":micro_framework", - "//tensorflow/contrib/lite/experimental/micro/testing:micro_test", + "//tensorflow/lite/experimental/micro/testing:micro_test", ], ) @@ -60,7 +60,7 @@ tflite_micro_cc_test( ], deps = [ ":micro_framework", - "//tensorflow/contrib/lite/experimental/micro/testing:micro_test", + "//tensorflow/lite/experimental/micro/testing:micro_test", ], ) @@ -71,6 +71,6 @@ tflite_micro_cc_test( ], deps = [ ":micro_framework", - "//tensorflow/contrib/lite/experimental/micro/testing:micro_test", + "//tensorflow/lite/experimental/micro/testing:micro_test", ], ) diff --git a/tensorflow/contrib/lite/experimental/micro/README.md b/tensorflow/lite/experimental/micro/README.md similarity index 81% rename from tensorflow/contrib/lite/experimental/micro/README.md rename to tensorflow/lite/experimental/micro/README.md index e03703f4967ba9..673daed74c41a1 100644 --- a/tensorflow/contrib/lite/experimental/micro/README.md +++ b/tensorflow/lite/experimental/micro/README.md @@ -20,7 +20,7 @@ To meet those goals, we've made some tradeoffs: - **Interpreted**: Code generation is a popular pattern for embedded code, because it gives standalone code that's easy to modify and step through, but we've chosen to go with an interpreted approach. In our internal microcontroller work we've found that using an extremely stripped-down interpreter with almost no dependencies gives us a lot of the same advantages, but is easier to maintain. For example, when new updates come out for the underlying library, you can just merge your local modifications in a single step, rather than having to regenerate new code and then patch in any changes you subsequently made. The coarse granularity of the interpreted primitives means that each operation call typically takes hundreds of thousands of instruction cycles at least, so we don't see noticeable performance gains from avoiding what's essentially a single switch statement at the interpreter level to call each operation. We're still working on improving the packaging though, for example we're considering having the ability to snapshot all the source files and headers used for a particular model, being able to compile the code and data together as a library, and then access it through a minimal set of C interface calls which hide the underlying complexity. -- **Flatbuffers**: We represent our models using [the standard flatbuffer schema used by the rest of TensorFlow Lite](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/schema/schema.fbs), with the difference that we always keep it in read-only program memory (typically flash) rather than relying on having a file system to read it from. This is a good fit because flatbuffer's serialized format is designed to be mapped into memory without requiring any extra memory allocations or modifications to access it. All of the functions to read model values work directly on the serialized bytes, and large sections of data like weights are directly accessible as sequential C-style arrays of their data type, with no strides or unpacking needed. We do get a lot of value from using flatbuffers, but there is a cost in complexity. The flat buffer library code is all inline [inside the main headers](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/schema/schema_generated.h), but it isn't straightforward to inspect their implementations, and the model data structures aren't easy to comprehend from the debugger. The header for the schema itself also has to be periodically updated when new information is added to the file format, though we try to handle that transparently for most developers by checking in a pre-generated version. +- **Flatbuffers**: We represent our models using [the standard flatbuffer schema used by the rest of TensorFlow Lite](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs), with the difference that we always keep it in read-only program memory (typically flash) rather than relying on having a file system to read it from. This is a good fit because flatbuffer's serialized format is designed to be mapped into memory without requiring any extra memory allocations or modifications to access it. All of the functions to read model values work directly on the serialized bytes, and large sections of data like weights are directly accessible as sequential C-style arrays of their data type, with no strides or unpacking needed. We do get a lot of value from using flatbuffers, but there is a cost in complexity. The flat buffer library code is all inline [inside the main headers](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema_generated.h), but it isn't straightforward to inspect their implementations, and the model data structures aren't easy to comprehend from the debugger. The header for the schema itself also has to be periodically updated when new information is added to the file format, though we try to handle that transparently for most developers by checking in a pre-generated version. - **Code Duplication**: Some of the code in this prototype largely duplicates the logic in other parts of the TensorFlow Lite code base, for example the operator wrappers. We've tried to keep share as much as we can between the two interpreters, but there are some assumptions built into the original runtime that make this difficult. We'll be working on modularizing the main interpreter so that we can move to an entirely shared system. @@ -33,8 +33,8 @@ Building requires a Linux or OS X machine. - Open a terminal - Download the TensorFlow source with `git clone https://github.com/tensorflow` - Enter the source root directory by running `cd tensorflow` - - Download the dependencies by running `tensorflow/contrib/lite/experimental/micro/tools/make/download_dependencies.sh`. This may take a few minutes - - Build and test the library with `make -f tensorflow/contrib/lite/experimental/micro/tools/make/Makefile test` + - Download the dependencies by running `tensorflow/lite/experimental/micro/tools/make/download_dependencies.sh`. This may take a few minutes + - Build and test the library with `make -f tensorflow/lite/experimental/micro/tools/make/Makefile test` You should see a series of compilation steps, followed by `~~~ALL TESTS PASSED~~~` for the various tests of the code that it will run. If there's an @@ -43,7 +43,7 @@ error, you should get an informative message from make about what went wrong. These tests are all built as simple binaries with few dependencies, so you can run them manually. For example, here's how to run the depthwise convolution test, and its output: ``` -tensorflow/contrib/lite/experimental/micro/tools/make/gen/linux_x86_64/bin/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test +tensorflow/lite/experimental/micro/tools/make/gen/linux_x86_64/bin/tensorflow/lite/experimental/micro/kernels/depthwise_conv_test Testing SimpleTest Testing SimpleTestQuantized @@ -53,7 +53,7 @@ Testing SimpleTestReluQuantized ~ALL TESTS PASSED~~~ ``` -Looking at the [depthwise_conv_test.cc](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test.cc) code, you'll see a sequence that looks like this: +Looking at the [depthwise_conv_test.cc](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/micro/kernels/depthwise_conv_test.cc) code, you'll see a sequence that looks like this: ``` ... @@ -78,19 +78,19 @@ So, why are we running tests in this complicated way? So far, we've been buildin ## Building for the "Blue Pill" STM32F103 -The goal of this library is to enable machine learning on resource-constrained micro controllers and DSPs, and as part of that we've targeted the ["Blue Pill" STM32F103-compatible development board](https://github.com/google/googletest) as a cheap and popular platform. It only has 20KB of RAM and 64KB of flash, so it's a good device to ensure we can run efficiently on small chips. +The goal of this library is to enable machine learning on resource-constrained micro controllers and DSPs, and as part of that we've targeted the ["Blue Pill" STM32F103-compatible development board](https://github.com/google/stm32_bare_lib) as a cheap and popular platform. It only has 20KB of RAM and 64KB of flash, so it's a good device to ensure we can run efficiently on small chips. It's fairly easy to [buy and wire up a physical board](https://github.com/google/stm32_bare_lib#wiring-up-your-blue-pill), but even if you don't have an actual device, the [Renode project](https://renode.io/) makes it easy to run a faithful emulation on your desktop machine. You'll need [Docker](https://www.docker.com/) installed, but once you have that set up, try running the following command: -`make -f tensorflow/contrib/lite/experimental/micro/tools/make/Makefile TARGET=bluepill test` +`make -f tensorflow/lite/experimental/micro/tools/make/Makefile TARGET=bluepill test` You should see a similar set of outputs as you did in the previous section, with the addition of some extra Docker logging messages. These are because we're using Docker to run the Renode micro controller emulation tool, and the tests themselves are being run on a simulated STM32F103 device. The communication channels between an embedded device and the host are quite limited, so the test harness looks at the output of the debug log to see if tests have passed, just as it did in the previous section. This makes it a very flexible way to run cross-platform tests, even when a platform has no operating system facilities, as long as it can output debugging text logs. To understand what's happening here, try running the same depthwise convolution test, but through the emulated device test harness, with the following command: ``` -tensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh \ -tensorflow/contrib/lite/experimental/micro/tools/make/gen/bluepill_cortex-m3/bin/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test \ +tensorflow/lite/experimental/micro/testing/test_bluepill_binary.sh \ +tensorflow/lite/experimental/micro/tools/make/gen/bluepill_cortex-m3/bin/tensorflow/lite/experimental/micro/kernels/depthwise_conv_test \ '~~~ALL TESTS PASSED~~~' ``` @@ -117,7 +117,7 @@ LOGS: 03:27:32.4839 [DEBUG] cpu.uartSemihosting: [+41µs host +0s virt 0s virt from start] ~~~ALL TESTS PASSED~~~ 03:27:32.4839 [DEBUG] cpu.uartSemihosting: [+5µs host +0s virt 0s virt from start] ... -tensorflow/contrib/lite/experimental/micro/tools/make/gen/bluepill_cortex-m3/bin/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test: PASS +tensorflow/lite/experimental/micro/tools/make/gen/bluepill_cortex-m3/bin/tensorflow/lite/experimental/micro/kernels/depthwise_conv_test: PASS ``` There's a lot of output here, but you should be able to see that the same tests diff --git a/tensorflow/contrib/lite/experimental/micro/compatibility.h b/tensorflow/lite/experimental/micro/compatibility.h similarity index 86% rename from tensorflow/contrib/lite/experimental/micro/compatibility.h rename to tensorflow/lite/experimental/micro/compatibility.h index 4f0fd9f3120a5d..3fa91644bdd64e 100644 --- a/tensorflow/contrib/lite/experimental/micro/compatibility.h +++ b/tensorflow/lite/experimental/micro/compatibility.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_COMPATIBILITY_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_COMPATIBILITY_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_COMPATIBILITY_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_COMPATIBILITY_H_ // C++ will automatically create class-specific delete operators for virtual // objects, which by default call the global delete function. For embedded @@ -29,4 +29,4 @@ limitations under the License. #define TF_LITE_REMOVE_VIRTUAL_DELETE #endif -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_COMPATIBILITY_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_COMPATIBILITY_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/BUILD b/tensorflow/lite/experimental/micro/examples/micro_speech/BUILD new file mode 100644 index 00000000000000..69022b611ed14a --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/BUILD @@ -0,0 +1,54 @@ +# Description: +# TensorFlow Lite microcontroller example. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow/lite/experimental/micro/testing:micro_test.bzl", + "tflite_micro_cc_test", +) + +tflite_micro_cc_test( + name = "micro_speech_test", + srcs = [ + "micro_speech_test.cc", + "no_features_data.cc", + "no_features_data.h", + "tiny_conv_model_data.cc", + "tiny_conv_model_data.h", + "yes_features_data.cc", + "yes_features_data.h", + ], + deps = [ + "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro/kernels:all_ops_resolver", + "//tensorflow/lite/experimental/micro/kernels:micro_ops", + "//tensorflow/lite/experimental/micro/testing:micro_test", + "//tensorflow/lite/schema:schema_fbs", + ], +) + +tflite_micro_cc_test( + name = "preprocessor_test", + srcs = [ + "no_30ms_sample_data.cc", + "no_30ms_sample_data.h", + "no_power_spectrum_data.cc", + "no_power_spectrum_data.h", + "preprocessor.cc", + "preprocessor.h", + "preprocessor_test.cc", + "yes_30ms_sample_data.cc", + "yes_30ms_sample_data.h", + "yes_power_spectrum_data.cc", + "yes_power_spectrum_data.h", + ], + deps = [ + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro/testing:micro_test", + ], +) diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/README.md b/tensorflow/lite/experimental/micro/examples/micro_speech/README.md similarity index 84% rename from tensorflow/contrib/lite/experimental/micro/examples/micro_speech/README.md rename to tensorflow/lite/experimental/micro/examples/micro_speech/README.md index 438a432356be5c..500eed33bab018 100644 --- a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/README.md +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/README.md @@ -4,43 +4,43 @@ This examples shows how you can use TensorFlow Lite to run a 20 kilobyte neural ## Table of Contents - * [Getting Started](#getting-started) - * [Getting Started on a Microcontroller](#getting-started-on-a-microcontroller) - * [Calculating the Input to the Neural Network](#calculating-the-input-to-the-neural-network) - * [Creating Your Own Model](#creating-your-own-model) + - [Getting Started](#getting-started) + - [Getting Started on a Microcontroller](#getting-started-on-a-microcontroller) + - [Calculating the Input to the Neural Network](#calculating-the-input-to-the-neural-network) + - [Creating Your Own Model](#creating-your-own-model) ## Getting Started To compile and test this example on a desktop Linux or MacOS machine, download [the TensorFlow source code](https://github.com/tensorflow/tensorflow), `cd` into the source directory from a terminal, and then retrieve the support libraries you need by running: ``` -tensorflow/contrib/lite/experimental/micro/tools/make/download_dependencies.sh +tensorflow/lite/experimental/micro/tools/make/download_dependencies.sh ``` This will take a few minutes, and downloads frameworks the code uses like [CMSIS](https://developer.arm.com/embedded/cmsis) and [flatbuffers](https://google.github.io/flatbuffers/). Once that process has finished, run: ``` -make -f tensorflow/contrib/lite/experimental/micro/tools/make/Makefile test_micro_speech +make -f tensorflow/lite/experimental/micro/tools/make/Makefile test_micro_speech ``` You should see a series of files get compiled, followed by some logging output from a test, which should conclude with "~~~ALL TESTS PASSED~~~". If you see this, it means that a small program has been built and run that loads a trained TensorFlow model, runs some example inputs through it, and got the expected outputs. This particular test runs spectrograms generated from recordings of people saying "Yes" and "No", and checks that the network correctly identifies them. -To understand how TensorFlow Lite does this, you can look at the `TestInvoke()` function in [micro_speech_test.cc](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc). It's a fairly small amount of code, creating an interpreter, getting a handle to a model that's been compiled into the program, and then invoking the interpreter with the model and sample inputs. +To understand how TensorFlow Lite does this, you can look at the `TestInvoke()` function in [micro_speech_test.cc](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc). It's a fairly small amount of code, creating an interpreter, getting a handle to a model that's been compiled into the program, and then invoking the interpreter with the model and sample inputs. ## Getting Started on a Microcontroller Once you have downloaded the dependencies and got the x86/Linux build working, you can try building a version for the STM32F103 'bluepill' device. The following command will build the test and then run it on an emulator, assuming you have Docker installed: ``` -make -f tensorflow/contrib/lite/experimental/micro/tools/make/Makefile TARGET=bluepill test_micro_speech +make -f tensorflow/lite/experimental/micro/tools/make/Makefile TARGET=bluepill test_micro_speech ``` If you have a real device [(see here for how to set one up)](https://github.com/google/stm32_bare_lib/tree/master/README.md) you can then convert the ELF file into a a `.bin` format executable to load onto it by running: ``` arm-none-eabi-objcopy \ -tensorflow/contrib/lite/experimental/micro/tools/make/gen/bluepill_cortex-m3/bin/micro_speech_test \ -tensorflow/contrib/lite/experimental/micro/tools/make/gen/bluepill_cortex-m3/bin/micro_speech_test.bin \ +tensorflow/lite/experimental/micro/tools/make/gen/bluepill_cortex-m3/bin/micro_speech_test \ +tensorflow/lite/experimental/micro/tools/make/gen/bluepill_cortex-m3/bin/micro_speech_test.bin \ --output binary ``` @@ -89,7 +89,7 @@ bazel run tensorflow/examples/speech_commands:freeze -- \ The next step is to create a TensorFlow Lite file from the frozen graph: ``` -bazel run tensorflow/contrib/lite/toco:toco -- \ +bazel run tensorflow/lite/toco:toco -- \ --input_file=/tmp/tiny_conv.pb --output_file=/tmp/tiny_conv.tflite \ --input_shapes=1,49,43,1 --input_arrays=Reshape_1 --output_arrays='labels_softmax' \ --inference_type=QUANTIZED_UINT8 --mean_values=0 --std_values=2 \ diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc similarity index 88% rename from tensorflow/contrib/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc rename to tensorflow/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc index 0f4731fd4b2a08..4e54ff670eb9ba 100644 --- a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/micro/examples/micro_speech/no_features_data.h" -#include "tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h" -#include "tensorflow/contrib/lite/experimental/micro/examples/micro_speech/yes_features_data.h" -#include "tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h" -#include "tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h" -#include "tensorflow/contrib/lite/experimental/micro/micro_interpreter.h" -#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h" -#include "tensorflow/contrib/lite/schema/schema_generated.h" -#include "tensorflow/contrib/lite/version.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.h" +#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h" +#include "tensorflow/lite/experimental/micro/micro_error_reporter.h" +#include "tensorflow/lite/experimental/micro/micro_interpreter.h" +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/version.h" TF_LITE_MICRO_TESTS_BEGIN diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/no_30ms_sample_data.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/no_30ms_sample_data.cc new file mode 100644 index 00000000000000..6eaa5c2fed61fa --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/no_30ms_sample_data.cc @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See the header for documentation on the meaning of this data. + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/no_30ms_sample_data.h" + +const int g_no_30ms_sample_data_size = 480; +const int16_t g_no_30ms_sample_data[480] = { + 5713, 5735, 5735, 5737, 5701, 5691, 5656, 5633, 5611, 5552, 5475, + 5394, 5293, 5177, 5064, 4924, 4737, 4599, 4420, 4237, 4048, 3828, + 3623, 3413, 3183, 2915, 2622, 2308, 1980, 1657, 1261, 901, 549, + 205, -85, -383, -688, -969, -1246, -1530, -1850, -2206, -2561, -2915, + -3224, -3482, -3713, -3921, -4107, -4287, -4470, -4660, -4850, -5057, -5239, + -5395, -5540, -5619, -5697, -5724, -5697, -5675, -5633, -5590, -5579, -5530, + -5486, -5442, -5426, -5391, -5348, -5276, -5197, -5124, -5039, -4925, -4808, + -4677, -4581, -4479, -4343, -4218, -4087, -3970, -3858, -3729, -3570, -3384, + -3206, -3020, -2839, -2636, -2453, -2287, -2185, -2154, -1926, -1562, -1223, + -758, -473, -64, 395, 599, 880, 814, 938, 1172, 1498, 1928, + 2127, 2422, 2608, 2841, 2937, 2886, 2815, 2985, 3324, 3757, 4152, + 4481, 4652, 4917, 4965, 4766, 4583, 4328, 4503, 4815, 5118, 5408, + 5682, 5956, 6082, 6055, 5744, 5426, 5341, 5427, 5606, 5882, 6065, + 6226, 6428, 6477, 6385, 6009, 5728, 5552, 5439, 5339, 5200, 5008, + 4947, 4835, 4614, 4330, 3887, 3521, 3111, 2460, 1983, 1297, 650, + 279, -353, -720, -1044, -1518, -1668, -2117, -2496, -2743, -3266, -3607, + -3790, -4149, -4075, -4042, -4096, -3981, -4138, -4226, -4214, -4503, -4455, + -4577, -4642, -4346, -4351, -4270, -4263, -4522, -4521, -4673, -4814, -4731, + -4950, -5011, -5004, -5288, -5341, -5566, -5833, -5783, -5929, -5847, -5765, + -5828, -5644, -5613, -5615, -5428, -5291, -5014, -4554, -4277, -3964, -3854, + -3829, -3612, -3603, -3438, -3137, -2831, -2164, -1438, -939, -330, -156, + 46, 242, 73, 242, 220, 239, 542, 565, 739, 872, 801, + 857, 676, 543, 586, 567, 828, 1142, 1490, 1985, 2508, 2982, + 3438, 3699, 3939, 4069, 4178, 4420, 4622, 4917, 5338, 5801, 6285, + 6658, 6963, 7213, 7233, 7328, 7176, 7038, 7031, 6860, 6957, 6767, + 6599, 6523, 6212, 6147, 6063, 5860, 6020, 6015, 6033, 6184, 5722, + 5607, 5016, 4337, 4063, 3229, 3080, 3006, 2804, 3035, 2541, 2136, + 1879, 1012, 401, -575, -1584, -1930, -2278, -2485, -2477, -2712, -2747, + -2766, -3320, -3592, -4188, -4669, -4672, -4939, -4789, -4426, -4203, -3674, + -3563, -3656, -3759, -4067, -4257, -4522, -4970, -5204, -5237, -5139, -4907, + -4911, -4917, -4921, -5007, -5230, -5654, -6122, -6464, -6733, -6948, -7067, + -6972, -6800, -6520, -6132, -5830, -5382, -5091, -4797, -4546, -4472, -4362, + -4350, -4235, -3851, -3454, -3144, -2735, -2341, -1845, -1262, -958, -549, + -166, 66, 382, 366, 352, 341, 85, -13, -176, -303, -235, + -341, -309, -227, -249, -50, 143, 384, 874, 1149, 1552, 2155, + 2767, 3499, 3994, 4460, 4920, 5288, 5569, 5704, 5881, 6094, 6461, + 6653, 6803, 7115, 7311, 7521, 7612, 7443, 7380, 7124, 6742, 6495, + 5964, 5656, 5415, 5167, 5656, 5813, 6027, 6401, 6351, 6787, 7019, + 6581, 6512, 5965, 5308, 5140, 4336, 4147, 3899, 3398, 3360, 2830, + 2624, 1968, 1026, 395, -699, -1424, -2327, -3006, -3192, -3435, -3337, + -3686, -3513, -3350, -3502, -3261, -3878, -4005, -4063, -4187, -3767, -3598, + -3384, -3300, -3094, -2857, -3023, -3274, -3851, -4352, -4523, -4943, -5477, + -5612, -5682, -5733, -5714, -5965, -6110, -5950, -6158, -6548, -6897, -7165, + -7281, -7352, -7258, -7185, -6659, -5946, -5470, +}; diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/no_30ms_sample_data.h b/tensorflow/lite/experimental/micro/examples/micro_speech/no_30ms_sample_data.h new file mode 100644 index 00000000000000..ff6b874089903d --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/no_30ms_sample_data.h @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This data was created from the PCM data in a WAV file held in v2 of the +// Speech Commands test dataset, at the path: +// speech_commands_test_set_v0.02/no/f9643d42_nohash_4.wav +// The data was extracted starting at an offset of 8,960, which corresponds to +// the 29th spectrogram slice. It's designed to be used to test the +// preprocessing pipeline, to ensure that the expected spectrogram slice is +// produced given this input. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_NO_30MS_SAMPLE_DATA_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_NO_30MS_SAMPLE_DATA_H_ + +#include + +extern const int g_no_30ms_sample_data_size; +extern const int16_t g_no_30ms_sample_data[]; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_NO_30MS_SAMPLE_DATA_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/no_features_data.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.cc similarity index 99% rename from tensorflow/contrib/lite/experimental/micro/examples/micro_speech/no_features_data.cc rename to tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.cc index 3615deb26c4f0e..e98c84f7ed2e67 100644 --- a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/no_features_data.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/micro/examples/micro_speech/no_features_data.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.h" /* File automatically created by * tensorflow/examples/speech_commands/wav_to_features.py \ diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/no_features_data.h b/tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.h similarity index 74% rename from tensorflow/contrib/lite/experimental/micro/examples/micro_speech/no_features_data.h rename to tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.h index b53d0a202b75ea..e2ee0c46cf13b0 100644 --- a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/no_features_data.h +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.h @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_NO_FEATURES_DATA_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_NO_FEATURES_DATA_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_NO_FEATURES_DATA_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_NO_FEATURES_DATA_H_ extern const int g_no_f9643d42_nohash_4_width; extern const int g_no_f9643d42_nohash_4_height; extern const unsigned char g_no_f9643d42_nohash_4_data[]; -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_NO_FEATURES_DATA_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_NO_FEATURES_DATA_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/no_power_spectrum_data.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/no_power_spectrum_data.cc new file mode 100644 index 00000000000000..c4fc5c33bb329c --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/no_power_spectrum_data.cc @@ -0,0 +1,23 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See the header for documentation on the meaning of this data. + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/no_power_spectrum_data.h" + +const uint8_t g_no_power_spectrum_data[g_no_power_spectrum_data_size] = { + 255, 7, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +}; diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/no_power_spectrum_data.h b/tensorflow/lite/experimental/micro/examples/micro_speech/no_power_spectrum_data.h new file mode 100644 index 00000000000000..fa39d3c70d78ce --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/no_power_spectrum_data.h @@ -0,0 +1,29 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This data was extracted from the larger feature data held in +// no_features_data.cc and consists of the 29th spectrogram slice of 43 values. +// This is the expected result of running the sample data in +// no_30ms_sample_data.cc through through the preprocessing pipeline. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_NO_POWER_SPECTRUM_DATA_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_NO_POWER_SPECTRUM_DATA_H_ + +#include + +constexpr int g_no_power_spectrum_data_size = 43; +extern const uint8_t g_no_power_spectrum_data[]; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_NO_POWER_SPECTRUM_DATA_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.cc new file mode 100644 index 00000000000000..12f9e22038bafa --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.cc @@ -0,0 +1,149 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Reference implementation of the preprocessing pipeline, with the same +// results as the audio tutorial at +// https://www.tensorflow.org/tutorials/sequences/audio_recognition +// This module takes 30ms of PCM-encoded signed 16-bit audio samples (at 16KHz, +// so 480 values), and extracts a power spectrum of frequencies. There are 43 +// frequency bands in the result, derived from the original 256 output from the +// discrete Fourier transform, and averaged together in groups of 6. +// It's expected that most platforms will have optimized versions of the +// functions used here, for example replacing the DFT with an FFT, so this +// version shouldn't be used where performance is critical. + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h" + +#include + +namespace { + +// These constants allow us to allocate fixed-sized arrays on the stack for our +// working memory. +constexpr int kInputSize = 512; +constexpr int kAverageWindowSize = 6; +constexpr int kOutputSize = + ((kInputSize / 2) + (kAverageWindowSize - 1)) / kAverageWindowSize; + +// Performs a discrete Fourier transform on the real inputs. This corresponds to +// rdft() in the FFT package at http://www.kurims.kyoto-u.ac.jp/~ooura/fft.html, +// and to kiss_fftr() in KISSFFT at https://github.com/mborgerding/kissfft. +// It takes in an array of float real values, and returns a result of the same +// length with float real and imaginary components interleaved, so +// fourier_output[0] is the first real value, fourier_output[1] is the first +// imaginary, fourier_output[2] is the second real, and so on. +// The calling function should ensure that the array passed in as fourier_output +// is at least time_series_size in length. Most optimized FFT implementations +// require the length to be a power of two as well, but this version doesn't +// enforce that. +void CalculateDiscreteFourierTransform(float* time_series, int time_series_size, + float* fourier_output) { + for (int i = 0; i < time_series_size / 2; ++i) { + float real = 0; + for (int j = 0; j < time_series_size; ++j) { + real += time_series[j] * cos(j * i * M_PI * 2 / time_series_size); + } + float imaginary = 0; + for (int j = 0; j < time_series_size; ++j) { + imaginary -= time_series[j] * sin(j * i * M_PI * 2 / time_series_size); + } + fourier_output[(i * 2) + 0] = real; + fourier_output[(i * 2) + 1] = imaginary; + } +} + +// Produces a simple sine curve that is used to ensure frequencies at the center +// of the current sample window are weighted more heavily than those at the end. +void CalculatePeriodicHann(int window_length, float* window_function) { + for (int i = 0; i < window_length; ++i) { + window_function[i] = 0.5 - 0.5 * cos((2 * M_PI * i) / window_length); + } +} + +} // namespace + +TfLiteStatus Preprocess(tflite::ErrorReporter* error_reporter, + const int16_t* input, int input_size, int output_size, + uint8_t* output) { + // Ensure our input and output data arrays are valid. + if (input_size > kInputSize) { + error_reporter->Report("Input size %d larger than %d", input_size, + kInputSize); + return kTfLiteError; + } + if (output_size != kOutputSize) { + error_reporter->Report("Requested output size %d doesn't match %d", + output_size, kOutputSize); + return kTfLiteError; + } + + // Pre-calculate the window function we'll be applying to the input data. + // In a real application, we'd calculate this table once in an initialization + // function and store it for repeated reuse. + float window_function[kInputSize]; + CalculatePeriodicHann(input_size, window_function); + + // Apply the window function to our time series input, and pad it with zeroes + // to the next power of two. + float float_input[kInputSize]; + for (int i = 0; i < kInputSize; ++i) { + if (i < input_size) { + float_input[i] = + (input[i] * window_function[i]) / static_cast(1 << 15); + } else { + float_input[i] = 0.0f; + } + } + + // Pull the frequency data from the time series sample. + float fourier_values[kInputSize]; + CalculateDiscreteFourierTransform(float_input, kInputSize, fourier_values); + + // We have the complex numbers giving us information about each frequency + // band, but all we want to know is how strong each frequency is, so calculate + // the squared magnitude by adding together the squares of each component. + float power_spectrum[kInputSize / 2]; + for (int i = 0; i < (kInputSize / 2); ++i) { + const float real = fourier_values[(i * 2) + 0]; + const float imaginary = fourier_values[(i * 2) + 1]; + power_spectrum[i] = (real * real) + (imaginary * imaginary); + } + + // Finally, reduce the size of the output by averaging together six adjacent + // frequencies into each slot, producing an array of 43 values. + for (int i = 0; i < kOutputSize; ++i) { + float total = 0.0f; + for (int j = 0; j < kAverageWindowSize; ++j) { + const int index = (i * kAverageWindowSize) + j; + if (index < (kInputSize / 2)) { + total += power_spectrum[index]; + } + } + const float average = total / kAverageWindowSize; + // Quantize the result into eight bits, effectively multiplying by two. + // The 127.5 constant here has to match the features_max value defined in + // tensorflow/examples/speech_commands/input_data.py, and this also assumes + // that features_min is zero. It it wasn't, we'd have to subtract it first. + int quantized_average = roundf(average * (255.0f / 127.5f)); + if (quantized_average < 0) { + quantized_average = 0; + } + if (quantized_average > 255) { + quantized_average = 255; + } + output[i] = quantized_average; + } + return kTfLiteOk; +} diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h b/tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h new file mode 100644 index 00000000000000..dede2a864219c3 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h @@ -0,0 +1,26 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_PREPROCESSOR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_PREPROCESSOR_H_ + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/experimental/micro/micro_error_reporter.h" + +TfLiteStatus Preprocess(tflite::ErrorReporter* error_reporter, + const int16_t* input, int input_size, int output_size, + uint8_t* output); + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_PREPROCESSOR_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor_test.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor_test.cc new file mode 100644 index 00000000000000..e8b49f67e3d72f --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor_test.cc @@ -0,0 +1,63 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/no_30ms_sample_data.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/no_power_spectrum_data.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/yes_30ms_sample_data.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/yes_power_spectrum_data.h" +#include "tensorflow/lite/experimental/micro/micro_error_reporter.h" +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(TestPreprocessor) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + uint8_t yes_calculated_data[g_yes_power_spectrum_data_size]; + TfLiteStatus yes_status = Preprocess( + error_reporter, g_yes_30ms_sample_data, g_yes_30ms_sample_data_size, + g_yes_power_spectrum_data_size, yes_calculated_data); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, yes_status); + + for (int i = 0; i < g_yes_power_spectrum_data_size; ++i) { + TF_LITE_MICRO_EXPECT_EQ(g_yes_power_spectrum_data[i], + yes_calculated_data[i]); + if (g_yes_power_spectrum_data[i] != yes_calculated_data[i]) { + error_reporter->Report("Expected value %d but found %d", + g_yes_power_spectrum_data[i], + yes_calculated_data[i]); + } + } + + uint8_t no_calculated_data[g_yes_power_spectrum_data_size]; + TfLiteStatus no_status = Preprocess( + error_reporter, g_no_30ms_sample_data, g_no_30ms_sample_data_size, + g_no_power_spectrum_data_size, no_calculated_data); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, no_status); + + for (int i = 0; i < g_no_power_spectrum_data_size; ++i) { + TF_LITE_MICRO_EXPECT_EQ(g_no_power_spectrum_data[i], no_calculated_data[i]); + if (g_no_power_spectrum_data[i] != no_calculated_data[i]) { + error_reporter->Report("Expected value %d but found %d", + g_no_power_spectrum_data[i], + no_calculated_data[i]); + } + } +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc similarity index 99% rename from tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc rename to tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc index f0769a1237d64a..62e4359859a422 100644 --- a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc @@ -17,7 +17,7 @@ limitations under the License. // xxd -i tiny_conv.tflite > tiny_conv_model_data.cc // See the README for a full description of the creation process. -#include "tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h" const unsigned char g_tiny_conv_model_data[] = { 0x18, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c, 0x33, 0x00, 0x00, 0x0e, 0x00, diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h b/tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h similarity index 78% rename from tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h rename to tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h index 2953cc852d98fa..a465dbfabf7cbb 100644 --- a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h @@ -18,10 +18,10 @@ limitations under the License. // don't have a file system. It was created using the command: // xxd -i tiny_conv.tflite > tiny_conv_model_data.cc -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_TINY_CONV_MODEL_DATA_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_TINY_CONV_MODEL_DATA_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_TINY_CONV_MODEL_DATA_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_TINY_CONV_MODEL_DATA_H_ extern const unsigned char g_tiny_conv_model_data[]; extern const int g_tiny_conv_model_data_len; -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_TINY_CONV_MODEL_DATA_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_TINY_CONV_MODEL_DATA_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/yes_30ms_sample_data.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/yes_30ms_sample_data.cc new file mode 100644 index 00000000000000..f089ef82f3a7cc --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/yes_30ms_sample_data.cc @@ -0,0 +1,70 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See the header for documentation on the meaning of this data. + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/yes_30ms_sample_data.h" + +const int g_yes_30ms_sample_data_size = 480; +const int16_t g_yes_30ms_sample_data[480] = { + -876, -470, 510, 803, 170, -787, -1568, -1893, -1598, -1027, + -992, -1803, -2610, -2484, -1905, -2113, -3113, -3399, -2267, -1261, + -2007, -3637, -3909, -2340, -893, -1158, -2272, -2486, -1639, -915, + -777, -596, -91, 196, 85, 210, 875, 1373, 1247, 1219, + 1958, 2718, 2328, 1196, 1008, 2350, 3677, 3269, 1503, 366, + 922, 2264, 2810, 1996, 608, -168, 75, 680, 811, 395, + -56, -318, -607, -966, -1108, -925, -613, -368, -369, -919, + -1926, -2460, -1685, -300, 155, -611, -1524, -2204, -3227, -3859, + -2037, 1622, 2382, -2583, -8448, -7544, -84, 4814, 915, -6423, + -7558, -1746, 2515, -59, -4587, -3858, 1260, 3625, 187, -4148, + -3500, 1542, 5467, 4780, 1256, -1127, -403, 2481, 5332, 6346, + 5014, 2536, 1216, 2467, 5039, 6238, 5070, 3381, 3269, 4173, + 3905, 2248, 1586, 3299, 5240, 4362, 1004, -1382, -489, 2113, + 3168, 1620, -742, -1824, -1435, -897, -1058, -1500, -1545, -1398, + -1965, -3266, -4136, -3756, -2609, -1804, -1986, -3087, -4599, -5296, + -4051, -1731, -781, -2228, -4092, -3977, -2325, -1353, -1568, -1490, + -428, 178, -672, -1650, -1058, 749, 2039, 2079, 1540, 897, + 310, 572, 2266, 4265, 4265, 1869, -231, 559, 3332, 4752, + 3229, 768, 101, 1364, 2463, 1984, 819, 411, 723, 675, + -162, -923, -743, -32, 185, -516, -1653, -2359, -2103, -986, + 42, -205, -1702, -2870, -2337, -809, -221, -982, -1544, -946, + -598, -2117, -4291, -4100, -857, 1948, 338, -4799, -7972, -5403, + 173, 2371, -1063, -5533, -5578, -1777, 605, -985, -3249, -2213, + 1184, 2691, 560, -2356, -2288, 1233, 5244, 6441, 4004, 370, + -663, 2555, 7404, 9282, 6573, 2612, 1836, 4662, 7467, 7393, + 5421, 4262, 4741, 5362, 4705, 3163, 2397, 3337, 4887, 4810, + 2254, -749, -1316, 772, 2706, 2016, -573, -2552, -2746, -2012, + -1647, -1978, -2579, -3105, -3473, -3911, -4484, -4891, -4795, -4163, + -3543, -3538, -4275, -5356, -5743, -4637, -2614, -1301, -1825, -3341, + -4011, -2937, -751, 1007, 1245, 235, -639, -61, 1626, 2864, + 2967, 2734, 3013, 3329, 2914, 2312, 2666, 3839, 4308, 3162, + 1453, 768, 1255, 1887, 2006, 1715, 1031, -297, -1660, -1690, + -277, 813, -30, -2137, -3370, -2854, -1553, -593, -413, -1146, + -2567, -3440, -2369, -205, 379, -1258, -2315, -812, 262, -3205, + -8576, -7894, 738, 7492, 1951, -11595, -17098, -6934, 7139, 8065, + -4575, -14199, -8946, 3606, 7504, -547, -8242, -5113, 4406, 8113, + 2134, -5040, -4089, 4157, 10934, 10158, 4167, -565, -192, 4428, + 9765, 12201, 9861, 4512, 1225, 3451, 8483, 10133, 6497, 2574, + 3333, 6806, 6986, 2487, -1214, 623, 5416, 6647, 2204, -3289, + -4556, -1565, 1544, 1525, -1236, -4293, -5695, -5174, -3995, -3403, + -3449, -3750, -4505, -6014, -7296, -6523, -3849, -2096, -3288, -5722, + -6004, -3581, -1497, -1960, -3330, -2800, -434, 964, -111, -1739, + -1136, 1736, 4151, 3736, 1274, -451, 469, 3386, 5833, 5898, + 3646, 1085, 272, 1743, 4061, 5108, 3837, 1490, 246, 967, + 1866, 859, -1069, -974, 1542, 2835, 47, -4285, -5068, -1567, + 1781, 1223, -1997, -4227, -3747, -1720, 41, 245, -1228, -2972, + -2673, 22, 1980, -930, -7721, -11271, -5725, 4974, 8484, -2007, + -16979, -19255, -4670, 11057, 9690, -6417, -17537, -10841, 4262, 9292, +}; diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/yes_30ms_sample_data.h b/tensorflow/lite/experimental/micro/examples/micro_speech/yes_30ms_sample_data.h new file mode 100644 index 00000000000000..daaeb514a806d0 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/yes_30ms_sample_data.h @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This data was created from the PCM data in a WAV file held in v2 of the +// Speech Commands test dataset, at the path: +// speech_commands_test_set_v0.02/yes/f2e59fea_nohash_1.wav +// The data was extracted starting at an offset of 8,000, which corresponds to +// the 26th spectrogram slice. It's designed to be used to test the +// preprocessing pipeline, to ensure that the expected spectrogram slice is +// produced given this input. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_YES_30MS_SAMPLE_DATA_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_YES_30MS_SAMPLE_DATA_H_ + +#include + +extern const int g_yes_30ms_sample_data_size; +extern const int16_t g_yes_30ms_sample_data[]; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_YES_30MS_SAMPLE_DATA_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/yes_features_data.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.cc similarity index 99% rename from tensorflow/contrib/lite/experimental/micro/examples/micro_speech/yes_features_data.cc rename to tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.cc index 3ad29e53c83ddc..2eb737fb8e1204 100644 --- a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/yes_features_data.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/micro/examples/micro_speech/yes_features_data.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.h" /* File automatically created by * tensorflow/examples/speech_commands/wav_to_features.py \ diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/yes_features_data.h b/tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.h similarity index 74% rename from tensorflow/contrib/lite/experimental/micro/examples/micro_speech/yes_features_data.h rename to tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.h index 33ac2308624235..39a3bb914cc198 100644 --- a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/yes_features_data.h +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.h @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_YES_FEATURES_DATA_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_YES_FEATURES_DATA_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_YES_FEATURES_DATA_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_YES_FEATURES_DATA_H_ extern const int g_yes_f2e59fea_nohash_1_width; extern const int g_yes_f2e59fea_nohash_1_height; extern const unsigned char g_yes_f2e59fea_nohash_1_data[]; -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_YES_FEATURES_DATA_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_YES_FEATURES_DATA_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/yes_power_spectrum_data.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/yes_power_spectrum_data.cc new file mode 100644 index 00000000000000..9a34a2045a221e --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/yes_power_spectrum_data.cc @@ -0,0 +1,23 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See the header for documentation on the meaning of this data. + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/yes_power_spectrum_data.h" + +const uint8_t g_yes_power_spectrum_data[g_yes_power_spectrum_data_size] = { + 8, 89, 8, 0, 0, 0, 0, 0, 0, 0, 0, 4, 13, 1, 6, 23, 20, 6, 4, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +}; diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/yes_power_spectrum_data.h b/tensorflow/lite/experimental/micro/examples/micro_speech/yes_power_spectrum_data.h new file mode 100644 index 00000000000000..5c8c00ac1116dc --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/yes_power_spectrum_data.h @@ -0,0 +1,29 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This data was extracted from the larger feature data held in +// no_features_data.cc and consists of the 26th spectrogram slice of 43 values. +// This is the expected result of running the sample data in +// yes_30ms_sample_data.cc through through the preprocessing pipeline. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_YES_POWER_SPECTRUM_DATA_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_YES_POWER_SPECTRUM_DATA_H_ + +#include + +constexpr int g_yes_power_spectrum_data_size = 43; +extern const uint8_t g_yes_power_spectrum_data[]; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_YES_POWER_SPECTRUM_DATA_H_ diff --git a/tensorflow/lite/experimental/micro/kernels/BUILD b/tensorflow/lite/experimental/micro/kernels/BUILD new file mode 100644 index 00000000000000..a54fd41760d58f --- /dev/null +++ b/tensorflow/lite/experimental/micro/kernels/BUILD @@ -0,0 +1,107 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/lite:build_def.bzl", "tflite_copts") +load( + "//tensorflow/lite/experimental/micro/testing:micro_test.bzl", + "tflite_micro_cc_test", +) + +cc_library( + name = "micro_ops", + srcs = [ + "depthwise_conv.cc", + "fully_connected.cc", + "softmax.cc", + ], + hdrs = [ + ], + copts = tflite_copts(), + deps = [ + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/kernels:kernel_util", + "//tensorflow/lite/kernels:op_macros", + "//tensorflow/lite/kernels:padding", + "//tensorflow/lite/kernels/internal:quantization_util", + "//tensorflow/lite/kernels/internal:reference_base", + "//tensorflow/lite/kernels/internal:tensor", + ], +) + +cc_library( + name = "all_ops_resolver", + srcs = [ + "all_ops_resolver.cc", + ], + hdrs = [ + "all_ops_resolver.h", + ], + copts = tflite_copts(), + deps = [ + ":micro_ops", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + ], +) + +cc_library( + name = "test_utils", + srcs = [ + ], + hdrs = [ + "test_utils.h", + ], + copts = tflite_copts(), + deps = [ + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/core/api", + "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro/testing:micro_test", + ], +) + +tflite_micro_cc_test( + name = "depthwise_conv_test", + srcs = [ + "depthwise_conv_test.cc", + ], + deps = [ + ":all_ops_resolver", + ":test_utils", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro/testing:micro_test", + ], +) + +tflite_micro_cc_test( + name = "fully_connected_test", + srcs = [ + "fully_connected_test.cc", + ], + deps = [ + ":all_ops_resolver", + ":test_utils", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro/testing:micro_test", + ], +) + +tflite_micro_cc_test( + name = "softmax_test", + srcs = [ + "softmax_test.cc", + ], + deps = [ + ":all_ops_resolver", + ":test_utils", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro/testing:micro_test", + ], +) diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.cc b/tensorflow/lite/experimental/micro/kernels/all_ops_resolver.cc similarity index 95% rename from tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.cc rename to tensorflow/lite/experimental/micro/kernels/all_ops_resolver.cc index bd0a37badb8ab1..b733949e45df9c 100644 --- a/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.cc +++ b/tensorflow/lite/experimental/micro/kernels/all_ops_resolver.cc @@ -10,7 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h" +#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h b/tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h similarity index 70% rename from tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h rename to tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h index f836064a3f6344..b9ba8c882624bf 100644 --- a/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h +++ b/tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h @@ -9,11 +9,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_ALL_OPS_RESOLVER_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_ALL_OPS_RESOLVER_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_KERNELS_ALL_OPS_RESOLVER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_KERNELS_ALL_OPS_RESOLVER_H_ -#include "tensorflow/contrib/lite/experimental/micro/compatibility.h" -#include "tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h" +#include "tensorflow/lite/experimental/micro/compatibility.h" +#include "tensorflow/lite/experimental/micro/micro_mutable_op_resolver.h" namespace tflite { namespace ops { @@ -31,4 +31,4 @@ class AllOpsResolver : public MicroMutableOpResolver { } // namespace ops } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_ALL_OPS_RESOLVER_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_KERNELS_ALL_OPS_RESOLVER_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv.cc b/tensorflow/lite/experimental/micro/kernels/depthwise_conv.cc similarity index 93% rename from tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv.cc rename to tensorflow/lite/experimental/micro/kernels/depthwise_conv.cc index 4f17263181982a..ce821a94787796 100644 --- a/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv.cc +++ b/tensorflow/lite/experimental/micro/kernels/depthwise_conv.cc @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/common.h" -#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/padding.h" - -#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/padding.h" + +#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h" +#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test.cc b/tensorflow/lite/experimental/micro/kernels/depthwise_conv_test.cc similarity index 97% rename from tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test.cc rename to tensorflow/lite/experimental/micro/kernels/depthwise_conv_test.cc index 169899c471dd44..f70437a4b943e6 100644 --- a/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test.cc +++ b/tensorflow/lite/experimental/micro/kernels/depthwise_conv_test.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h" -#include "tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h" -#include "tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h" -#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h" +#include "tensorflow/lite/experimental/micro/kernels/test_utils.h" +#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h" +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" namespace tflite { namespace testing { diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected.cc b/tensorflow/lite/experimental/micro/kernels/fully_connected.cc similarity index 93% rename from tensorflow/contrib/lite/experimental/micro/kernels/fully_connected.cc rename to tensorflow/lite/experimental/micro/kernels/fully_connected.cc index 1e9e54cafb8c91..a344c4ffbeded9 100644 --- a/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected.cc +++ b/tensorflow/lite/experimental/micro/kernels/fully_connected.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h" -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/common.h" -#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected_test.cc b/tensorflow/lite/experimental/micro/kernels/fully_connected_test.cc similarity index 98% rename from tensorflow/contrib/lite/experimental/micro/kernels/fully_connected_test.cc rename to tensorflow/lite/experimental/micro/kernels/fully_connected_test.cc index b42bf4c3bca755..300f8aaf78ad38 100644 --- a/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected_test.cc +++ b/tensorflow/lite/experimental/micro/kernels/fully_connected_test.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h" -#include "tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h" -#include "tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h" -#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h" +#include "tensorflow/lite/experimental/micro/kernels/test_utils.h" +#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h" +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" namespace tflite { namespace testing { diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/softmax.cc b/tensorflow/lite/experimental/micro/kernels/softmax.cc similarity index 93% rename from tensorflow/contrib/lite/experimental/micro/kernels/softmax.cc rename to tensorflow/lite/experimental/micro/kernels/softmax.cc index a4019a067c563c..6d2d8b470fcad5 100644 --- a/tensorflow/contrib/lite/experimental/micro/kernels/softmax.cc +++ b/tensorflow/lite/experimental/micro/kernels/softmax.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/kernels/internal/reference/softmax.h" -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/common.h" -#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/kernels/internal/reference/softmax.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/softmax_test.cc b/tensorflow/lite/experimental/micro/kernels/softmax_test.cc similarity index 94% rename from tensorflow/contrib/lite/experimental/micro/kernels/softmax_test.cc rename to tensorflow/lite/experimental/micro/kernels/softmax_test.cc index 694456d8ace518..7253b3be8ce20f 100644 --- a/tensorflow/contrib/lite/experimental/micro/kernels/softmax_test.cc +++ b/tensorflow/lite/experimental/micro/kernels/softmax_test.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h" -#include "tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h" -#include "tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h" -#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h" +#include "tensorflow/lite/experimental/micro/kernels/test_utils.h" +#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h" +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" namespace tflite { namespace testing { diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h b/tensorflow/lite/experimental/micro/kernels/test_utils.h similarity index 91% rename from tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h rename to tensorflow/lite/experimental/micro/kernels/test_utils.h index 789a48ece8bd68..4207c609812f16 100644 --- a/tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h +++ b/tensorflow/lite/experimental/micro/kernels/test_utils.h @@ -12,18 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_ #include #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/core/api/error_reporter.h" -#include "tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h" -#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/experimental/micro/kernels/test_utils.h" +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" namespace tflite { namespace testing { @@ -167,4 +167,4 @@ inline TfLiteTensor CreateQuantized32Tensor(std::initializer_list data, } // namespace testing } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.cc b/tensorflow/lite/experimental/micro/micro_error_reporter.cc similarity index 96% rename from tensorflow/contrib/lite/experimental/micro/micro_error_reporter.cc rename to tensorflow/lite/experimental/micro/micro_error_reporter.cc index de11c2af5276e7..6bfe541f806306 100644 --- a/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.cc +++ b/tensorflow/lite/experimental/micro/micro_error_reporter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h" +#include "tensorflow/lite/experimental/micro/micro_error_reporter.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h b/tensorflow/lite/experimental/micro/micro_error_reporter.h similarity index 82% rename from tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h rename to tensorflow/lite/experimental/micro/micro_error_reporter.h index 21a014bee09a63..0ab853ec2ac915 100644 --- a/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h +++ b/tensorflow/lite/experimental/micro/micro_error_reporter.h @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_ERROR_REPORTER_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_ERROR_REPORTER_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_MICRO_ERROR_REPORTER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_MICRO_ERROR_REPORTER_H_ -#include "tensorflow/contrib/lite/core/api/error_reporter.h" -#include "tensorflow/contrib/lite/experimental/micro/compatibility.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/experimental/micro/compatibility.h" #ifdef TF_LITE_MCU_DEBUG_LOG // These functions should be supplied by the micro target library @@ -51,4 +51,4 @@ class MicroErrorReporter : public ErrorReporter { } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_ERROR_REPORTER_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_MICRO_ERROR_REPORTER_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/micro_error_reporter_test.cc b/tensorflow/lite/experimental/micro/micro_error_reporter_test.cc similarity index 93% rename from tensorflow/contrib/lite/experimental/micro/micro_error_reporter_test.cc rename to tensorflow/lite/experimental/micro/micro_error_reporter_test.cc index ef3c32050c0e82..ca89de9739fe1d 100644 --- a/tensorflow/contrib/lite/experimental/micro/micro_error_reporter_test.cc +++ b/tensorflow/lite/experimental/micro/micro_error_reporter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h" +#include "tensorflow/lite/experimental/micro/micro_error_reporter.h" int main(int argc, char** argv) { tflite::MicroErrorReporter micro_error_reporter; diff --git a/tensorflow/contrib/lite/experimental/micro/micro_interpreter.cc b/tensorflow/lite/experimental/micro/micro_interpreter.cc similarity index 98% rename from tensorflow/contrib/lite/experimental/micro/micro_interpreter.cc rename to tensorflow/lite/experimental/micro/micro_interpreter.cc index 5ece5edc31acc9..e0460c5d3e5cd8 100644 --- a/tensorflow/contrib/lite/experimental/micro/micro_interpreter.cc +++ b/tensorflow/lite/experimental/micro/micro_interpreter.cc @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/micro/micro_interpreter.h" +#include "tensorflow/lite/experimental/micro/micro_interpreter.h" -#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h" -#include "tensorflow/contrib/lite/experimental/micro/compatibility.h" +#include "tensorflow/lite/core/api/flatbuffer_conversions.h" +#include "tensorflow/lite/experimental/micro/compatibility.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/experimental/micro/micro_interpreter.h b/tensorflow/lite/experimental/micro/micro_interpreter.h similarity index 81% rename from tensorflow/contrib/lite/experimental/micro/micro_interpreter.h rename to tensorflow/lite/experimental/micro/micro_interpreter.h index a88514cde84959..6450dcce96204b 100644 --- a/tensorflow/contrib/lite/experimental/micro/micro_interpreter.h +++ b/tensorflow/lite/experimental/micro/micro_interpreter.h @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_INTERPRETER_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_INTERPRETER_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_MICRO_INTERPRETER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_MICRO_INTERPRETER_H_ -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/core/api/error_reporter.h" -#include "tensorflow/contrib/lite/core/api/op_resolver.h" -#include "tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h" -#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { @@ -68,4 +68,4 @@ class MicroInterpreter { } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_INTERPRETER_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_MICRO_INTERPRETER_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/micro_interpreter_test.cc b/tensorflow/lite/experimental/micro/micro_interpreter_test.cc similarity index 98% rename from tensorflow/contrib/lite/experimental/micro/micro_interpreter_test.cc rename to tensorflow/lite/experimental/micro/micro_interpreter_test.cc index 251e5f72037717..0c0c71f0792dd3 100644 --- a/tensorflow/contrib/lite/experimental/micro/micro_interpreter_test.cc +++ b/tensorflow/lite/experimental/micro/micro_interpreter_test.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/micro/micro_interpreter.h" +#include "tensorflow/lite/experimental/micro/micro_interpreter.h" -#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h" +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.cc b/tensorflow/lite/experimental/micro/micro_mutable_op_resolver.cc similarity index 97% rename from tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.cc rename to tensorflow/lite/experimental/micro/micro_mutable_op_resolver.cc index 40c21c6448c39f..1e8b5c0e573bd9 100644 --- a/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.cc +++ b/tensorflow/lite/experimental/micro/micro_mutable_op_resolver.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h" +#include "tensorflow/lite/experimental/micro/micro_mutable_op_resolver.h" namespace tflite { diff --git a/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h b/tensorflow/lite/experimental/micro/micro_mutable_op_resolver.h similarity index 79% rename from tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h rename to tensorflow/lite/experimental/micro/micro_mutable_op_resolver.h index f3750a248416cc..f613203909e2d4 100644 --- a/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h +++ b/tensorflow/lite/experimental/micro/micro_mutable_op_resolver.h @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_ -#include "tensorflow/contrib/lite/core/api/op_resolver.h" -#include "tensorflow/contrib/lite/experimental/micro/compatibility.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/experimental/micro/compatibility.h" #ifndef TFLITE_REGISTRATIONS_MAX #define TFLITE_REGISTRATIONS_MAX (128) @@ -43,4 +43,4 @@ class MicroMutableOpResolver : public OpResolver { } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver_test.cc b/tensorflow/lite/experimental/micro/micro_mutable_op_resolver_test.cc similarity index 94% rename from tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver_test.cc rename to tensorflow/lite/experimental/micro/micro_mutable_op_resolver_test.cc index 5420a33e8778d9..f551830865dd93 100644 --- a/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver_test.cc +++ b/tensorflow/lite/experimental/micro/micro_mutable_op_resolver_test.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h" +#include "tensorflow/lite/experimental/micro/micro_mutable_op_resolver.h" -#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h" +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.cc b/tensorflow/lite/experimental/micro/simple_tensor_allocator.cc similarity index 97% rename from tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.cc rename to tensorflow/lite/experimental/micro/simple_tensor_allocator.cc index 555e53afeff837..6ce14edea53dec 100644 --- a/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.cc +++ b/tensorflow/lite/experimental/micro/simple_tensor_allocator.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h" +#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h" -#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h" +#include "tensorflow/lite/core/api/flatbuffer_conversions.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h b/tensorflow/lite/experimental/micro/simple_tensor_allocator.h similarity index 78% rename from tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h rename to tensorflow/lite/experimental/micro/simple_tensor_allocator.h index 56fb293675808c..3530ecdfe265f8 100644 --- a/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h +++ b/tensorflow/lite/experimental/micro/simple_tensor_allocator.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_SIMPLE_TENSOR_ALLOCATOR_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_SIMPLE_TENSOR_ALLOCATOR_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_SIMPLE_TENSOR_ALLOCATOR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_SIMPLE_TENSOR_ALLOCATOR_H_ -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/core/api/error_reporter.h" -#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { @@ -48,4 +48,4 @@ class SimpleTensorAllocator { } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_SIMPLE_TENSOR_ALLOCATOR_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_SIMPLE_TENSOR_ALLOCATOR_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator_test.cc b/tensorflow/lite/experimental/micro/simple_tensor_allocator_test.cc similarity index 97% rename from tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator_test.cc rename to tensorflow/lite/experimental/micro/simple_tensor_allocator_test.cc index ab19394502281d..b82017c7fe60e9 100644 --- a/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator_test.cc +++ b/tensorflow/lite/experimental/micro/simple_tensor_allocator_test.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/micro/micro_interpreter.h" +#include "tensorflow/lite/experimental/micro/micro_interpreter.h" -#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h" +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/experimental/micro/testing/BUILD b/tensorflow/lite/experimental/micro/testing/BUILD similarity index 77% rename from tensorflow/contrib/lite/experimental/micro/testing/BUILD rename to tensorflow/lite/experimental/micro/testing/BUILD index 0d23be5712ad1b..5a31a709ca3f02 100644 --- a/tensorflow/contrib/lite/experimental/micro/testing/BUILD +++ b/tensorflow/lite/experimental/micro/testing/BUILD @@ -12,6 +12,6 @@ cc_library( "micro_test.h", ], deps = [ - "//tensorflow/contrib/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro:micro_framework", ], ) diff --git a/tensorflow/contrib/lite/experimental/micro/testing/Dockerfile.bluepill b/tensorflow/lite/experimental/micro/testing/Dockerfile.bluepill similarity index 100% rename from tensorflow/contrib/lite/experimental/micro/testing/Dockerfile.bluepill rename to tensorflow/lite/experimental/micro/testing/Dockerfile.bluepill diff --git a/tensorflow/contrib/lite/experimental/micro/testing/bluepill.resc b/tensorflow/lite/experimental/micro/testing/bluepill.resc similarity index 100% rename from tensorflow/contrib/lite/experimental/micro/testing/bluepill.resc rename to tensorflow/lite/experimental/micro/testing/bluepill.resc diff --git a/tensorflow/contrib/lite/experimental/micro/testing/bluepill.robot b/tensorflow/lite/experimental/micro/testing/bluepill.robot similarity index 94% rename from tensorflow/contrib/lite/experimental/micro/testing/bluepill.robot rename to tensorflow/lite/experimental/micro/testing/bluepill.robot index f09c3a0cc0df84..37612168576280 100644 --- a/tensorflow/contrib/lite/experimental/micro/testing/bluepill.robot +++ b/tensorflow/lite/experimental/micro/testing/bluepill.robot @@ -17,7 +17,7 @@ Should Run Bluepill Test Execute Command $bin = @${BIN} Execute Script ${SCRIPT} - Create Terminal Tester ${UART} timeout=3 + Create Terminal Tester ${UART} timeout=30 Start Emulation Wait For Line On Uart ${EXPECTED} diff --git a/tensorflow/contrib/lite/experimental/micro/testing/micro_test.bzl b/tensorflow/lite/experimental/micro/testing/micro_test.bzl similarity index 96% rename from tensorflow/contrib/lite/experimental/micro/testing/micro_test.bzl rename to tensorflow/lite/experimental/micro/testing/micro_test.bzl index 916e3eeac394f9..7a7ba15ca5fca8 100644 --- a/tensorflow/contrib/lite/experimental/micro/testing/micro_test.bzl +++ b/tensorflow/lite/experimental/micro/testing/micro_test.bzl @@ -51,7 +51,7 @@ def tflite_micro_cc_test( name = name, size = "medium", srcs = [ - "//tensorflow/contrib/lite/experimental/micro/testing:test_linux_binary.sh", + "//tensorflow/lite/experimental/micro/testing:test_linux_binary.sh", ], args = [ native.package_name() + "/" + name + "_binary", diff --git a/tensorflow/contrib/lite/experimental/micro/testing/micro_test.h b/tensorflow/lite/experimental/micro/testing/micro_test.h similarity index 96% rename from tensorflow/contrib/lite/experimental/micro/testing/micro_test.h rename to tensorflow/lite/experimental/micro/testing/micro_test.h index 3b6554dea6a59f..10bab05faec9fd 100644 --- a/tensorflow/contrib/lite/experimental/micro/testing/micro_test.h +++ b/tensorflow/lite/experimental/micro/testing/micro_test.h @@ -51,10 +51,10 @@ limitations under the License. // all on systems that struggle to run more conventional approaches, so use with // caution! -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_TESTING_MICRO_TEST_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_TESTING_MICRO_TEST_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_TESTING_MICRO_TEST_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_TESTING_MICRO_TEST_H_ -#include "tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h" +#include "tensorflow/lite/experimental/micro/micro_error_reporter.h" namespace micro_test { extern int tests_passed; @@ -153,4 +153,4 @@ extern tflite::ErrorReporter* reporter; } \ } while (false) -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_TESTING_MICRO_TEST_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_TESTING_MICRO_TEST_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh b/tensorflow/lite/experimental/micro/testing/test_bluepill_binary.sh similarity index 83% rename from tensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh rename to tensorflow/lite/experimental/micro/testing/test_bluepill_binary.sh index a470dc52f8d840..e288c6cf568314 100755 --- a/tensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh +++ b/tensorflow/lite/experimental/micro/testing/test_bluepill_binary.sh @@ -27,8 +27,8 @@ declare -r MICRO_LOG_FILENAME=${MICRO_LOG_PATH}/logs.txt mkdir -p ${MICRO_LOG_PATH} docker build -t renode_bluepill \ - -f ${ROOT_DIR}/tensorflow/contrib/lite/experimental/micro/testing/Dockerfile.bluepill \ - ${ROOT_DIR}/tensorflow/contrib/lite/experimental/micro/testing/ + -f ${ROOT_DIR}/tensorflow/lite/experimental/micro/testing/Dockerfile.bluepill \ + ${ROOT_DIR}/tensorflow/lite/experimental/micro/testing/ exit_code=0 # running in `if` to avoid setting +e @@ -37,10 +37,10 @@ if ! docker run \ -v ${ROOT_DIR}:/workspace \ -v /tmp:/tmp \ -e BIN=/workspace/$1 \ - -e SCRIPT=/workspace/tensorflow/contrib/lite/experimental/micro/testing/bluepill.resc \ + -e SCRIPT=/workspace/tensorflow/lite/experimental/micro/testing/bluepill.resc \ -e EXPECTED="$2" \ -it renode_bluepill \ - /bin/bash -c "/opt/renode/tests/test.sh /workspace/tensorflow/contrib/lite/experimental/micro/testing/bluepill.robot 2>&1 >${MICRO_LOG_FILENAME}" + /bin/bash -c "/opt/renode/tests/test.sh /workspace/tensorflow/lite/experimental/micro/testing/bluepill.robot 2>&1 >${MICRO_LOG_FILENAME}" then exit_code=1 fi diff --git a/tensorflow/contrib/lite/experimental/micro/testing/test_linux_binary.sh b/tensorflow/lite/experimental/micro/testing/test_linux_binary.sh similarity index 100% rename from tensorflow/contrib/lite/experimental/micro/testing/test_linux_binary.sh rename to tensorflow/lite/experimental/micro/testing/test_linux_binary.sh diff --git a/tensorflow/contrib/lite/experimental/micro/tools/make/Makefile b/tensorflow/lite/experimental/micro/tools/make/Makefile similarity index 63% rename from tensorflow/contrib/lite/experimental/micro/tools/make/Makefile rename to tensorflow/lite/experimental/micro/tools/make/Makefile index 3f749e53ef1aa9..5492003e5af2f3 100644 --- a/tensorflow/contrib/lite/experimental/micro/tools/make/Makefile +++ b/tensorflow/lite/experimental/micro/tools/make/Makefile @@ -1,4 +1,4 @@ -MAKEFILE_DIR := tensorflow/contrib/lite/experimental/micro/tools/make +MAKEFILE_DIR := tensorflow/lite/experimental/micro/tools/make # Try to figure out the host system HOST_OS := @@ -17,7 +17,7 @@ endif HOST_ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32; else echo $(shell uname -m); fi) # Override these on the make command line to target a specific architecture. For example: -# make -f tensorflow/contrib/lite/Makefile TARGET=rpi TARGET_ARCH=armv7l +# make -f tensorflow/lite/Makefile TARGET=rpi TARGET_ARCH=armv7l TARGET := $(HOST_OS) TARGET_ARCH := $(HOST_ARCH) @@ -33,7 +33,7 @@ INCLUDES := \ # override local versions in the source tree. INCLUDES += -I/usr/local/include -TEST_SCRIPT := tensorflow/contrib/lite/experimental/micro/testing/test_linux_binary.sh +TEST_SCRIPT := tensorflow/lite/experimental/micro/testing/test_linux_binary.sh MICROLITE_LIBS := -lm @@ -54,24 +54,33 @@ MICROLITE_LIB_NAME := libtensorflow-microlite.a # Test binary for the microcontroller speech model. MICRO_SPEECH_TEST_SRCS := \ -tensorflow/contrib/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc \ -tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc \ -tensorflow/contrib/lite/experimental/micro/examples/micro_speech/no_features_data.cc \ -tensorflow/contrib/lite/experimental/micro/examples/micro_speech/yes_features_data.cc +tensorflow/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.cc + +# Test binary for the microcontroller speech model. +PREPROCESSOR_TEST_SRCS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor_test.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/no_30ms_sample_data.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/yes_30ms_sample_data.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/no_power_spectrum_data.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/yes_power_spectrum_data.cc MICROLITE_TEST_SRCS := \ -$(wildcard tensorflow/contrib/lite/experimental/micro/*test.cc) \ -$(wildcard tensorflow/contrib/lite/experimental/micro/kernels/*test.cc) +$(wildcard tensorflow/lite/experimental/micro/*test.cc) \ +$(wildcard tensorflow/lite/experimental/micro/kernels/*test.cc) MICROLITE_CC_BASE_SRCS := \ -$(wildcard tensorflow/contrib/lite/experimental/micro/*.cc) \ -$(wildcard tensorflow/contrib/lite/experimental/micro/kernels/*.cc) \ -tensorflow/contrib/lite/c/c_api_internal.c \ -tensorflow/contrib/lite/core/api/error_reporter.cc \ -tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc \ -tensorflow/contrib/lite/core/api/op_resolver.cc \ -tensorflow/contrib/lite/kernels/kernel_util.cc \ -tensorflow/contrib/lite/kernels/internal/quantization_util.cc +$(wildcard tensorflow/lite/experimental/micro/*.cc) \ +$(wildcard tensorflow/lite/experimental/micro/kernels/*.cc) \ +tensorflow/lite/c/c_api_internal.c \ +tensorflow/lite/core/api/error_reporter.cc \ +tensorflow/lite/core/api/flatbuffer_conversions.cc \ +tensorflow/lite/core/api/op_resolver.cc \ +tensorflow/lite/kernels/kernel_util.cc \ +tensorflow/lite/kernels/internal/quantization_util.cc MICROLITE_CC_SRCS := $(filter-out $(MICROLITE_TEST_SRCS), $(MICROLITE_CC_BASE_SRCS)) # These target-specific makefiles should modify or replace options like @@ -82,6 +91,7 @@ include $(wildcard $(MAKEFILE_DIR)/targets/*_makefile.inc) ALL_SRCS := \ $(MICRO_SPEECH_TEST_SRCS) \ + $(PREPROCESSOR_TEST_SRCS) \ $(MICROLITE_CC_SRCS) \ $(MICROLITE_TEST_SRCS) @@ -94,6 +104,7 @@ LIBDIR := $(GENDIR)lib/ MICROLITE_LIB_PATH := $(LIBDIR)$(MICROLITE_LIB_NAME) MICRO_SPEECH_TEST_BINARY := $(BINDIR)micro_speech_test +PREPROCESSOR_TEST_BINARY := $(BINDIR)preprocessor_test CXX := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}g++ CC := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}gcc @@ -102,6 +113,9 @@ AR := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}ar MICRO_SPEECH_TEST_OBJS := $(addprefix $(OBJDIR), \ $(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MICRO_SPEECH_TEST_SRCS)))) +PREPROCESSOR_TEST_OBJS := $(addprefix $(OBJDIR), \ +$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(PREPROCESSOR_TEST_SRCS)))) + MICROLITE_LIB_OBJS := $(addprefix $(OBJDIR), \ $(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MICROLITE_CC_SRCS)))) @@ -119,16 +133,16 @@ $(OBJDIR)%.o: %.c $(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@ # The target that's compiled if there's no command-line arguments. -all: $(MICROLITE_LIB_PATH) $(MICRO_SPEECH_TEST_BINARY) +all: $(MICROLITE_LIB_PATH) $(MICRO_SPEECH_TEST_BINARY) $(PREPROCESSOR_TEST_BINARY) microlite: $(MICROLITE_LIB_PATH) # Hack for generating schema file bypassing flatbuffer parsing -tensorflow/contrib/lite/schema/schema_generated.h: - @cp -u tensorflow/contrib/lite/schema/schema_generated.h.OPENSOURCE tensorflow/contrib/lite/schema/schema_generated.h +tensorflow/lite/schema/schema_generated.h: + @cp -u tensorflow/lite/schema/schema_generated.h.OPENSOURCE tensorflow/lite/schema/schema_generated.h # Gathers together all the objects we've compiled into a single '.a' archive. -$(MICROLITE_LIB_PATH): tensorflow/contrib/lite/schema/schema_generated.h $(MICROLITE_LIB_OBJS) +$(MICROLITE_LIB_PATH): tensorflow/lite/schema/schema_generated.h $(MICROLITE_LIB_OBJS) @mkdir -p $(dir $@) $(AR) $(ARFLAGS) $(MICROLITE_LIB_PATH) $(MICROLITE_LIB_OBJS) @@ -144,6 +158,19 @@ micro_speech_test_bin: $(MICRO_SPEECH_TEST_BINARY).bin test_micro_speech: $(MICRO_SPEECH_TEST_BINARY) $(TEST_SCRIPT) $(MICRO_SPEECH_TEST_BINARY) '~~~ALL TESTS PASSED~~~' +$(PREPROCESSOR_TEST_BINARY): $(PREPROCESSOR_TEST_OBJS) $(MICROLITE_LIB_PATH) + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) $(INCLUDES) \ + -o $(PREPROCESSOR_TEST_BINARY) $(PREPROCESSOR_TEST_OBJS) \ + $(LIBFLAGS) $(MICROLITE_LIB_PATH) $(LDFLAGS) $(MICROLITE_LIBS) + +preprocessor_test: $(PREPROCESSOR_TEST_BINARY) +preprocessor_test_bin: $(PREPROCESSOR_TEST_BINARY).bin + +test_preprocessor: $(PREPROCESSOR_TEST_BINARY) + $(TEST_SCRIPT) $(PREPROCESSOR_TEST_BINARY) '~~~ALL TESTS PASSED~~~' + + $(BINDIR)%_test : $(OBJDIR)%_test.o $(MICROLITE_LIB_PATH) @mkdir -p $(dir $@) $(CXX) $(CXXFLAGS) $(INCLUDES) \ diff --git a/tensorflow/contrib/lite/experimental/micro/tools/make/download_dependencies.sh b/tensorflow/lite/experimental/micro/tools/make/download_dependencies.sh similarity index 95% rename from tensorflow/contrib/lite/experimental/micro/tools/make/download_dependencies.sh rename to tensorflow/lite/experimental/micro/tools/make/download_dependencies.sh index 62402efddd86af..6749858bdb9ffe 100755 --- a/tensorflow/contrib/lite/experimental/micro/tools/make/download_dependencies.sh +++ b/tensorflow/lite/experimental/micro/tools/make/download_dependencies.sh @@ -17,9 +17,9 @@ set -e SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -cd "$SCRIPT_DIR/../../../../../../.." +cd "$SCRIPT_DIR/../../../../../.." -DOWNLOADS_DIR=tensorflow/contrib/lite/experimental/micro/tools/make/downloads +DOWNLOADS_DIR=tensorflow/lite/experimental/micro/tools/make/downloads BZL_FILE_PATH=tensorflow/workspace.bzl # Ensure it is being run from repo root diff --git a/tensorflow/contrib/lite/experimental/micro/tools/make/targets/apollo3evb/_main.c b/tensorflow/lite/experimental/micro/tools/make/targets/apollo3evb/_main.c similarity index 100% rename from tensorflow/contrib/lite/experimental/micro/tools/make/targets/apollo3evb/_main.c rename to tensorflow/lite/experimental/micro/tools/make/targets/apollo3evb/_main.c diff --git a/tensorflow/contrib/lite/experimental/micro/tools/make/targets/apollo3evb/apollo3evb.ld b/tensorflow/lite/experimental/micro/tools/make/targets/apollo3evb/apollo3evb.ld similarity index 100% rename from tensorflow/contrib/lite/experimental/micro/tools/make/targets/apollo3evb/apollo3evb.ld rename to tensorflow/lite/experimental/micro/tools/make/targets/apollo3evb/apollo3evb.ld diff --git a/tensorflow/contrib/lite/experimental/micro/tools/make/targets/apollo3evb_makefile.inc b/tensorflow/lite/experimental/micro/tools/make/targets/apollo3evb_makefile.inc similarity index 93% rename from tensorflow/contrib/lite/experimental/micro/tools/make/targets/apollo3evb_makefile.inc rename to tensorflow/lite/experimental/micro/tools/make/targets/apollo3evb_makefile.inc index 86c5af69a927ff..f722204feaded5 100644 --- a/tensorflow/contrib/lite/experimental/micro/tools/make/targets/apollo3evb_makefile.inc +++ b/tensorflow/lite/experimental/micro/tools/make/targets/apollo3evb_makefile.inc @@ -86,11 +86,11 @@ ifeq ($(TARGET), apollo3evb) $(MAKEFILE_DIR)/targets/apollo3evb/am_util_id.c \ $(MAKEFILE_DIR)/targets/apollo3evb/am_util_stdio.c - TEST_SCRIPT := tensorflow/contrib/lite/experimental/log_test/test_apollo3evb_binary.sh + TEST_SCRIPT := tensorflow/lite/experimental/log_test/test_apollo3evb_binary.sh # These are tests that don't currently work on the blue pill. EXCLUDED_TESTS := \ - tensorflow/contrib/lite/experimental/micro/micro_interpreter_test.cc \ - tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator_test.cc + tensorflow/lite/experimental/micro/micro_interpreter_test.cc \ + tensorflow/lite/experimental/micro/simple_tensor_allocator_test.cc MICROLITE_TEST_SRCS := $(filter-out $(EXCLUDED_TESTS), $(MICROLITE_TEST_SRCS)) # These are microcontroller-specific rules for converting the ELF output diff --git a/tensorflow/contrib/lite/experimental/micro/tools/make/targets/bluepill_makefile.inc b/tensorflow/lite/experimental/micro/tools/make/targets/bluepill_makefile.inc similarity index 87% rename from tensorflow/contrib/lite/experimental/micro/tools/make/targets/bluepill_makefile.inc rename to tensorflow/lite/experimental/micro/tools/make/targets/bluepill_makefile.inc index 022a8422dc89c0..5e3105a109b99b 100644 --- a/tensorflow/contrib/lite/experimental/micro/tools/make/targets/bluepill_makefile.inc +++ b/tensorflow/lite/experimental/micro/tools/make/targets/bluepill_makefile.inc @@ -47,11 +47,11 @@ ifeq ($(TARGET), bluepill) MICROLITE_CC_SRCS += \ $(wildcard $(MAKEFILE_DIR)/downloads/stm32_bare_lib/source/*.c) \ $(wildcard $(MAKEFILE_DIR)/downloads/stm32_bare_lib/source/*.cc) - TEST_SCRIPT := tensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh + TEST_SCRIPT := tensorflow/lite/experimental/micro/testing/test_bluepill_binary.sh # These are tests that don't currently work on the blue pill. EXCLUDED_TESTS := \ - tensorflow/contrib/lite/experimental/micro/micro_interpreter_test.cc \ - tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator_test.cc + tensorflow/lite/experimental/micro/micro_interpreter_test.cc \ + tensorflow/lite/experimental/micro/simple_tensor_allocator_test.cc MICROLITE_TEST_SRCS := $(filter-out $(EXCLUDED_TESTS), $(MICROLITE_TEST_SRCS)) # These are microcontroller-specific rules for converting the ELF output diff --git a/tensorflow/lite/experimental/micro/tools/make/targets/linux_x86_makefile.inc b/tensorflow/lite/experimental/micro/tools/make/targets/linux_x86_makefile.inc new file mode 100644 index 00000000000000..8ea78e8f3e3db7 --- /dev/null +++ b/tensorflow/lite/experimental/micro/tools/make/targets/linux_x86_makefile.inc @@ -0,0 +1,9 @@ +# Settings for x86 on Linux +ifeq ($(TARGET), linux) + ifeq ($(TARGET_ARCH), x86_64) + PLATFORM_FLAGS = \ + -DTF_LITE_DISABLE_X86_NEON + CXXFLAGS += $(PLATFORM_FLAGS) + CCFLAGS += $(PLATFORM_FLAGS) + endif +endif diff --git a/tensorflow/contrib/lite/experimental/microfrontend/BUILD b/tensorflow/lite/experimental/microfrontend/BUILD similarity index 81% rename from tensorflow/contrib/lite/experimental/microfrontend/BUILD rename to tensorflow/lite/experimental/microfrontend/BUILD index 5005d0bbae8a71..2f881e3acabdae 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/BUILD +++ b/tensorflow/lite/experimental/microfrontend/BUILD @@ -20,10 +20,10 @@ cc_library( name = "audio_microfrontend", srcs = ["audio_microfrontend.cc"], deps = [ - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/experimental/microfrontend/lib:frontend", - "//tensorflow/contrib/lite/kernels:kernel_util", - "//tensorflow/contrib/lite/kernels/internal:reference", + "//tensorflow/lite:framework", + "//tensorflow/lite/experimental/microfrontend/lib:frontend", + "//tensorflow/lite/kernels:kernel_util", + "//tensorflow/lite/kernels/internal:reference", "@flatbuffers", ], ) @@ -35,8 +35,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":audio_microfrontend", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", "@flatbuffers", ], @@ -48,15 +48,15 @@ tf_custom_op_library( "ops/audio_microfrontend_op.cc", ], deps = [ - "//tensorflow/contrib/lite/experimental/microfrontend/lib:frontend", + "//tensorflow/lite/experimental/microfrontend/lib:frontend", ], ) tf_gen_op_libs( op_lib_names = ["audio_microfrontend_op"], deps = [ - "//tensorflow/contrib/lite/experimental/microfrontend/lib:frontend", "//tensorflow/core:lib", + "//tensorflow/lite/experimental/microfrontend/lib:frontend", ], ) diff --git a/tensorflow/contrib/lite/experimental/microfrontend/audio_microfrontend.cc b/tensorflow/lite/experimental/microfrontend/audio_microfrontend.cc similarity index 95% rename from tensorflow/contrib/lite/experimental/microfrontend/audio_microfrontend.cc rename to tensorflow/lite/experimental/microfrontend/audio_microfrontend.cc index 8819d070b85a33..4367fe74a48444 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/audio_microfrontend.cc +++ b/tensorflow/lite/experimental/microfrontend/audio_microfrontend.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "flatbuffers/flexbuffers.h" // TF:flatbuffers -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_util.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/experimental/microfrontend/lib/frontend.h" +#include "tensorflow/lite/experimental/microfrontend/lib/frontend_util.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/experimental/microfrontend/audio_microfrontend_test.cc b/tensorflow/lite/experimental/microfrontend/audio_microfrontend_test.cc similarity index 97% rename from tensorflow/contrib/lite/experimental/microfrontend/audio_microfrontend_test.cc rename to tensorflow/lite/experimental/microfrontend/audio_microfrontend_test.cc index 9da58d960679a3..a9119d01831f68 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/audio_microfrontend_test.cc +++ b/tensorflow/lite/experimental/microfrontend/audio_microfrontend_test.cc @@ -20,9 +20,9 @@ limitations under the License. #include #include #include "flatbuffers/flexbuffers.h" // TF:flatbuffers -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/BUILD b/tensorflow/lite/experimental/microfrontend/lib/BUILD similarity index 100% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/BUILD rename to tensorflow/lite/experimental/microfrontend/lib/BUILD diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/README b/tensorflow/lite/experimental/microfrontend/lib/README similarity index 100% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/README rename to tensorflow/lite/experimental/microfrontend/lib/README diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/bits.h b/tensorflow/lite/experimental/microfrontend/lib/bits.h similarity index 92% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/bits.h rename to tensorflow/lite/experimental/microfrontend/lib/bits.h index f81bc2b023e62a..bf15466a3d6484 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/bits.h +++ b/tensorflow/lite/experimental/microfrontend/lib/bits.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_BITS_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_BITS_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_BITS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_BITS_H_ #ifdef __cplusplus #include @@ -99,4 +99,4 @@ static inline int MostSignificantBit64(uint64_t n) { } // extern "C" #endif -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_BITS_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_BITS_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/fft.c b/tensorflow/lite/experimental/microfrontend/lib/fft.c similarity index 96% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/fft.c rename to tensorflow/lite/experimental/microfrontend/lib/fft.c index 1ecbb30b514294..c1dd62fb7d4254 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/fft.c +++ b/tensorflow/lite/experimental/microfrontend/lib/fft.c @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/fft.h" +#include "tensorflow/lite/experimental/microfrontend/lib/fft.h" #include diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/fft.h b/tensorflow/lite/experimental/microfrontend/lib/fft.h similarity index 84% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/fft.h rename to tensorflow/lite/experimental/microfrontend/lib/fft.h index e7644bf2a70f51..aaffa69debb17d 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/fft.h +++ b/tensorflow/lite/experimental/microfrontend/lib/fft.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_H_ #include #include @@ -47,4 +47,4 @@ void FftReset(struct FftState* state); } // extern "C" #endif -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/fft_io.c b/tensorflow/lite/experimental/microfrontend/lib/fft_io.c similarity index 95% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/fft_io.c rename to tensorflow/lite/experimental/microfrontend/lib/fft_io.c index cc1ce209d8501b..b01a8848e9d5cb 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/fft_io.c +++ b/tensorflow/lite/experimental/microfrontend/lib/fft_io.c @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/fft_io.h" +#include "tensorflow/lite/experimental/microfrontend/lib/fft_io.h" void FftWriteMemmapPreamble(FILE* fp, const struct FftState* state) { fprintf(fp, "static int16_t fft_input[%zu];\n", state->fft_size); diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/fft_io.h b/tensorflow/lite/experimental/microfrontend/lib/fft_io.h similarity index 76% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/fft_io.h rename to tensorflow/lite/experimental/microfrontend/lib/fft_io.h index 4d10c3a92af7e8..7a59af68266381 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/fft_io.h +++ b/tensorflow/lite/experimental/microfrontend/lib/fft_io.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_IO_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_IO_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_IO_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_IO_H_ #include -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/fft.h" +#include "tensorflow/lite/experimental/microfrontend/lib/fft.h" #ifdef __cplusplus extern "C" { @@ -31,4 +31,4 @@ void FftWriteMemmap(FILE* fp, const struct FftState* state, } // extern "C" #endif -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_IO_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_IO_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/fft_test.cc b/tensorflow/lite/experimental/microfrontend/lib/fft_test.cc similarity index 92% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/fft_test.cc rename to tensorflow/lite/experimental/microfrontend/lib/fft_test.cc index b8684a0b5c0187..7c1ee2d852201c 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/fft_test.cc +++ b/tensorflow/lite/experimental/microfrontend/lib/fft_test.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/fft.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/fft_util.h" +#include "tensorflow/lite/experimental/microfrontend/lib/fft.h" +#include "tensorflow/lite/experimental/microfrontend/lib/fft_util.h" #include #include diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/fft_util.c b/tensorflow/lite/experimental/microfrontend/lib/fft_util.c similarity index 96% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/fft_util.c rename to tensorflow/lite/experimental/microfrontend/lib/fft_util.c index 55494422f375e1..40cb9f87358087 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/fft_util.c +++ b/tensorflow/lite/experimental/microfrontend/lib/fft_util.c @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/fft_util.h" +#include "tensorflow/lite/experimental/microfrontend/lib/fft_util.h" #include diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/fft_util.h b/tensorflow/lite/experimental/microfrontend/lib/fft_util.h similarity index 76% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/fft_util.h rename to tensorflow/lite/experimental/microfrontend/lib/fft_util.h index 4935e87fc1ab8b..6a471301c3f0a5 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/fft_util.h +++ b/tensorflow/lite/experimental/microfrontend/lib/fft_util.h @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_UTIL_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_UTIL_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_UTIL_H_ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/fft.h" +#include "tensorflow/lite/experimental/microfrontend/lib/fft.h" #ifdef __cplusplus extern "C" { @@ -31,4 +31,4 @@ void FftFreeStateContents(struct FftState* state); } // extern "C" #endif -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_UTIL_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FFT_UTIL_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank.c b/tensorflow/lite/experimental/microfrontend/lib/filterbank.c similarity index 96% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank.c rename to tensorflow/lite/experimental/microfrontend/lib/filterbank.c index 944eb1a7379746..22cfaf78ab4662 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank.c +++ b/tensorflow/lite/experimental/microfrontend/lib/filterbank.c @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank.h" +#include "tensorflow/lite/experimental/microfrontend/lib/filterbank.h" #include -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/bits.h" +#include "tensorflow/lite/experimental/microfrontend/lib/bits.h" void FilterbankConvertFftComplexToEnergy(struct FilterbankState* state, struct complex_int16_t* fft_output, diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank.h b/tensorflow/lite/experimental/microfrontend/lib/filterbank.h similarity index 86% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank.h rename to tensorflow/lite/experimental/microfrontend/lib/filterbank.h index 0dd9c3fa651680..1e6d3885f2c227 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank.h +++ b/tensorflow/lite/experimental/microfrontend/lib/filterbank.h @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_H_ #include #include -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/fft.h" +#include "tensorflow/lite/experimental/microfrontend/lib/fft.h" #define kFilterbankBits 12 @@ -60,4 +60,4 @@ void FilterbankReset(struct FilterbankState* state); } // extern "C" #endif -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_io.c b/tensorflow/lite/experimental/microfrontend/lib/filterbank_io.c similarity index 97% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_io.c rename to tensorflow/lite/experimental/microfrontend/lib/filterbank_io.c index 672ddd530f847c..2dbb4b3bf09654 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_io.c +++ b/tensorflow/lite/experimental/microfrontend/lib/filterbank_io.c @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_io.h" +#include "tensorflow/lite/experimental/microfrontend/lib/filterbank_io.h" static void PrintArray(FILE* fp, const char* name, const int16_t* values, size_t size) { diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_io.h b/tensorflow/lite/experimental/microfrontend/lib/filterbank_io.h similarity index 75% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_io.h rename to tensorflow/lite/experimental/microfrontend/lib/filterbank_io.h index 1ddc314df2234d..5fc96845897c56 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_io.h +++ b/tensorflow/lite/experimental/microfrontend/lib/filterbank_io.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_IO_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_IO_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_IO_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_IO_H_ #include -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank.h" +#include "tensorflow/lite/experimental/microfrontend/lib/filterbank.h" #ifdef __cplusplus extern "C" { @@ -32,4 +32,4 @@ void FilterbankWriteMemmap(FILE* fp, const struct FilterbankState* state, } // extern "C" #endif -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_IO_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_IO_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_test.cc b/tensorflow/lite/experimental/microfrontend/lib/filterbank_test.cc similarity index 97% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_test.cc rename to tensorflow/lite/experimental/microfrontend/lib/filterbank_test.cc index 88d8de4b8f0ebd..808d527186eaa9 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_test.cc +++ b/tensorflow/lite/experimental/microfrontend/lib/filterbank_test.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_util.h" +#include "tensorflow/lite/experimental/microfrontend/lib/filterbank.h" +#include "tensorflow/lite/experimental/microfrontend/lib/filterbank_util.h" #include diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_util.c b/tensorflow/lite/experimental/microfrontend/lib/filterbank_util.c similarity index 99% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_util.c rename to tensorflow/lite/experimental/microfrontend/lib/filterbank_util.c index 53b5e45073455b..ce8b4acc0f696f 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_util.c +++ b/tensorflow/lite/experimental/microfrontend/lib/filterbank_util.c @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_util.h" +#include "tensorflow/lite/experimental/microfrontend/lib/filterbank_util.h" #include #include diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_util.h b/tensorflow/lite/experimental/microfrontend/lib/filterbank_util.h similarity index 81% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_util.h rename to tensorflow/lite/experimental/microfrontend/lib/filterbank_util.h index 9ec9bc930286e3..781d102479b428 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_util.h +++ b/tensorflow/lite/experimental/microfrontend/lib/filterbank_util.h @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_UTIL_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_UTIL_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_UTIL_H_ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank.h" +#include "tensorflow/lite/experimental/microfrontend/lib/filterbank.h" #ifdef __cplusplus extern "C" { @@ -47,4 +47,4 @@ void FilterbankFreeStateContents(struct FilterbankState* state); } // extern "C" #endif -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_UTIL_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FILTERBANK_UTIL_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.c b/tensorflow/lite/experimental/microfrontend/lib/frontend.c similarity index 94% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.c rename to tensorflow/lite/experimental/microfrontend/lib/frontend.c index de7a60b56fd85a..7a618d9af5e797 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.c +++ b/tensorflow/lite/experimental/microfrontend/lib/frontend.c @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.h" +#include "tensorflow/lite/experimental/microfrontend/lib/frontend.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/bits.h" +#include "tensorflow/lite/experimental/microfrontend/lib/bits.h" struct FrontendOutput FrontendProcessSamples(struct FrontendState* state, const int16_t* samples, diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.h b/tensorflow/lite/experimental/microfrontend/lib/frontend.h similarity index 73% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.h rename to tensorflow/lite/experimental/microfrontend/lib/frontend.h index 71ae81024cb3ae..883df5fd3d05c5 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.h +++ b/tensorflow/lite/experimental/microfrontend/lib/frontend.h @@ -12,18 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_H_ #include #include -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/fft.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/window.h" +#include "tensorflow/lite/experimental/microfrontend/lib/fft.h" +#include "tensorflow/lite/experimental/microfrontend/lib/filterbank.h" +#include "tensorflow/lite/experimental/microfrontend/lib/log_scale.h" +#include "tensorflow/lite/experimental/microfrontend/lib/noise_reduction.h" +#include "tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control.h" +#include "tensorflow/lite/experimental/microfrontend/lib/window.h" #ifdef __cplusplus extern "C" { @@ -61,4 +61,4 @@ void FrontendReset(struct FrontendState* state); } // extern "C" #endif -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_io.c b/tensorflow/lite/experimental/microfrontend/lib/frontend_io.c similarity index 83% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_io.c rename to tensorflow/lite/experimental/microfrontend/lib/frontend_io.c index 40bcf247497dff..b422d078a6faaf 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_io.c +++ b/tensorflow/lite/experimental/microfrontend/lib/frontend_io.c @@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_io.h" +#include "tensorflow/lite/experimental/microfrontend/lib/frontend_io.h" #include -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/fft_io.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_io.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_io.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_io.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/window_io.h" +#include "tensorflow/lite/experimental/microfrontend/lib/fft_io.h" +#include "tensorflow/lite/experimental/microfrontend/lib/filterbank_io.h" +#include "tensorflow/lite/experimental/microfrontend/lib/log_scale_io.h" +#include "tensorflow/lite/experimental/microfrontend/lib/noise_reduction_io.h" +#include "tensorflow/lite/experimental/microfrontend/lib/window_io.h" int WriteFrontendStateMemmap(const char* header, const char* source, const struct FrontendState* state) { diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_io.h b/tensorflow/lite/experimental/microfrontend/lib/frontend_io.h similarity index 73% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_io.h rename to tensorflow/lite/experimental/microfrontend/lib/frontend_io.h index 4f45577caeab7a..0d59eda7d093f0 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_io.h +++ b/tensorflow/lite/experimental/microfrontend/lib/frontend_io.h @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_IO_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_IO_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_IO_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_IO_H_ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.h" +#include "tensorflow/lite/experimental/microfrontend/lib/frontend.h" #ifdef __cplusplus extern "C" { @@ -28,4 +28,4 @@ int WriteFrontendStateMemmap(const char* header, const char* source, } // extern "C" #endif -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_IO_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_IO_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_main.c b/tensorflow/lite/experimental/microfrontend/lib/frontend_main.c similarity index 93% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_main.c rename to tensorflow/lite/experimental/microfrontend/lib/frontend_main.c index 46caebeec9059c..4a8411b6214df3 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_main.c +++ b/tensorflow/lite/experimental/microfrontend/lib/frontend_main.c @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_util.h" +#include "tensorflow/lite/experimental/microfrontend/lib/frontend.h" +#include "tensorflow/lite/experimental/microfrontend/lib/frontend_util.h" int main(int argc, char** argv) { struct FrontendConfig frontend_config; diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_memmap_generator.c b/tensorflow/lite/experimental/microfrontend/lib/frontend_memmap_generator.c similarity index 86% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_memmap_generator.c rename to tensorflow/lite/experimental/microfrontend/lib/frontend_memmap_generator.c index a4c59b0cccabb7..766b7f2ad568db 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_memmap_generator.c +++ b/tensorflow/lite/experimental/microfrontend/lib/frontend_memmap_generator.c @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_util.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_io.h" +#include "tensorflow/lite/experimental/microfrontend/lib/frontend.h" +#include "tensorflow/lite/experimental/microfrontend/lib/frontend_util.h" +#include "tensorflow/lite/experimental/microfrontend/lib/frontend_io.h" int main(int argc, char** argv) { if (argc != 3) { diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_memmap_main.c b/tensorflow/lite/experimental/microfrontend/lib/frontend_memmap_main.c similarity index 96% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_memmap_main.c rename to tensorflow/lite/experimental/microfrontend/lib/frontend_memmap_main.c index a4264922b94b5a..cf39e93a78361d 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_memmap_main.c +++ b/tensorflow/lite/experimental/microfrontend/lib/frontend_memmap_main.c @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.h" +#include "tensorflow/lite/experimental/microfrontend/lib/frontend.h" #include "memmap.h" int main(int argc, char** argv) { diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_test.cc b/tensorflow/lite/experimental/microfrontend/lib/frontend_test.cc similarity index 96% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_test.cc rename to tensorflow/lite/experimental/microfrontend/lib/frontend_test.cc index f06e2565c285e6..993e866cc08850 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_test.cc +++ b/tensorflow/lite/experimental/microfrontend/lib/frontend_test.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_util.h" +#include "tensorflow/lite/experimental/microfrontend/lib/frontend.h" +#include "tensorflow/lite/experimental/microfrontend/lib/frontend_util.h" #include #include diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_util.c b/tensorflow/lite/experimental/microfrontend/lib/frontend_util.c similarity index 95% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_util.c rename to tensorflow/lite/experimental/microfrontend/lib/frontend_util.c index ae2d9ae6c4c8bb..94c15adcafe926 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_util.c +++ b/tensorflow/lite/experimental/microfrontend/lib/frontend_util.c @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_util.h" +#include "tensorflow/lite/experimental/microfrontend/lib/frontend_util.h" #include #include -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/bits.h" +#include "tensorflow/lite/experimental/microfrontend/lib/bits.h" void FrontendFillConfigWithDefaults(struct FrontendConfig* config) { WindowFillConfigWithDefaults(&config->window); diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_util.h b/tensorflow/lite/experimental/microfrontend/lib/frontend_util.h similarity index 62% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_util.h rename to tensorflow/lite/experimental/microfrontend/lib/frontend_util.h index a958b610eae689..895ce6cd2b2b08 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_util.h +++ b/tensorflow/lite/experimental/microfrontend/lib/frontend_util.h @@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_UTIL_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_UTIL_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_UTIL_H_ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/fft_util.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/filterbank_util.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_util.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_util.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control_util.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/window_util.h" +#include "tensorflow/lite/experimental/microfrontend/lib/fft_util.h" +#include "tensorflow/lite/experimental/microfrontend/lib/filterbank_util.h" +#include "tensorflow/lite/experimental/microfrontend/lib/frontend.h" +#include "tensorflow/lite/experimental/microfrontend/lib/log_scale_util.h" +#include "tensorflow/lite/experimental/microfrontend/lib/noise_reduction_util.h" +#include "tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control_util.h" +#include "tensorflow/lite/experimental/microfrontend/lib/window_util.h" #ifdef __cplusplus extern "C" { @@ -49,4 +49,4 @@ void FrontendFreeStateContents(struct FrontendState* state); } // extern "C" #endif -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_UTIL_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_FRONTEND_UTIL_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_lut.c b/tensorflow/lite/experimental/microfrontend/lib/log_lut.c similarity index 95% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/log_lut.c rename to tensorflow/lite/experimental/microfrontend/lib/log_lut.c index f8d32102336d19..f59618e028f488 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_lut.c +++ b/tensorflow/lite/experimental/microfrontend/lib/log_lut.c @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/log_lut.h" +#include "tensorflow/lite/experimental/microfrontend/lib/log_lut.h" const uint16_t kLogLut[] #ifndef _MSC_VER __attribute__((aligned(4))) diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_lut.h b/tensorflow/lite/experimental/microfrontend/lib/log_lut.h similarity index 82% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/log_lut.h rename to tensorflow/lite/experimental/microfrontend/lib/log_lut.h index 53dd1fa4052d2f..b2448a32289a91 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_lut.h +++ b/tensorflow/lite/experimental/microfrontend/lib/log_lut.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_LUT_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_LUT_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_LUT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_LUT_H_ #include @@ -37,4 +37,4 @@ extern const uint16_t kLogLut[]; } // extern "C" #endif -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_LUT_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_LUT_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale.c b/tensorflow/lite/experimental/microfrontend/lib/log_scale.c similarity index 92% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale.c rename to tensorflow/lite/experimental/microfrontend/lib/log_scale.c index 4b1246187155e8..54f370e7d9f552 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale.c +++ b/tensorflow/lite/experimental/microfrontend/lib/log_scale.c @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale.h" +#include "tensorflow/lite/experimental/microfrontend/lib/log_scale.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/bits.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/log_lut.h" +#include "tensorflow/lite/experimental/microfrontend/lib/bits.h" +#include "tensorflow/lite/experimental/microfrontend/lib/log_lut.h" #define kuint16max 0x0000FFFF diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale.h b/tensorflow/lite/experimental/microfrontend/lib/log_scale.h similarity index 82% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale.h rename to tensorflow/lite/experimental/microfrontend/lib/log_scale.h index 8fd60999330492..a383f32f5bc6df 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale.h +++ b/tensorflow/lite/experimental/microfrontend/lib/log_scale.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_H_ #include #include @@ -36,4 +36,4 @@ uint16_t* LogScaleApply(struct LogScaleState* state, uint32_t* signal, } // extern "C" #endif -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_io.c b/tensorflow/lite/experimental/microfrontend/lib/log_scale_io.c similarity index 92% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_io.c rename to tensorflow/lite/experimental/microfrontend/lib/log_scale_io.c index f59cde951ca40a..a04760de58e5fb 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_io.c +++ b/tensorflow/lite/experimental/microfrontend/lib/log_scale_io.c @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_io.h" +#include "tensorflow/lite/experimental/microfrontend/lib/log_scale_io.h" void LogScaleWriteMemmap(FILE* fp, const struct LogScaleState* state, const char* variable) { diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_io.h b/tensorflow/lite/experimental/microfrontend/lib/log_scale_io.h similarity index 73% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_io.h rename to tensorflow/lite/experimental/microfrontend/lib/log_scale_io.h index 5444303b2445ac..9d447ac9018b12 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_io.h +++ b/tensorflow/lite/experimental/microfrontend/lib/log_scale_io.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_IO_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_IO_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_IO_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_IO_H_ #include -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale.h" +#include "tensorflow/lite/experimental/microfrontend/lib/log_scale.h" #ifdef __cplusplus extern "C" { @@ -30,4 +30,4 @@ void LogScaleWriteMemmap(FILE* fp, const struct LogScaleState* state, } // extern "C" #endif -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_IO_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_IO_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_test.cc b/tensorflow/lite/experimental/microfrontend/lib/log_scale_test.cc similarity index 92% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_test.cc rename to tensorflow/lite/experimental/microfrontend/lib/log_scale_test.cc index 312d7ea7406a14..91ca657e543d2a 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_test.cc +++ b/tensorflow/lite/experimental/microfrontend/lib/log_scale_test.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_util.h" +#include "tensorflow/lite/experimental/microfrontend/lib/log_scale.h" +#include "tensorflow/lite/experimental/microfrontend/lib/log_scale_util.h" #include #include diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_util.c b/tensorflow/lite/experimental/microfrontend/lib/log_scale_util.c similarity index 92% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_util.c rename to tensorflow/lite/experimental/microfrontend/lib/log_scale_util.c index 8a025fbf72d9db..0e3dd1d1e94687 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_util.c +++ b/tensorflow/lite/experimental/microfrontend/lib/log_scale_util.c @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_util.h" +#include "tensorflow/lite/experimental/microfrontend/lib/log_scale_util.h" void LogScaleFillConfigWithDefaults(struct LogScaleConfig* config) { config->enable_log = 1; diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_util.h b/tensorflow/lite/experimental/microfrontend/lib/log_scale_util.h similarity index 79% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_util.h rename to tensorflow/lite/experimental/microfrontend/lib/log_scale_util.h index 33b21f30b10930..11f7d9eeb9b793 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale_util.h +++ b/tensorflow/lite/experimental/microfrontend/lib/log_scale_util.h @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_UTIL_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_UTIL_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_UTIL_H_ #include #include -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/log_scale.h" +#include "tensorflow/lite/experimental/microfrontend/lib/log_scale.h" #ifdef __cplusplus extern "C" { @@ -42,4 +42,4 @@ int LogScalePopulateState(const struct LogScaleConfig* config, } // extern "C" #endif -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_UTIL_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_LOG_SCALE_UTIL_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction.c b/tensorflow/lite/experimental/microfrontend/lib/noise_reduction.c similarity index 95% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction.c rename to tensorflow/lite/experimental/microfrontend/lib/noise_reduction.c index 92f8b58d74f9d3..b6fcb5e9409c07 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction.c +++ b/tensorflow/lite/experimental/microfrontend/lib/noise_reduction.c @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction.h" +#include "tensorflow/lite/experimental/microfrontend/lib/noise_reduction.h" #include diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction.h b/tensorflow/lite/experimental/microfrontend/lib/noise_reduction.h similarity index 83% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction.h rename to tensorflow/lite/experimental/microfrontend/lib/noise_reduction.h index cc2cf2d9b742f9..46d3f52e60e376 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction.h +++ b/tensorflow/lite/experimental/microfrontend/lib/noise_reduction.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_H_ #define kNoiseReductionBits 14 @@ -43,4 +43,4 @@ void NoiseReductionReset(struct NoiseReductionState* state); } // extern "C" #endif -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_io.c b/tensorflow/lite/experimental/microfrontend/lib/noise_reduction_io.c similarity index 94% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_io.c rename to tensorflow/lite/experimental/microfrontend/lib/noise_reduction_io.c index 1cba410436ad2b..19c32b32759ed6 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_io.c +++ b/tensorflow/lite/experimental/microfrontend/lib/noise_reduction_io.c @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_io.h" +#include "tensorflow/lite/experimental/microfrontend/lib/noise_reduction_io.h" void NoiseReductionWriteMemmapPreamble( FILE* fp, const struct NoiseReductionState* state) { diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_io.h b/tensorflow/lite/experimental/microfrontend/lib/noise_reduction_io.h similarity index 75% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_io.h rename to tensorflow/lite/experimental/microfrontend/lib/noise_reduction_io.h index afeedfce99d09b..ded52118f0cab0 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_io.h +++ b/tensorflow/lite/experimental/microfrontend/lib/noise_reduction_io.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_IO_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_IO_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_IO_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_IO_H_ #include -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction.h" +#include "tensorflow/lite/experimental/microfrontend/lib/noise_reduction.h" #ifdef __cplusplus extern "C" { @@ -33,4 +33,4 @@ void NoiseReductionWriteMemmap(FILE* fp, } // extern "C" #endif -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_IO_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_IO_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_test.cc b/tensorflow/lite/experimental/microfrontend/lib/noise_reduction_test.cc similarity index 92% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_test.cc rename to tensorflow/lite/experimental/microfrontend/lib/noise_reduction_test.cc index f4cf486227a260..16140564879305 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_test.cc +++ b/tensorflow/lite/experimental/microfrontend/lib/noise_reduction_test.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_util.h" +#include "tensorflow/lite/experimental/microfrontend/lib/noise_reduction.h" +#include "tensorflow/lite/experimental/microfrontend/lib/noise_reduction_util.h" #include #include diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_util.c b/tensorflow/lite/experimental/microfrontend/lib/noise_reduction_util.c similarity index 95% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_util.c rename to tensorflow/lite/experimental/microfrontend/lib/noise_reduction_util.c index 46f475352e0670..a6c9234eb888da 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_util.c +++ b/tensorflow/lite/experimental/microfrontend/lib/noise_reduction_util.c @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_util.h" +#include "tensorflow/lite/experimental/microfrontend/lib/noise_reduction_util.h" #include diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_util.h b/tensorflow/lite/experimental/microfrontend/lib/noise_reduction_util.h similarity index 82% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_util.h rename to tensorflow/lite/experimental/microfrontend/lib/noise_reduction_util.h index 207b8a679dac9d..fa55539143fca6 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction_util.h +++ b/tensorflow/lite/experimental/microfrontend/lib/noise_reduction_util.h @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_UTIL_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_UTIL_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_UTIL_H_ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/noise_reduction.h" +#include "tensorflow/lite/experimental/microfrontend/lib/noise_reduction.h" #ifdef __cplusplus extern "C" { @@ -47,4 +47,4 @@ void NoiseReductionFreeStateContents(struct NoiseReductionState* state); } // extern "C" #endif -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_UTIL_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_NOISE_REDUCTION_UTIL_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control.c b/tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control.c similarity index 92% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control.c rename to tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control.c index 551d552e8f63a4..b49eb301370a7e 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control.c +++ b/tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control.c @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control.h" +#include "tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/bits.h" +#include "tensorflow/lite/experimental/microfrontend/lib/bits.h" int16_t WideDynamicFunction(const uint32_t x, const int16_t* lut) { if (x <= 2) { diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control.h b/tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control.h similarity index 82% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control.h rename to tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control.h index cab74f49dbece6..81557913223361 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control.h +++ b/tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_PCAN_GAIN_CONTROL_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_PCAN_GAIN_CONTROL_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_PCAN_GAIN_CONTROL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_PCAN_GAIN_CONTROL_H_ #include #include @@ -43,4 +43,4 @@ void PcanGainControlApply(struct PcanGainControlState* state, uint32_t* signal); } // extern "C" #endif -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_PCAN_GAIN_CONTROL_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_PCAN_GAIN_CONTROL_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control_test.cc b/tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control_test.cc similarity index 91% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control_test.cc rename to tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control_test.cc index bbc36d6eac7757..830db89edd8eb3 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control_test.cc +++ b/tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control_test.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control_util.h" +#include "tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control.h" +#include "tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control_util.h" #include #include diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control_util.c b/tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control_util.c similarity index 97% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control_util.c rename to tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control_util.c index 4226b390bc1427..dbe44c494ae07f 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control_util.c +++ b/tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control_util.c @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control_util.h" +#include "tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control_util.h" #include #include diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control_util.h b/tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control_util.h similarity index 84% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control_util.h rename to tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control_util.h index 79c0b1da693651..d4bfaa2ed71d22 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control_util.h +++ b/tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control_util.h @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_PCAN_GAIN_CONTROL_UTIL_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_PCAN_GAIN_CONTROL_UTIL_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_PCAN_GAIN_CONTROL_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_PCAN_GAIN_CONTROL_UTIL_H_ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/pcan_gain_control.h" +#include "tensorflow/lite/experimental/microfrontend/lib/pcan_gain_control.h" #define kWideDynamicFunctionBits 32 #define kWideDynamicFunctionLUTSize (4 * kWideDynamicFunctionBits - 3) @@ -54,4 +54,4 @@ void PcanGainControlFreeStateContents(struct PcanGainControlState* state); } // extern "C" #endif -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_PCAN_GAIN_CONTROL_UTIL_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_PCAN_GAIN_CONTROL_UTIL_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/window.c b/tensorflow/lite/experimental/microfrontend/lib/window.c similarity index 97% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/window.c rename to tensorflow/lite/experimental/microfrontend/lib/window.c index 0fdc040a7a58c7..517b60487becb9 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/window.c +++ b/tensorflow/lite/experimental/microfrontend/lib/window.c @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/window.h" +#include "tensorflow/lite/experimental/microfrontend/lib/window.h" #include diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/window.h b/tensorflow/lite/experimental/microfrontend/lib/window.h similarity index 85% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/window.h rename to tensorflow/lite/experimental/microfrontend/lib/window.h index 90291e5c7238b5..bad8151412fe22 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/window.h +++ b/tensorflow/lite/experimental/microfrontend/lib/window.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_H_ #include #include @@ -46,4 +46,4 @@ void WindowReset(struct WindowState* state); } // extern "C" #endif -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/window_io.c b/tensorflow/lite/experimental/microfrontend/lib/window_io.c similarity index 95% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/window_io.c rename to tensorflow/lite/experimental/microfrontend/lib/window_io.c index f1fee7c1eda9d0..ed4ac5eb110c0f 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/window_io.c +++ b/tensorflow/lite/experimental/microfrontend/lib/window_io.c @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/window_io.h" +#include "tensorflow/lite/experimental/microfrontend/lib/window_io.h" void WindowWriteMemmapPreamble(FILE* fp, const struct WindowState* state) { fprintf(fp, "static int16_t window_coefficients[] = {\n"); diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/window_io.h b/tensorflow/lite/experimental/microfrontend/lib/window_io.h similarity index 75% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/window_io.h rename to tensorflow/lite/experimental/microfrontend/lib/window_io.h index 2bab9064c1fa70..a76b2dc3e81238 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/window_io.h +++ b/tensorflow/lite/experimental/microfrontend/lib/window_io.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_IO_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_IO_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_IO_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_IO_H_ #include -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/window.h" +#include "tensorflow/lite/experimental/microfrontend/lib/window.h" #ifdef __cplusplus extern "C" { @@ -31,4 +31,4 @@ void WindowWriteMemmap(FILE* fp, const struct WindowState* state, } // extern "C" #endif -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_IO_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_IO_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/window_test.cc b/tensorflow/lite/experimental/microfrontend/lib/window_test.cc similarity index 97% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/window_test.cc rename to tensorflow/lite/experimental/microfrontend/lib/window_test.cc index a6c0879faa8ac5..8c6c19188d3e12 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/window_test.cc +++ b/tensorflow/lite/experimental/microfrontend/lib/window_test.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/window.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/window_util.h" +#include "tensorflow/lite/experimental/microfrontend/lib/window.h" +#include "tensorflow/lite/experimental/microfrontend/lib/window_util.h" #include #include diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/window_util.c b/tensorflow/lite/experimental/microfrontend/lib/window_util.c similarity index 96% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/window_util.c rename to tensorflow/lite/experimental/microfrontend/lib/window_util.c index 3adde0fb0a6855..2445c343be11e8 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/window_util.c +++ b/tensorflow/lite/experimental/microfrontend/lib/window_util.c @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/window_util.h" +#include "tensorflow/lite/experimental/microfrontend/lib/window_util.h" #include #include diff --git a/tensorflow/contrib/lite/experimental/microfrontend/lib/window_util.h b/tensorflow/lite/experimental/microfrontend/lib/window_util.h similarity index 80% rename from tensorflow/contrib/lite/experimental/microfrontend/lib/window_util.h rename to tensorflow/lite/experimental/microfrontend/lib/window_util.h index 52dc8f38cc8bfc..68e4de9eb586ec 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/lib/window_util.h +++ b/tensorflow/lite/experimental/microfrontend/lib/window_util.h @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_UTIL_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_UTIL_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_UTIL_H_ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/window.h" +#include "tensorflow/lite/experimental/microfrontend/lib/window.h" #ifdef __cplusplus extern "C" { @@ -42,4 +42,4 @@ void WindowFreeStateContents(struct WindowState* state); } // extern "C" #endif -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_UTIL_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICROFRONTEND_LIB_WINDOW_UTIL_H_ diff --git a/tensorflow/contrib/lite/experimental/microfrontend/ops/audio_microfrontend_op.cc b/tensorflow/lite/experimental/microfrontend/ops/audio_microfrontend_op.cc similarity index 98% rename from tensorflow/contrib/lite/experimental/microfrontend/ops/audio_microfrontend_op.cc rename to tensorflow/lite/experimental/microfrontend/ops/audio_microfrontend_op.cc index 83b9528ab5e23e..51094a976d297a 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/ops/audio_microfrontend_op.cc +++ b/tensorflow/lite/experimental/microfrontend/ops/audio_microfrontend_op.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend.h" -#include "tensorflow/contrib/lite/experimental/microfrontend/lib/frontend_util.h" +#include "tensorflow/lite/experimental/microfrontend/lib/frontend.h" +#include "tensorflow/lite/experimental/microfrontend/lib/frontend_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" diff --git a/tensorflow/contrib/lite/experimental/microfrontend/python/kernel_tests/audio_microfrontend_op_test.py b/tensorflow/lite/experimental/microfrontend/python/kernel_tests/audio_microfrontend_op_test.py similarity index 98% rename from tensorflow/contrib/lite/experimental/microfrontend/python/kernel_tests/audio_microfrontend_op_test.py rename to tensorflow/lite/experimental/microfrontend/python/kernel_tests/audio_microfrontend_op_test.py index e1661d6608777a..020d40bc13c1bb 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/python/kernel_tests/audio_microfrontend_op_test.py +++ b/tensorflow/lite/experimental/microfrontend/python/kernel_tests/audio_microfrontend_op_test.py @@ -20,7 +20,7 @@ import tensorflow as tf -from tensorflow.contrib.lite.experimental.microfrontend.python.ops import audio_microfrontend_op as frontend_op +from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op as frontend_op SAMPLE_RATE = 1000 WINDOW_SIZE = 25 diff --git a/tensorflow/contrib/lite/experimental/microfrontend/python/ops/audio_microfrontend_op.py b/tensorflow/lite/experimental/microfrontend/python/ops/audio_microfrontend_op.py similarity index 98% rename from tensorflow/contrib/lite/experimental/microfrontend/python/ops/audio_microfrontend_op.py rename to tensorflow/lite/experimental/microfrontend/python/ops/audio_microfrontend_op.py index c6e5e34760f20c..3d49482f4ecd34 100644 --- a/tensorflow/contrib/lite/experimental/microfrontend/python/ops/audio_microfrontend_op.py +++ b/tensorflow/lite/experimental/microfrontend/python/ops/audio_microfrontend_op.py @@ -20,7 +20,7 @@ import tensorflow as tf -from tensorflow.contrib.lite.experimental.microfrontend.ops import gen_audio_microfrontend_op +from tensorflow.lite.experimental.microfrontend.ops import gen_audio_microfrontend_op from tensorflow.contrib.util import loader from tensorflow.python.platform import resource_loader diff --git a/tensorflow/contrib/lite/experimental/writer/BUILD b/tensorflow/lite/experimental/writer/BUILD similarity index 61% rename from tensorflow/contrib/lite/experimental/writer/BUILD rename to tensorflow/lite/experimental/writer/BUILD index 82d39c00abd27d..506c668cf2c70f 100644 --- a/tensorflow/contrib/lite/experimental/writer/BUILD +++ b/tensorflow/lite/experimental/writer/BUILD @@ -8,7 +8,7 @@ cc_binary( name = "option_writer_generator", srcs = ["option_writer_generator.cc"], deps = [ - "//tensorflow/contrib/lite/schema:schema_fbs_with_reflection", + "//tensorflow/lite/schema:schema_fbs_with_reflection", "@flatbuffers", ], ) @@ -27,11 +27,11 @@ cc_library( ], textual_hdrs = ["option_writer_generated.h"], deps = [ - "//tensorflow/contrib/lite:builtin_op_data", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:schema_fbs_version", - "//tensorflow/contrib/lite/kernels:builtin_ops", - "//tensorflow/contrib/lite/schema:schema_fbs_with_reflection", + "//tensorflow/lite:builtin_op_data", + "//tensorflow/lite:framework", + "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/schema:schema_fbs_with_reflection", ], ) @@ -40,8 +40,8 @@ cc_binary( srcs = ["writer.cc"], deps = [ ":writer_lib", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:builtin_ops", ], ) @@ -51,9 +51,9 @@ cc_test( srcs = ["writer_lib_test.cc"], deps = [ ":writer_lib", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:builtin_ops", - "//tensorflow/contrib/lite/testing:util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/contrib/lite/experimental/writer/enum_mapping.h b/tensorflow/lite/experimental/writer/enum_mapping.h similarity index 91% rename from tensorflow/contrib/lite/experimental/writer/enum_mapping.h rename to tensorflow/lite/experimental/writer/enum_mapping.h index 8bc464fd7188a2..cb6ec3e0d7e0f1 100644 --- a/tensorflow/contrib/lite/experimental/writer/enum_mapping.h +++ b/tensorflow/lite/experimental/writer/enum_mapping.h @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h" +#include "tensorflow/lite/builtin_op_data.h" +#include "tensorflow/lite/schema/reflection/schema_generated.h" // TODO(aselle): Ideally extract this from the schema. @@ -113,4 +113,4 @@ inline LSHProjectionType LSHProjectionTypeToSchema( } } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_ diff --git a/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc b/tensorflow/lite/experimental/writer/option_writer_generator.cc similarity index 99% rename from tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc rename to tensorflow/lite/experimental/writer/option_writer_generator.cc index dc32817b86038f..036809e94abcfc 100644 --- a/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc +++ b/tensorflow/lite/experimental/writer/option_writer_generator.cc @@ -17,12 +17,12 @@ limitations under the License. #include #include #include "flatbuffers/minireflect.h" // TF:flatbuffers -#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h" +#include "tensorflow/lite/schema/reflection/schema_generated.h" namespace tflite { namespace { // This is generated by grepping -// cat third_party/tensorflow/contrib/lite/builtin_op_data.h +// cat third_party/tensorflow/lite/builtin_op_data.h //| grep "^} TfLite" | sed 's/^} TfLite\(.*\)Params;/\1Params/g' | grep -v "^}" static const char* param_structs[] = {"TfLiteConvParams", "TfLitePoolParams", diff --git a/tensorflow/contrib/lite/experimental/writer/writer.cc b/tensorflow/lite/experimental/writer/writer.cc similarity index 89% rename from tensorflow/contrib/lite/experimental/writer/writer.cc rename to tensorflow/lite/experimental/writer/writer.cc index 20ede214fba795..c1de0333676041 100644 --- a/tensorflow/contrib/lite/experimental/writer/writer.cc +++ b/tensorflow/lite/experimental/writer/writer.cc @@ -20,9 +20,9 @@ limitations under the License. #include -#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/experimental/writer/writer_lib.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" int main(int argc, char* argv[]) { if (argc != 3) { diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.cc b/tensorflow/lite/experimental/writer/writer_lib.cc similarity index 95% rename from tensorflow/contrib/lite/experimental/writer/writer_lib.cc rename to tensorflow/lite/experimental/writer/writer_lib.cc index 555a9cc4b09f30..a0ce4b716d62c5 100644 --- a/tensorflow/contrib/lite/experimental/writer/writer_lib.cc +++ b/tensorflow/lite/experimental/writer/writer_lib.cc @@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h" +#include "tensorflow/lite/experimental/writer/writer_lib.h" #include #include #include -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context_util.h" -#include "tensorflow/contrib/lite/experimental/writer/enum_mapping.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h" -#include "tensorflow/contrib/lite/version.h" +#include "tensorflow/lite/builtin_op_data.h" +#include "tensorflow/lite/context_util.h" +#include "tensorflow/lite/experimental/writer/enum_mapping.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/schema/reflection/schema_generated.h" +#include "tensorflow/lite/version.h" namespace tflite { template @@ -33,7 +33,7 @@ using FlatBufferBuilder = flatbuffers::FlatBufferBuilder; std::pair> CreateBuiltinUnion( FlatBufferBuilder* fbb, enum BuiltinOperator op, void* builtin_op_data) { switch (op) { -#include "tensorflow/contrib/lite/experimental/writer/option_writer_generated.h" +#include "tensorflow/lite/experimental/writer/option_writer_generated.h" } return std::make_pair(BuiltinOptions_NONE, Offset()); } diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.h b/tensorflow/lite/experimental/writer/writer_lib.h similarity index 88% rename from tensorflow/contrib/lite/experimental/writer/writer_lib.h rename to tensorflow/lite/experimental/writer/writer_lib.h index a5f14697cfd223..08c0436932ffc9 100644 --- a/tensorflow/contrib/lite/experimental/writer/writer_lib.h +++ b/tensorflow/lite/experimental/writer/writer_lib.h @@ -16,7 +16,7 @@ limitations under the License. // // Usage: // From command line: -// bazel run third_party/tensorflow/contrib/lite/experimental/writer:writer +// bazel run third_party/tensorflow/lite/experimental/writer:writer // -- foo.tflite foo.out.tflite // // From C++ @@ -24,16 +24,16 @@ limitations under the License. // // Build Interpreter however // // ... // InterpreterWriter(interpreter.get()).Write("output.tflite"); -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_ #include #include -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context_util.h" -#include "tensorflow/contrib/lite/experimental/writer/enum_mapping.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h" -#include "tensorflow/contrib/lite/version.h" +#include "tensorflow/lite/builtin_op_data.h" +#include "tensorflow/lite/context_util.h" +#include "tensorflow/lite/experimental/writer/enum_mapping.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/schema/reflection/schema_generated.h" +#include "tensorflow/lite/version.h" namespace tflite { @@ -128,4 +128,4 @@ class InterpreterWriter { } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_ diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc b/tensorflow/lite/experimental/writer/writer_lib_test.cc similarity index 89% rename from tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc rename to tensorflow/lite/experimental/writer/writer_lib_test.cc index 49194a76c8c084..e04c678a50f72a 100644 --- a/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc +++ b/tensorflow/lite/experimental/writer/writer_lib_test.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h" +#include "tensorflow/lite/experimental/writer/writer_lib.h" #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/testing/util.h" namespace tflite { // Make an interpreter that has no tensors and no nodes diff --git a/tensorflow/contrib/lite/g3doc/_book.yaml b/tensorflow/lite/g3doc/_book.yaml similarity index 94% rename from tensorflow/contrib/lite/g3doc/_book.yaml rename to tensorflow/lite/g3doc/_book.yaml index 1d916d0583c096..ab0d186848fcac 100644 --- a/tensorflow/contrib/lite/g3doc/_book.yaml +++ b/tensorflow/lite/g3doc/_book.yaml @@ -1,7 +1,7 @@ upper_tabs: # Tabs left of dropdown menu - include: /_upper_tabs_left.yaml -- include: /versions/_upper_tabs_versions.yaml +- include: /api_docs/_upper_tabs_api.yaml # Dropdown menu - name: Ecosystem path: /ecosystem @@ -57,7 +57,7 @@ upper_tabs: - title: Post-training quantization path: /lite/performance/post_training_quantization - title: Post-training quantization example - path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb + path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tutorials/post_training_quant.ipynb status: external - title: TF Mobile @@ -81,4 +81,4 @@ upper_tabs: skip_translation: true contents: - title: API - path: /api_docs/python/tf/contrib/lite + path: /api_docs/python/tf/lite diff --git a/tensorflow/contrib/lite/g3doc/_index.yaml b/tensorflow/lite/g3doc/_index.yaml similarity index 99% rename from tensorflow/contrib/lite/g3doc/_index.yaml rename to tensorflow/lite/g3doc/_index.yaml index 44ee6ba7505d42..43b5e3cfc01b45 100644 --- a/tensorflow/contrib/lite/g3doc/_index.yaml +++ b/tensorflow/lite/g3doc/_index.yaml @@ -211,8 +211,8 @@ landing_page: path: https://www.youtube.com/watch?v=FAMfy7izB6A - heading: TensorFlow Lite on GitHub image_path: /ecosystem/images/github-card-16x9.png - path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite + path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite buttons: - label: View on GitHub - path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite + path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite - classname: devsite-landing-row-item-hidden diff --git a/tensorflow/contrib/lite/g3doc/_project.yaml b/tensorflow/lite/g3doc/_project.yaml similarity index 100% rename from tensorflow/contrib/lite/g3doc/_project.yaml rename to tensorflow/lite/g3doc/_project.yaml diff --git a/tensorflow/contrib/lite/g3doc/apis.md b/tensorflow/lite/g3doc/apis.md similarity index 99% rename from tensorflow/contrib/lite/g3doc/apis.md rename to tensorflow/lite/g3doc/apis.md index 69616c7b8a3c1c..e9fa24bff1d1a3 100644 --- a/tensorflow/contrib/lite/g3doc/apis.md +++ b/tensorflow/lite/g3doc/apis.md @@ -347,7 +347,7 @@ interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs); where each entry in `inputs` corresponds to an input tensor and `map_of_indices_to_outputs` maps indices of output tensors to the corresponding output data. In both cases the tensor indices should correspond to -the values given to the [TensorFlow Lite Optimized Converter](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md) +the values given to the [TensorFlow Lite Optimized Converter](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/toco/g3doc/cmdline_examples.md) when the model was created. Be aware that the order of tensors in `input` must match the order given to the `TensorFlow Lite Optimized Converter`. diff --git a/tensorflow/contrib/lite/g3doc/convert/cmdline_examples.md b/tensorflow/lite/g3doc/convert/cmdline_examples.md similarity index 97% rename from tensorflow/contrib/lite/g3doc/convert/cmdline_examples.md rename to tensorflow/lite/g3doc/convert/cmdline_examples.md index 44fb4f19aeb12f..59f26b35051ce2 100644 --- a/tensorflow/contrib/lite/g3doc/convert/cmdline_examples.md +++ b/tensorflow/lite/g3doc/convert/cmdline_examples.md @@ -18,7 +18,7 @@ There are two approaches to running the converter in the command line. [clone the TensorFlow repository](https://www.tensorflow.org/install/source) and use `bazel`. * Example: `bazel run - //tensorflow/contrib/lite/python:tflite_convert -- + //tensorflow/lite/python:tflite_convert -- --output_file=...` ### Converting models prior to TensorFlow 1.9 @@ -95,11 +95,10 @@ tflite_convert \ The TensorFlow Lite Converter is compatible with fixed point quantization models described [here](https://www.tensorflow.org/performance/quantization). These are -float models with -[`FakeQuant*`](https://www.tensorflow.org/api_guides/python/array_ops#Fake_quantization) -ops inserted at the boundaries of fused layers to record min-max range -information. This generates a quantized inference workload that reproduces the -quantization behavior that was used during training. +float models with `FakeQuant*` ops inserted at the boundaries of fused layers +to record min-max range information. This generates a quantized inference +workload that reproduces the quantization behavior that was used during +training. The following command generates a quantized TensorFlow Lite FlatBuffer from a "quantized" TensorFlow GraphDef. diff --git a/tensorflow/contrib/lite/g3doc/convert/cmdline_reference.md b/tensorflow/lite/g3doc/convert/cmdline_reference.md similarity index 100% rename from tensorflow/contrib/lite/g3doc/convert/cmdline_reference.md rename to tensorflow/lite/g3doc/convert/cmdline_reference.md diff --git a/tensorflow/contrib/lite/g3doc/convert/index.md b/tensorflow/lite/g3doc/convert/index.md similarity index 100% rename from tensorflow/contrib/lite/g3doc/convert/index.md rename to tensorflow/lite/g3doc/convert/index.md diff --git a/tensorflow/contrib/lite/g3doc/convert/python_api.md b/tensorflow/lite/g3doc/convert/python_api.md similarity index 85% rename from tensorflow/contrib/lite/g3doc/convert/python_api.md rename to tensorflow/lite/g3doc/convert/python_api.md index 9dcb79187ec9bd..4bdf0d8cbe8f57 100644 --- a/tensorflow/contrib/lite/g3doc/convert/python_api.md +++ b/tensorflow/lite/g3doc/convert/python_api.md @@ -3,6 +3,11 @@ This page provides examples on how to use the TensorFlow Lite Converter and the TensorFlow Lite interpreter using the Python API. +Note: TFLite recently moved from `tf.contrib.lite` to `tf.lite`. If you are +using tensorflow `r1.12` or earlier you will need to add `.contrib` to the +commands below. `tf.lite` works with newer builds, like the nightly build, +which can be installed with: `pip install tf-nightly` + [TOC] @@ -16,8 +21,8 @@ be targeted to devices with mobile. ## API The API for converting TensorFlow models to TensorFlow Lite as of TensorFlow 1.9 -is `tf.contrib.lite.TFLiteConverter`. The API for calling the Python intepreter -is `tf.contrib.lite.Interpreter`. +is `tf.lite.TFLiteConverter`. The API for calling the Python intepreter +is `tf.lite.Interpreter`. Note: Reference "Additional Instructions" sections for converting TensorFlow models to TensorFlow Lite @@ -52,7 +57,7 @@ out = tf.identity(val, name="out") with tf.Session() as sess: sess.run(tf.global_variables_initializer()) - converter = tf.contrib.lite.TFLiteConverter.from_session(sess, [img], [out]) + converter = tf.lite.TFLiteConverter.from_session(sess, [img], [out]) tflite_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_model) ``` @@ -75,7 +80,7 @@ graph_def_file = "/path/to/Downloads/mobilenet_v1_1.0_224/frozen_graph.pb" input_arrays = ["input"] output_arrays = ["MobilenetV1/Predictions/Softmax"] -converter = tf.contrib.lite.TFLiteConverter.from_frozen_graph( +converter = tf.lite.TFLiteConverter.from_frozen_graph( graph_def_file, input_arrays, output_arrays) tflite_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_model) @@ -89,7 +94,7 @@ FlatBuffer. ```python import tensorflow as tf -converter = tf.contrib.lite.TFLiteConverter.from_saved_model(saved_model_dir) +converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) tflite_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_model) ``` @@ -97,7 +102,7 @@ open("converted_model.tflite", "wb").write(tflite_model) For more complex SavedModels, the optional parameters that can be passed into `TFLiteConverter.from_saved_model()` are `input_arrays`, `input_shapes`, `output_arrays`, `tag_set` and `signature_key`. Details of each parameter are -available by running `help(tf.contrib.lite.TFLiteConverter)`. +available by running `help(tf.lite.TFLiteConverter)`. ### Exporting a tf.keras File @@ -108,7 +113,7 @@ Lite FlatBuffer. This example requires ```python import tensorflow as tf -converter = tf.contrib.lite.TFLiteConverter.from_keras_model_file("keras_model.h5") +converter = tf.lite.TFLiteConverter.from_keras_model_file("keras_model.h5") tflite_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_model) ``` @@ -140,7 +145,7 @@ keras_file = "keras_model.h5" tf.keras.models.save_model(model, keras_file) # Convert to TensorFlow Lite model. -converter = tf.contrib.lite.TFLiteConverter.from_keras_model_file(keras_file) +converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) tflite_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_model) ``` @@ -149,8 +154,8 @@ open("converted_model.tflite", "wb").write(tflite_model) For models where the default value of the attributes is not sufficient, the attribute's values should be set before calling `convert()`. In order to call -any constants use `tf.contrib.lite.constants.` as seen below with -`QUANTIZED_UINT8`. Run `help(tf.contrib.lite.TFLiteConverter)` in the Python +any constants use `tf.lite.constants.` as seen below with +`QUANTIZED_UINT8`. Run `help(tf.lite.TFLiteConverter)` in the Python terminal for detailed documentation on the attributes. Although the examples are demonstrated on GraphDefs containing only constants. @@ -170,8 +175,8 @@ val = img + const out = tf.fake_quant_with_min_max_args(val, min=0., max=1., name="output") with tf.Session() as sess: - converter = tf.contrib.lite.TFLiteConverter.from_session(sess, [img], [out]) - converter.inference_type = tf.contrib.lite.constants.QUANTIZED_UINT8 + converter = tf.lite.TFLiteConverter.from_session(sess, [img], [out]) + converter.inference_type = tf.lite.constants.QUANTIZED_UINT8 input_arrays = converter.get_input_arrays() converter.quantized_input_stats = {input_arrays[0] : (0., 1.)} # mean, std_dev tflite_model = converter.convert() @@ -185,7 +190,7 @@ with tf.Session() as sess: The following example shows how to use the TensorFlow Lite Python interpreter when provided a TensorFlow Lite FlatBuffer file. The example also demonstrates how to run inference on random input data. Run -`help(tf.contrib.lite.Interpreter)` in the Python terminal to get detailed +`help(tf.lite.Interpreter)` in the Python terminal to get detailed documentation on the interpreter. ```python @@ -193,7 +198,7 @@ import numpy as np import tensorflow as tf # Load TFLite model and allocate tensors. -interpreter = tf.contrib.lite.Interpreter(model_path="converted_model.tflite") +interpreter = tf.lite.Interpreter(model_path="converted_model.tflite") interpreter.allocate_tensors() # Get input and output tensors. @@ -227,11 +232,11 @@ val = img + const out = tf.identity(val, name="out") with tf.Session() as sess: - converter = tf.contrib.lite.TFLiteConverter.from_session(sess, [img], [out]) + converter = tf.lite.TFLiteConverter.from_session(sess, [img], [out]) tflite_model = converter.convert() # Load TFLite model and allocate tensors. -interpreter = tf.contrib.lite.Interpreter(model_content=tflite_model) +interpreter = tf.lite.Interpreter(model_content=tflite_model) interpreter.allocate_tensors() ``` @@ -254,5 +259,5 @@ identically to `TFLiteConverter`. ### Converting models prior to TensorFlow 1.9 To convert TensorFlow models to TensorFlow Lite in TensorFlow 1.7 and TensorFlow -1.8, use the `toco_convert` function. Run `help(tf.contrib.lite.toco_convert)` +1.8, use the `toco_convert` function. Run `help(tf.lite.toco_convert)` to get details about accepted parameters. diff --git a/tensorflow/contrib/lite/g3doc/custom_operators.md b/tensorflow/lite/g3doc/custom_operators.md similarity index 98% rename from tensorflow/contrib/lite/g3doc/custom_operators.md rename to tensorflow/lite/g3doc/custom_operators.md index ee6150b60e8e85..4a22d6a67577cf 100644 --- a/tensorflow/contrib/lite/g3doc/custom_operators.md +++ b/tensorflow/lite/g3doc/custom_operators.md @@ -103,7 +103,7 @@ operations instead of a single operator. pre-allocating the memory using temporary tensors. You may need to use OpData struct to reference the tensor indices in other functions. See example in the - [kernel for convolution](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/kernels/conv.cc). + [kernel for convolution](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/conv.cc). A sample code snippet is below ``` @@ -164,7 +164,7 @@ for node in frozen_graph_def.node: tf.TensorShape([10]), ]) node.attr['_output_quantized'].b = False -tflite_model = tf.contrib.lite.toco_convert( +tflite_model = tf.lite.toco_convert( frozen_graph_def,...) ``` diff --git a/tensorflow/contrib/lite/g3doc/demo_android.md b/tensorflow/lite/g3doc/demo_android.md similarity index 90% rename from tensorflow/contrib/lite/g3doc/demo_android.md rename to tensorflow/lite/g3doc/demo_android.md index c38b928684848b..772598d5cfd36a 100644 --- a/tensorflow/contrib/lite/g3doc/demo_android.md +++ b/tensorflow/lite/g3doc/demo_android.md @@ -2,7 +2,7 @@ # Android Demo App An example Android application using TensorFLow Lite is available -[on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo). +[on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/java/demo). The demo is a sample camera app that classifies images continuously using either a quantized Mobilenet model or a floating point Inception-v3 model. To run the demo, a device running Android 5.0 ( API 21) or higher is required. @@ -41,23 +41,23 @@ app: [Android Studio](https://developer.android.com/studio/index.html). * Make sure the Android SDK version is greater than 26 and NDK version is greater than 14 (in the Android Studio settings). -* Import the `tensorflow/contrib/lite/java/demo` directory as a new +* Import the `tensorflow/lite/java/demo` directory as a new Android Studio project. * Install all the Gradle extensions it requests. Now you can build and run the demo app. -The build process downloads the quantized [Mobilenet TensorFlow Lite model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip), and unzips it into the assets directory: `tensorflow/contrib/lite/java/demo/app/src/main/assets/`. +The build process downloads the quantized [Mobilenet TensorFlow Lite model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip), and unzips it into the assets directory: `tensorflow/lite/java/demo/app/src/main/assets/`. Some additional details are available on the -[TF Lite Android App page](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/README.md). +[TF Lite Android App page](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/java/demo/README.md). ### Using other models To use a different model: * Download the floating point [Inception-v3 model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_slim_2016_android_2017_11_10.zip). * Unzip and copy `inceptionv3_non_slim_2015.tflite` to the assets directory. -* Change the chosen classifier in [Camera2BasicFragment.java](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java)
+* Change the chosen classifier in [Camera2BasicFragment.java](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java)
from: `classifier = new ImageClassifierQuantizedMobileNet(getActivity());`
to: `classifier = new ImageClassifierFloatInception(getActivity());`. @@ -114,14 +114,14 @@ android_ndk_repository( ``` Some additional details are available on the -[TF Lite Android App page](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/README.md). +[TF Lite Android App page](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/java/demo/README.md). ### Build the source code To build the demo app, run `bazel`: ``` -bazel build --cxxopt=--std=c++11 //tensorflow/contrib/lite/java/demo/app/src/main:TfLiteCameraDemo +bazel build --cxxopt=--std=c++11 //tensorflow/lite/java/demo/app/src/main:TfLiteCameraDemo ``` Caution: Because of an bazel bug, we only support building the Android demo app diff --git a/tensorflow/contrib/lite/g3doc/demo_ios.md b/tensorflow/lite/g3doc/demo_ios.md similarity index 95% rename from tensorflow/contrib/lite/g3doc/demo_ios.md rename to tensorflow/lite/g3doc/demo_ios.md index 7579ad84a049ec..fbf1dd63925911 100644 --- a/tensorflow/contrib/lite/g3doc/demo_ios.md +++ b/tensorflow/lite/g3doc/demo_ios.md @@ -38,11 +38,11 @@ instructions walk you through building and running the demo on an iOS device. 2. Download the model files used by the demo app (this is done from inside the cloned directory): - sh tensorflow/contrib/lite/examples/ios/download_models.sh + sh tensorflow/lite/examples/ios/download_models.sh 3. Install the pod to generate the workspace file: - cd tensorflow/contrib/lite/examples/ios/camera + cd tensorflow/lite/examples/ios/camera pod install If you have installed this pod before and that command doesn't work, try diff --git a/tensorflow/contrib/lite/g3doc/devguide.md b/tensorflow/lite/g3doc/devguide.md similarity index 93% rename from tensorflow/contrib/lite/g3doc/devguide.md rename to tensorflow/lite/g3doc/devguide.md index 0eed5160009c07..270cb8ce378a2b 100644 --- a/tensorflow/contrib/lite/g3doc/devguide.md +++ b/tensorflow/lite/g3doc/devguide.md @@ -35,7 +35,7 @@ by suggesting contextually relevant messages. The model is built specifically fo memory constrained devices, such as watches and phones, and has been successfully used in Smart Replies on Android Wear. Currently, this model is Android-specific. -These pre-trained models are [available for download](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/models.md) +These pre-trained models are [available for download](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/models.md) ### Re-train Inception-V3 or MobileNet for a custom data set @@ -63,7 +63,7 @@ framework. See to create .pb file for the custom model. TensorFlow Lite currently supports a subset of TensorFlow operators. Refer to the -[TensorFlow Lite & TensorFlow Compatibility Guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md) +[TensorFlow Lite & TensorFlow Compatibility Guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/g3doc/tf_ops_compatibility.md) for supported operators and their usage. This set of operators will continue to grow in future Tensorflow Lite releases. @@ -151,7 +151,7 @@ inference in the `freeze_graph` step. It is also possible to use the Tensorflow Optimizing Converter with protobufs from either Python or from the command line (see the -[toco_from_protos.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco/python/toco_from_protos.py) +[toco_from_protos.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/toco/python/toco_from_protos.py) example). This allows you to integrate the conversion step into the model design workflow, ensuring the model is easily convertible to a mobile inference graph. For example: @@ -164,25 +164,25 @@ val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.]) out = tf.identity(val, name="out") with tf.Session() as sess: - tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out]) + tflite_model = tf.lite.toco_convert(sess.graph_def, [img], [out]) open("converteds_model.tflite", "wb").write(tflite_model) ``` For usage, see the Tensorflow Optimizing Converter -[command-line examples](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md). +[command-line examples](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/toco/g3doc/cmdline_examples.md). Refer to the -[Ops compatibility guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md) +[Ops compatibility guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/g3doc/tf_ops_compatibility.md) for troubleshooting help, and if that doesn't help, please [file an issue](https://github.com/tensorflow/tensorflow/issues). The [development repo](https://github.com/tensorflow/tensorflow) contains a tool to visualize TensorFlow Lite models after conversion. To build the -[visualize.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/tools/visualize.py) +[visualize.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/visualize.py) tool: ```sh -bazel run tensorflow/contrib/lite/tools:visualize -- model.tflite model_viz.html +bazel run tensorflow/lite/tools:visualize -- model.tflite model_viz.html ``` This generates an interactive HTML page listing subgraphs, operations, and a @@ -201,7 +201,7 @@ provides the ability to load a graph, set up inputs, and run the model to calculate outputs. The open source Android demo app uses the JNI interface and is available -[on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app). +[on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/java/demo/app). You can also download a [prebuilt APK](http://download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk). See the Android demo guide for details. @@ -212,7 +212,7 @@ installing TensorFlow on Android and setting up `bazel` and Android Studio. ### iOS To integrate a TensorFlow model in an iOS app, see the -[TensorFlow Lite for iOS](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/ios.md) +[TensorFlow Lite for iOS](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/g3doc/ios.md) guide and iOS demo guide. #### Core ML support @@ -227,6 +227,6 @@ devices. To use the converter, refer to the ### Raspberry Pi Compile Tensorflow Lite for a Raspberry Pi by following the -[RPi build instructions](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/rpi.md) +[RPi build instructions](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/rpi.md) This compiles a static library file (`.a`) used to build your app. There are plans for Python bindings and a demo app. diff --git a/tensorflow/contrib/lite/g3doc/images/convert/sample_after.png b/tensorflow/lite/g3doc/images/convert/sample_after.png similarity index 100% rename from tensorflow/contrib/lite/g3doc/images/convert/sample_after.png rename to tensorflow/lite/g3doc/images/convert/sample_after.png diff --git a/tensorflow/contrib/lite/g3doc/images/convert/sample_before.png b/tensorflow/lite/g3doc/images/convert/sample_before.png similarity index 100% rename from tensorflow/contrib/lite/g3doc/images/convert/sample_before.png rename to tensorflow/lite/g3doc/images/convert/sample_before.png diff --git a/tensorflow/contrib/lite/g3doc/images/convert/workflow.svg b/tensorflow/lite/g3doc/images/convert/workflow.svg similarity index 100% rename from tensorflow/contrib/lite/g3doc/images/convert/workflow.svg rename to tensorflow/lite/g3doc/images/convert/workflow.svg diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/assistant_logo.png b/tensorflow/lite/g3doc/images/landing-page/assistant_logo.png similarity index 100% rename from tensorflow/contrib/lite/g3doc/images/landing-page/assistant_logo.png rename to tensorflow/lite/g3doc/images/landing-page/assistant_logo.png diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/detect_crop_disease_in_africa.png b/tensorflow/lite/g3doc/images/landing-page/detect_crop_disease_in_africa.png similarity index 100% rename from tensorflow/contrib/lite/g3doc/images/landing-page/detect_crop_disease_in_africa.png rename to tensorflow/lite/g3doc/images/landing-page/detect_crop_disease_in_africa.png diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo.png b/tensorflow/lite/g3doc/images/landing-page/fishbrain_logo.png similarity index 100% rename from tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo.png rename to tensorflow/lite/g3doc/images/landing-page/fishbrain_logo.png diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo_big.png b/tensorflow/lite/g3doc/images/landing-page/fishbrain_logo_big.png similarity index 100% rename from tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo_big.png rename to tensorflow/lite/g3doc/images/landing-page/fishbrain_logo_big.png diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/gboard_logo.png b/tensorflow/lite/g3doc/images/landing-page/gboard_logo.png similarity index 100% rename from tensorflow/contrib/lite/g3doc/images/landing-page/gboard_logo.png rename to tensorflow/lite/g3doc/images/landing-page/gboard_logo.png diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/gmail_logo.png b/tensorflow/lite/g3doc/images/landing-page/gmail_logo.png similarity index 100% rename from tensorflow/contrib/lite/g3doc/images/landing-page/gmail_logo.png rename to tensorflow/lite/g3doc/images/landing-page/gmail_logo.png diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo.png b/tensorflow/lite/g3doc/images/landing-page/loseit_logo.png similarity index 100% rename from tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo.png rename to tensorflow/lite/g3doc/images/landing-page/loseit_logo.png diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo_big.png b/tensorflow/lite/g3doc/images/landing-page/loseit_logo_big.png similarity index 100% rename from tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo_big.png rename to tensorflow/lite/g3doc/images/landing-page/loseit_logo_big.png diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/nest_logo.png b/tensorflow/lite/g3doc/images/landing-page/nest_logo.png similarity index 100% rename from tensorflow/contrib/lite/g3doc/images/landing-page/nest_logo.png rename to tensorflow/lite/g3doc/images/landing-page/nest_logo.png diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/photos_logo.png b/tensorflow/lite/g3doc/images/landing-page/photos_logo.png similarity index 100% rename from tensorflow/contrib/lite/g3doc/images/landing-page/photos_logo.png rename to tensorflow/lite/g3doc/images/landing-page/photos_logo.png diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/shazam_logo.png b/tensorflow/lite/g3doc/images/landing-page/shazam_logo.png similarity index 100% rename from tensorflow/contrib/lite/g3doc/images/landing-page/shazam_logo.png rename to tensorflow/lite/g3doc/images/landing-page/shazam_logo.png diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/vsco_logo.png b/tensorflow/lite/g3doc/images/landing-page/vsco_logo.png similarity index 100% rename from tensorflow/contrib/lite/g3doc/images/landing-page/vsco_logo.png rename to tensorflow/lite/g3doc/images/landing-page/vsco_logo.png diff --git a/tensorflow/contrib/lite/g3doc/images/performance/model_size_vs_accuracy.png b/tensorflow/lite/g3doc/images/performance/model_size_vs_accuracy.png similarity index 100% rename from tensorflow/contrib/lite/g3doc/images/performance/model_size_vs_accuracy.png rename to tensorflow/lite/g3doc/images/performance/model_size_vs_accuracy.png diff --git a/tensorflow/contrib/lite/g3doc/images/performance/model_size_vs_latency.png b/tensorflow/lite/g3doc/images/performance/model_size_vs_latency.png similarity index 100% rename from tensorflow/contrib/lite/g3doc/images/performance/model_size_vs_latency.png rename to tensorflow/lite/g3doc/images/performance/model_size_vs_latency.png diff --git a/tensorflow/contrib/lite/g3doc/ios.md b/tensorflow/lite/g3doc/ios.md similarity index 78% rename from tensorflow/contrib/lite/g3doc/ios.md rename to tensorflow/lite/g3doc/ios.md index 3b9fcca8117dc1..c195b6abf4f76f 100644 --- a/tensorflow/contrib/lite/g3doc/ios.md +++ b/tensorflow/lite/g3doc/ios.md @@ -41,24 +41,24 @@ brew link libtool Then you need to run a shell script to download the dependencies you need: ```bash -tensorflow/contrib/lite/tools/make/download_dependencies.sh +tensorflow/lite/tools/make/download_dependencies.sh ``` This will fetch copies of libraries and data from the web and install them in -`tensorflow/contrib/lite/downloads`. +`tensorflow/lite/downloads`. With all of the dependencies set up, you can now build the library for all five supported architectures on iOS: ```bash -tensorflow/contrib/lite/tools/make/build_ios_universal_lib.sh +tensorflow/lite/tools/make/build_ios_universal_lib.sh ``` -Under the hood this uses a makefile in `tensorflow/contrib/lite` to build the +Under the hood this uses a makefile in `tensorflow/lite` to build the different versions of the library, followed by a call to `lipo` to bundle them into a universal file containing armv7, armv7s, arm64, i386, and x86_64 architectures. The resulting library is in -`tensorflow/contrib/lite/tools/make/gen/lib/libtensorflow-lite.a`. +`tensorflow/lite/tools/make/gen/lib/libtensorflow-lite.a`. If you get an error such as `no such file or directory: 'x86_64'` when running `build_ios_universal_lib.sh`: open Xcode > Preferences > Locations, and ensure @@ -68,19 +68,19 @@ a value is selected in the "Command Line Tools" dropdown. You'll need to update various settings in your app to link against TensorFlow Lite. You can view them in the example project at -`tensorflow/contrib/lite/examples/ios/simple/simple.xcodeproj` but here's a full +`tensorflow/lite/examples/ios/simple/simple.xcodeproj` but here's a full rundown: - You'll need to add the library at - `tensorflow/contrib/lite/gen/lib/libtensorflow-lite.a` to your linking build - stage, and in Search Paths add `tensorflow/contrib/lite/gen/lib` to the + `tensorflow/lite/gen/lib/libtensorflow-lite.a` to your linking build + stage, and in Search Paths add `tensorflow/lite/gen/lib` to the Library Search Paths setting. - The _Header Search_ paths needs to contain: - the root folder of tensorflow, - - `tensorflow/contrib/lite/downloads` - - `tensorflow/contrib/lite/downloads/flatbuffers/include` + - `tensorflow/lite/downloads` + - `tensorflow/lite/downloads/flatbuffers/include` - C++11 support (or later) should be enabled by setting `C++ Language Dialect` to `GNU++11` (or `GNU++14`), and `C++ Standard Library` to `libc++`. diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/lite/g3doc/models.md similarity index 98% rename from tensorflow/contrib/lite/g3doc/models.md rename to tensorflow/lite/g3doc/models.md index 279764ce964e52..537e285490f905 100644 --- a/tensorflow/contrib/lite/g3doc/models.md +++ b/tensorflow/lite/g3doc/models.md @@ -7,13 +7,13 @@ Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Ac ------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ---------------------: MnasNet_0.50_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_0.5_224_09_07_2018.tgz) | 8.5 Mb | 68.03% | 87.79% | 37 ms MnasNet_0.75_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_0.75_224_09_07_2018.tgz) | 12 Mb | 71.72% | 90.17% | 61 ms -MnasNet_1.0_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_224_09_07_2018.tgz) | 17 Mb | 74.08% | 91.75% | 93 ms -MnasNet_1.3_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.3_224_09_07_2018.tgz) | 24 Mb | 75.24% | 92.55% | 152 ms MnasNet_1.0_96| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_96_09_07_2018.tgz) | 17 Mb | 62.33% | 83.98% | 23 ms MnasNet_1.0_128| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_128_09_07_2018.tgz) | 17 Mb | 67.32% | 87.70% | 34 ms MnasNet_1.0_160| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_160_09_07_2018.tgz) | 17 Mb | 70.63% | 89.58% | 51 ms MnasNet_1.0_192| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_192_09_07_2018.tgz) | 17 Mb | 72.56% | 90.76% | 70 ms MnasNet_1.0_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_224_09_07_2018.tgz) | 17 Mb | 74.08% | 91.75% | 93 ms +MnasNet_1.3_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.3_224_09_07_2018.tgz) | 24 Mb | 75.24% | 92.55% | 152 ms + ^ Performance numbers are generated on Pixel-1 using single thread large BIG core. diff --git a/tensorflow/contrib/lite/g3doc/ops_versioning.md b/tensorflow/lite/g3doc/ops_versioning.md similarity index 100% rename from tensorflow/contrib/lite/g3doc/ops_versioning.md rename to tensorflow/lite/g3doc/ops_versioning.md diff --git a/tensorflow/contrib/lite/g3doc/overview.md b/tensorflow/lite/g3doc/overview.md similarity index 97% rename from tensorflow/contrib/lite/g3doc/overview.md rename to tensorflow/lite/g3doc/overview.md index 9d035a69211d7c..2d747a9b59f734 100644 --- a/tensorflow/contrib/lite/g3doc/overview.md +++ b/tensorflow/lite/g3doc/overview.md @@ -12,7 +12,7 @@ optimizing the kernels for mobile apps, pre-fused activations, and quantized kernels that allow smaller and faster (fixed-point math) models. Most of our TensorFlow Lite documentation is [on -GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite) +GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite) for the time being. ## What does TensorFlow Lite contain? @@ -118,7 +118,7 @@ TensorFlow Lite provides: to all first-party and third-party apps. Also see the complete list of - [TensorFlow Lite's supported models](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/models.md), + [TensorFlow Lite's supported models](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/models.md), including the model sizes, performance numbers, and downloadable model files. - Quantized versions of the MobileNet model, which runs faster than the @@ -136,7 +136,7 @@ We recommend you try out TensorFlow Lite with the pre-tested models indicated above. If you have an existing model, you will need to test whether your model is compatible with both the converter and the supported operator set. To test your model, see the -[documentation on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite). +[documentation on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite). ### Retrain Inception-V3 or MobileNet for a custom data set @@ -198,5 +198,5 @@ possible performance for a particular model on a particular device. ## Next Steps -The TensorFlow Lite [GitHub repository](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite). +The TensorFlow Lite [GitHub repository](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite). contains additional docs, code samples, and demo applications. diff --git a/tensorflow/contrib/lite/g3doc/performance/benchmarks.md b/tensorflow/lite/g3doc/performance/benchmarks.md similarity index 93% rename from tensorflow/contrib/lite/g3doc/performance/benchmarks.md rename to tensorflow/lite/g3doc/performance/benchmarks.md index 28cb6aba6ec61d..5a1e5586beecad 100644 --- a/tensorflow/contrib/lite/g3doc/performance/benchmarks.md +++ b/tensorflow/lite/g3doc/performance/benchmarks.md @@ -5,17 +5,17 @@ This document lists TensorFlow Lite performance benchmarks when running well known models on some Android and iOS devices. These performance benchmark numbers were generated with the -[Android TFLite benchmark binary](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark) -and the [iOS benchmark app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios). +[Android TFLite benchmark binary](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark) +and the [iOS benchmark app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark/ios). # Android performance benchmarks For Android benchmarks, the CPU affinity is set to use big cores on the device to -reduce variance (see [details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#reducing-variance-between-runs-on-android)). +reduce variance (see [details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark#reducing-variance-between-runs-on-android)). It assumes that models were download and unzipped to the `/data/local/tmp/tflite_models` directory. The benchmark binary is built -using [these instructions](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#on-android) +using [these instructions](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark#on-android) and assumed in the `/data/local/tmp` directory. To run the benchmark: @@ -117,7 +117,7 @@ Pixel xl | 0c | # iOS benchmarks To run iOS benchmarks, the [benchmark -app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios) +app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark/ios) was modified to include the appropriate model and `benchmark_params.json` was modified to set `num_threads` to 1. diff --git a/tensorflow/contrib/lite/g3doc/performance/best_practices.md b/tensorflow/lite/g3doc/performance/best_practices.md similarity index 84% rename from tensorflow/contrib/lite/g3doc/performance/best_practices.md rename to tensorflow/lite/g3doc/performance/best_practices.md index 180e51e5f6664b..b76414cebe0d70 100644 --- a/tensorflow/contrib/lite/g3doc/performance/best_practices.md +++ b/tensorflow/lite/g3doc/performance/best_practices.md @@ -18,7 +18,7 @@ You can retrain the listed models on your own dataset by using transfer learning ## Profile your model -Once you have selected a candidate model that is right for your task, it is a good practice to profile and benchmark your model. Tensorflow Lite [benchmarking tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark) has a built-in profiler that shows per operator profiling statistics. This can help in understanding performance bottlenecks and which operators dominate the computation time. +Once you have selected a candidate model that is right for your task, it is a good practice to profile and benchmark your model. Tensorflow Lite [benchmarking tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark) has a built-in profiler that shows per operator profiling statistics. This can help in understanding performance bottlenecks and which operators dominate the computation time. ## Profile and optimize operators in the graph If a particular operator appears frequently in the model and based on profiling you find the operator consuming the most amount of time, you can look into optimizing the operator. @@ -28,17 +28,17 @@ If a particular operator appears frequently in the model and based on profiling If your model uses floating point weights or activations then it may be possible to reduce the size of model up to ~4x by using quantization and other model optimizations. Check out our [model optimization toolkit](model_optimization.md) for details about optimizing your model. ## Tweak the number of threads -Tensorflow Lite supports multi-threaded kernels for many operators. You can increase the number of threads and speed up execution of operators. Increasing the number of threads will however make your model use more resources and power. For some applications latency may be more important than energy efficiency. You can increase the number of threads by setting the number of [interpreter](https://github.com/tensorflow/tensorflow/blob/1084594657a5d139102ac794f84d1427a710e39a/tensorflow/contrib/lite/interpreter.h#L337) threads. Multi-threaded execution however comes at the cost of increased performance variability depending on what else is been executed concurrently. This is particularly the case for mobile apps. For example, isolated tests may show 2x speed up vs single-threaded but if another app is executing at the same time may result in worst performance than single-threaded. +Tensorflow Lite supports multi-threaded kernels for many operators. You can increase the number of threads and speed up execution of operators. Increasing the number of threads will however make your model use more resources and power. For some applications latency may be more important than energy efficiency. You can increase the number of threads by setting the number of [interpreter](https://github.com/tensorflow/tensorflow/blob/1084594657a5d139102ac794f84d1427a710e39a/tensorflow/lite/interpreter.h#L337) threads. Multi-threaded execution however comes at the cost of increased performance variability depending on what else is been executed concurrently. This is particularly the case for mobile apps. For example, isolated tests may show 2x speed up vs single-threaded but if another app is executing at the same time may result in worst performance than single-threaded. ## Eliminate redundant copies -If your application is not careful, there can be redundant copies when feeding the input to the model and reading output from the model. Make sure to eliminate redundant copies. If you are using higher level APIs like Java API, make sure to carefully check the documentation for performance caveats. For example, the Java API is a lot faster if ByteBuffers are used as [inputs](https://github.com/tensorflow/tensorflow/blob/6305a6d83552ba6a472cd72398b60d9241467f1f/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java#L151). +If your application is not careful, there can be redundant copies when feeding the input to the model and reading output from the model. Make sure to eliminate redundant copies. If you are using higher level APIs like Java API, make sure to carefully check the documentation for performance caveats. For example, the Java API is a lot faster if ByteBuffers are used as [inputs](https://github.com/tensorflow/tensorflow/blob/6305a6d83552ba6a472cd72398b60d9241467f1f/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java#L151). ## Profile your application with platform specific tools Platform specific tools like [Android profiler](https://developer.android.com/studio/profile/android-profiler) and [Instruments](https://help.apple.com/instruments/mac/current/) provide a wealth of profiling information that can be used to debug your app. Sometimes the performance bug may be not in the model but in parts of application code that interact with the model. Make sure to familiarize yourself with platform specific profiling tools and best practices for your platform. ## Evaluate whether your model benefits from using hardware accelerators available on the device Tensorflow Lite is working on adding support for accelerators like GPU and provides acceleration through [Neural Networks API](https://developer.android.com/ndk/guides/neuralnetworks/) on Android. -You can utilize these hardware accelerator backends to improve the speed and efficiency of your model. To enable Neural Networks API call [UseNNAPI](https://github.com/tensorflow/tensorflow/blob/6305a6d83552ba6a472cd72398b60d9241467f1f/tensorflow/contrib/lite/interpreter.h#L334) on the interpreter instance. +You can utilize these hardware accelerator backends to improve the speed and efficiency of your model. To enable Neural Networks API call [UseNNAPI](https://github.com/tensorflow/tensorflow/blob/6305a6d83552ba6a472cd72398b60d9241467f1f/tensorflow/lite/interpreter.h#L334) on the interpreter instance. ## Need more help The Tensorflow team is happy to help diagnose and address specific performance issues you may be facing. Please file an issue on [GitHub](https://github.com/tensorflow/tensorflow/issues) with details of the issue. diff --git a/tensorflow/contrib/lite/g3doc/performance/model_optimization.md b/tensorflow/lite/g3doc/performance/model_optimization.md similarity index 100% rename from tensorflow/contrib/lite/g3doc/performance/model_optimization.md rename to tensorflow/lite/g3doc/performance/model_optimization.md diff --git a/tensorflow/contrib/lite/g3doc/performance/post_training_quantization.md b/tensorflow/lite/g3doc/performance/post_training_quantization.md similarity index 91% rename from tensorflow/contrib/lite/g3doc/performance/post_training_quantization.md rename to tensorflow/lite/g3doc/performance/post_training_quantization.md index d95cab94aaea9a..cf4d70b2deb337 100644 --- a/tensorflow/contrib/lite/g3doc/performance/post_training_quantization.md +++ b/tensorflow/lite/g3doc/performance/post_training_quantization.md @@ -7,7 +7,7 @@ is enabled as an option in [TensorFlow Lite model converter](../convert): ``` import tensorflow as tf -converter = tf.contrib.lite.TocoConverter.from_saved_model(saved_model_dir) +converter = tf.lite.TocoConverter.from_saved_model(saved_model_dir) converter.post_training_quantize = True tflite_quantized_model = converter.convert() open("quantized_model.tflite", "wb").write(tflite_quantized_model) @@ -33,8 +33,8 @@ Hybrid ops are available for the most compute-intensive operators in a network: Since weights are quantized post-training, there could be an accuracy loss, particularly for smaller networks. Pre-trained fully quantized models are provided for specific networks in -the [TensorFlow Lite model repository](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/models.md#image-classification-quantized-models){:.external}. It is important to check the accuracy of the quantized model to verify that any degradation -in accuracy is within acceptable limits. There is a tool to evaluate [TensorFlow Lite model accuracy](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/tools/accuracy/README.md){:.external}. +the [TensorFlow Lite model repository](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/models.md#image-classification-quantized-models){:.external}. It is important to check the accuracy of the quantized model to verify that any degradation +in accuracy is within acceptable limits. There is a tool to evaluate [TensorFlow Lite model accuracy](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/accuracy/README.md){:.external}. If the accuracy drop is too high, consider using [quantization aware training](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/quantize/README.md){:.external}. diff --git a/tensorflow/contrib/lite/g3doc/rpi.md b/tensorflow/lite/g3doc/rpi.md similarity index 78% rename from tensorflow/contrib/lite/g3doc/rpi.md rename to tensorflow/lite/g3doc/rpi.md index 41a1892b6f179f..708d9e328cbdff 100644 --- a/tensorflow/contrib/lite/g3doc/rpi.md +++ b/tensorflow/lite/g3doc/rpi.md @@ -23,18 +23,18 @@ Clone this Tensorflow repository, Run this script at the root of the repository > The Tensorflow repository is in `/tensorflow` if you are using `tensorflow/tensorflow:nightly-devel` docker image, just try it. ```bash -./tensorflow/contrib/lite/tools/make/download_dependencies.sh +./tensorflow/lite/tools/make/download_dependencies.sh ``` Note that you only need to do this once. You should then be able to compile: ```bash -./tensorflow/contrib/lite/tools/make/build_rpi_lib.sh +./tensorflow/lite/tools/make/build_rpi_lib.sh ``` This should compile a static library in: -`tensorflow/contrib/lite/gen/lib/rpi_armv7/libtensorflow-lite.a`. +`tensorflow/lite/gen/lib/rpi_armv7/libtensorflow-lite.a`. ## Native compiling This has been tested on Raspberry Pi 3b, Raspbian GNU/Linux 9.1 (stretch), gcc version 6.3.0 20170516 (Raspbian 6.3.0-18+rpi1). @@ -48,14 +48,14 @@ sudo apt-get install build-essential First, clone the TensorFlow repository. Run this at the root of the repository: ```bash -./tensorflow/contrib/lite/tools/make/download_dependencies.sh +./tensorflow/lite/tools/make/download_dependencies.sh ``` Note that you only need to do this once. You should then be able to compile: ```bash -./tensorflow/contrib/lite/tools/make/build_rpi_lib.sh +./tensorflow/lite/tools/make/build_rpi_lib.sh ``` This should compile a static library in: -`tensorflow/contrib/lite/tools/make/gen/lib/rpi_armv7/libtensorflow-lite.a`. +`tensorflow/lite/tools/make/gen/lib/rpi_armv7/libtensorflow-lite.a`. diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/lite/g3doc/tf_ops_compatibility.md similarity index 100% rename from tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md rename to tensorflow/lite/g3doc/tf_ops_compatibility.md diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md b/tensorflow/lite/g3doc/tfmobile/android_build.md similarity index 100% rename from tensorflow/contrib/lite/g3doc/tfmobile/android_build.md rename to tensorflow/lite/g3doc/tfmobile/android_build.md diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/index.md b/tensorflow/lite/g3doc/tfmobile/index.md similarity index 100% rename from tensorflow/contrib/lite/g3doc/tfmobile/index.md rename to tensorflow/lite/g3doc/tfmobile/index.md diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md b/tensorflow/lite/g3doc/tfmobile/ios_build.md similarity index 100% rename from tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md rename to tensorflow/lite/g3doc/tfmobile/ios_build.md diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md b/tensorflow/lite/g3doc/tfmobile/linking_libs.md similarity index 100% rename from tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md rename to tensorflow/lite/g3doc/tfmobile/linking_libs.md diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md b/tensorflow/lite/g3doc/tfmobile/optimizing.md similarity index 100% rename from tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md rename to tensorflow/lite/g3doc/tfmobile/optimizing.md diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md b/tensorflow/lite/g3doc/tfmobile/prepare_models.md similarity index 100% rename from tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md rename to tensorflow/lite/g3doc/tfmobile/prepare_models.md diff --git a/tensorflow/contrib/lite/graph_info.cc b/tensorflow/lite/graph_info.cc similarity index 99% rename from tensorflow/contrib/lite/graph_info.cc rename to tensorflow/lite/graph_info.cc index e60ed2c2463cb6..cdbe66a3a4fa90 100644 --- a/tensorflow/contrib/lite/graph_info.cc +++ b/tensorflow/lite/graph_info.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/graph_info.h" +#include "tensorflow/lite/graph_info.h" #include namespace tflite { diff --git a/tensorflow/contrib/lite/graph_info.h b/tensorflow/lite/graph_info.h similarity index 93% rename from tensorflow/contrib/lite/graph_info.h rename to tensorflow/lite/graph_info.h index 8ee83827bb3fdf..ff7ce669aceccf 100644 --- a/tensorflow/contrib/lite/graph_info.h +++ b/tensorflow/lite/graph_info.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_GRAPH_INFO_H_ -#define TENSORFLOW_CONTRIB_LITE_GRAPH_INFO_H_ +#ifndef TENSORFLOW_LITE_GRAPH_INFO_H_ +#define TENSORFLOW_LITE_GRAPH_INFO_H_ #include -#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/lite/c/c_api_internal.h" namespace tflite { @@ -79,4 +79,4 @@ TfLiteStatus PartitionGraphIntoIndependentSubgraphs( } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_GRAPH_INFO_H_ +#endif // TENSORFLOW_LITE_GRAPH_INFO_H_ diff --git a/tensorflow/contrib/lite/graph_info_test.cc b/tensorflow/lite/graph_info_test.cc similarity index 99% rename from tensorflow/contrib/lite/graph_info_test.cc rename to tensorflow/lite/graph_info_test.cc index 89a8f36b416b5d..5ecc3774e13060 100644 --- a/tensorflow/contrib/lite/graph_info_test.cc +++ b/tensorflow/lite/graph_info_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/graph_info.h" -#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/lite/graph_info.h" +#include "tensorflow/lite/testing/util.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/lite/interpreter.cc similarity index 98% rename from tensorflow/contrib/lite/interpreter.cc rename to tensorflow/lite/interpreter.cc index c72e7bf33ebbaa..bff7145de998f2 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/lite/interpreter.cc @@ -13,23 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/lite/interpreter.h" #include #include #include #include -#include "tensorflow/contrib/lite/arena_planner.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/context_util.h" -#include "tensorflow/contrib/lite/core/api/error_reporter.h" -#include "tensorflow/contrib/lite/graph_info.h" -#include "tensorflow/contrib/lite/memory_planner.h" -#include "tensorflow/contrib/lite/nnapi_delegate.h" -#include "tensorflow/contrib/lite/profiling/profiler.h" -#include "tensorflow/contrib/lite/schema/schema_generated.h" -#include "tensorflow/contrib/lite/util.h" +#include "tensorflow/lite/arena_planner.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/context_util.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/graph_info.h" +#include "tensorflow/lite/memory_planner.h" +#include "tensorflow/lite/nnapi_delegate.h" +#include "tensorflow/lite/profiling/profiler.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/util.h" namespace tflite { namespace { @@ -933,9 +933,8 @@ void Interpreter::SwitchToKernelContext() { SetForbiddenContextFunction(&context_.GetExecutionPlan); } -TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate, - bool allow_dynamic_tensors) { - if (!allow_dynamic_tensors) { +TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) { + if (!(delegate->flags & kTfLiteDelegateFlagsAllowDynamicTensors)) { int last_execution_plan_index_prepared; TF_LITE_ENSURE_OK(&context_, PrepareOpsStartingAt( 0, &last_execution_plan_index_prepared)); @@ -971,7 +970,7 @@ TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate, TF_LITE_ENSURE_OK(&context_, status); - if (!allow_dynamic_tensors) { + if (!(delegate->flags & kTfLiteDelegateFlagsAllowDynamicTensors)) { // Reset the state to force tensor/op reallocation. state_ = kStateUninvokable; TF_LITE_ENSURE_OK(&context_, AllocateTensors()); diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/lite/interpreter.h similarity index 97% rename from tensorflow/contrib/lite/interpreter.h rename to tensorflow/lite/interpreter.h index cbd042fa924000..7178b201ecc708 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/lite/interpreter.h @@ -14,20 +14,20 @@ limitations under the License. ==============================================================================*/ // Main abstraction controlling the tflite interpreter. // See context.h for the API for defining operations (TfLiteRegistration). -#ifndef TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ -#define TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ +#ifndef TENSORFLOW_LITE_INTERPRETER_H_ +#define TENSORFLOW_LITE_INTERPRETER_H_ #include #include #include #include -#include "tensorflow/contrib/lite/allocation.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/core/api/error_reporter.h" -#include "tensorflow/contrib/lite/memory_planner.h" -#include "tensorflow/contrib/lite/profiling/profiler.h" -#include "tensorflow/contrib/lite/stderr_reporter.h" +#include "tensorflow/lite/allocation.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/memory_planner.h" +#include "tensorflow/lite/profiling/profiler.h" +#include "tensorflow/lite/stderr_reporter.h" namespace tflite { @@ -360,8 +360,7 @@ class Interpreter { // parts of the graph themselves. After this is called, the graph may // contain new nodes that replace 1 more nodes. // WARNING: This is an experimental API and subject to change. - TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate, - bool allow_dynamic_tensors = false); + TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate); // Ensure the data in `tensor.data` is readable. In case delegate is used, // it might require to copy the data from delegate buffer to raw memory. @@ -581,13 +580,11 @@ class Interpreter { // Variant of the public ModifyGraphWithDelegate method that additionally // Assumes ownership of the provided delegate. // WARNING: This is an experimental API and subject to change. - TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegatePtr delegate, - bool allow_dynamic_tensors = false) { + TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegatePtr delegate) { // Note that we retain ownership of the delegate even if graph modification // fails, as delegate use will be in an indeterminate state at that point. owned_delegates_.push_back(std::move(delegate)); - return ModifyGraphWithDelegate(owned_delegates_.back().get(), - allow_dynamic_tensors); + return ModifyGraphWithDelegate(owned_delegates_.back().get()); } // Ensures that `tensors_` has at least `kTensorsCapacityHeadroom` extra @@ -692,4 +689,4 @@ class Interpreter { }; } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ +#endif // TENSORFLOW_LITE_INTERPRETER_H_ diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/lite/interpreter_test.cc similarity index 98% rename from tensorflow/contrib/lite/interpreter_test.cc rename to tensorflow/lite/interpreter_test.cc index 6c71d5a8d7bb3e..3ac19fc87d1870 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/lite/interpreter_test.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/lite/interpreter.h" #include -#include "tensorflow/contrib/lite/core/api/error_reporter.h" -#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/schema/schema_generated.h" -#include "tensorflow/contrib/lite/string_util.h" -#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/string_util.h" +#include "tensorflow/lite/testing/util.h" namespace tflite { @@ -1109,6 +1109,7 @@ class TestDelegate : public ::testing::Test { TfLiteBufferHandle* handle) { *handle = kTfLiteNullBufferHandle; }; // Store type-punned data SimpleDelegate structure. delegate_.data_ = reinterpret_cast(this); + delegate_.flags = kTfLiteDelegateFlagsNone; } static TfLiteRegistration FakeFusedRegistration() { @@ -1210,7 +1211,7 @@ TEST_F(TestDelegate, SetInvalidHandleToTensor) { interpreter_->Invoke(); delegate_ = std::unique_ptr(new SimpleDelegate({0, 1, 2})); TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate(); - interpreter_->ModifyGraphWithDelegate(delegate, true); + interpreter_->ModifyGraphWithDelegate(delegate); SimpleDelegate another_simple_delegate({0, 1, 2}); @@ -1268,6 +1269,7 @@ class TestDelegateWithDynamicTensors : public ::testing::Test { context, DelegateRegistration(), execution_plan, delegate); return kTfLiteOk; }; + delegate_.flags = kTfLiteDelegateFlagsNone; } static TfLiteRegistration DynamicCopyOpRegistration() { @@ -1296,7 +1298,7 @@ class TestDelegateWithDynamicTensors : public ::testing::Test { }; TEST_F(TestDelegateWithDynamicTensors, DisallowDynamicTensors) { - interpreter_->ModifyGraphWithDelegate(&delegate_, false); + interpreter_->ModifyGraphWithDelegate(&delegate_); ASSERT_EQ(interpreter_->execution_plan().size(), 1); // The interpreter should not call delegate's `Prepare` when dynamic tensors @@ -1305,7 +1307,8 @@ TEST_F(TestDelegateWithDynamicTensors, DisallowDynamicTensors) { } TEST_F(TestDelegateWithDynamicTensors, AllowDynamicTensors) { - interpreter_->ModifyGraphWithDelegate(&delegate_, true); + delegate_.flags = kTfLiteDelegateFlagsAllowDynamicTensors; + interpreter_->ModifyGraphWithDelegate(&delegate_); ASSERT_EQ(interpreter_->execution_plan().size(), 1); // The node should be replaced because dynamic tensors are allowed. Therefore @@ -1317,6 +1320,7 @@ TEST(TestDelegateOwnership, ProperlyDisposed) { struct TfLiteInterpreterOwnedDelegate : public TfLiteDelegate { TfLiteInterpreterOwnedDelegate(bool* destroyed, bool* prepared) : destroyed(destroyed), prepared(prepared) { + flags = kTfLiteDelegateFlagsNone; Prepare = [](TfLiteContext*, TfLiteDelegate* delegate) -> TfLiteStatus { *static_cast(delegate)->prepared = true; diff --git a/tensorflow/contrib/lite/java/AndroidManifest.xml b/tensorflow/lite/java/AndroidManifest.xml similarity index 100% rename from tensorflow/contrib/lite/java/AndroidManifest.xml rename to tensorflow/lite/java/AndroidManifest.xml diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/lite/java/BUILD similarity index 85% rename from tensorflow/contrib/lite/java/BUILD rename to tensorflow/lite/java/BUILD index c34f6ccfa05b32..cf759fa00c617d 100644 --- a/tensorflow/contrib/lite/java/BUILD +++ b/tensorflow/lite/java/BUILD @@ -2,14 +2,14 @@ # TensorFlow Lite Java API. package(default_visibility = [ - "//tensorflow/contrib/lite/java/ovic:__pkg__", + "//tensorflow/lite/java/ovic:__pkg__", ]) licenses(["notice"]) # Apache 2.0 load("//tensorflow/java:build_defs.bzl", "JAVACOPTS") -load("//tensorflow/contrib/lite:build_def.bzl", "tflite_jni_binary") -load("//tensorflow/contrib/lite/java:aar_with_jni.bzl", "aar_with_jni") +load("//tensorflow/lite:build_def.bzl", "tflite_jni_binary") +load("//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni") JAVA_SRCS = glob([ "src/main/java/org/tensorflow/lite/*.java", @@ -18,7 +18,7 @@ JAVA_SRCS = glob([ # Building tensorflow-lite.aar including 4 variants of .so # To build an aar for release, run below command: # bazel build --cxxopt='--std=c++11' -c opt --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \ -# tensorflow/contrib/lite/java:tensorflow-lite +# tensorflow/lite/java:tensorflow-lite aar_with_jni( name = "tensorflow-lite", android_library = ":tensorflowlite", @@ -90,7 +90,6 @@ java_test( size = "small", srcs = ["src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java"], javacopts = JAVACOPTS, - tags = ["no_oss"], test_class = "org.tensorflow.lite.TensorFlowLiteTest", deps = [ ":tensorflowlitelib", @@ -104,7 +103,6 @@ java_test( size = "small", srcs = ["src/test/java/org/tensorflow/lite/DataTypeTest.java"], javacopts = JAVACOPTS, - tags = ["no_oss"], test_class = "org.tensorflow.lite.DataTypeTest", deps = [ ":tensorflowlitelib", @@ -122,11 +120,11 @@ java_test( "src/testdata/int32.bin", "src/testdata/int64.bin", "src/testdata/invalid_model.bin", + "src/testdata/quantized.bin", "src/testdata/uint8.bin", "src/testdata/with_custom_op.lite", ], javacopts = JAVACOPTS, - tags = ["no_oss"], test_class = "org.tensorflow.lite.NativeInterpreterWrapperTest", deps = [ ":tensorflowlitelib", @@ -142,16 +140,14 @@ java_test( srcs = ["src/test/java/org/tensorflow/lite/InterpreterTest.java"], data = [ "src/testdata/add.bin", - "src/testdata/mobilenet.tflite.bin", - "//tensorflow/contrib/lite:testdata/multi_add_flex.bin", + "//tensorflow/lite:testdata/multi_add_flex.bin", ], javacopts = JAVACOPTS, - tags = ["no_oss"], test_class = "org.tensorflow.lite.InterpreterTest", visibility = ["//visibility:private"], deps = [ ":tensorflowlitelib", - "//tensorflow/contrib/lite/java/src/test/native:libtensorflowlite_test_jni.so", + "//tensorflow/lite/java/src/test/native:libtensorflowlite_test_jni.so", "@com_google_truth", "@junit", ], @@ -162,10 +158,9 @@ java_test( size = "small", srcs = ["src/test/java/org/tensorflow/lite/InterpreterFlexTest.java"], data = [ - "//tensorflow/contrib/lite:testdata/multi_add_flex.bin", + "//tensorflow/lite:testdata/multi_add_flex.bin", ], javacopts = JAVACOPTS, - tags = ["no_oss"], test_class = "org.tensorflow.lite.InterpreterFlexTest", visibility = ["//visibility:private"], deps = [ @@ -183,7 +178,6 @@ java_test( "src/testdata/add.bin", ], javacopts = JAVACOPTS, - tags = ["no_oss"], test_class = "org.tensorflow.lite.TensorTest", deps = [ ":tensorflowlitelib", @@ -215,7 +209,7 @@ cc_library( tflite_jni_binary( name = "libtensorflowlite_jni.so", deps = [ - "//tensorflow/contrib/lite/java/src/main/native", + "//tensorflow/lite/java/src/main/native", ], ) @@ -223,8 +217,8 @@ tflite_jni_binary( tflite_jni_binary( name = "libtensorflowlite_flex_jni.so", deps = [ - "//tensorflow/contrib/lite/delegates/flex:delegate", - "//tensorflow/contrib/lite/java/src/main/native", - "//tensorflow/contrib/lite/java/src/main/native:init_tensorflow", + "//tensorflow/lite/delegates/flex:delegate", + "//tensorflow/lite/java/src/main/native", + "//tensorflow/lite/java/src/main/native:init_tensorflow", ], ) diff --git a/tensorflow/contrib/lite/java/aar_with_jni.bzl b/tensorflow/lite/java/aar_with_jni.bzl similarity index 100% rename from tensorflow/contrib/lite/java/aar_with_jni.bzl rename to tensorflow/lite/java/aar_with_jni.bzl diff --git a/tensorflow/contrib/lite/java/build_aar_for_release.sh b/tensorflow/lite/java/build_aar_for_release.sh similarity index 98% rename from tensorflow/contrib/lite/java/build_aar_for_release.sh rename to tensorflow/lite/java/build_aar_for_release.sh index fbcb1e7db9a3f9..54be643fc7e0ae 100755 --- a/tensorflow/contrib/lite/java/build_aar_for_release.sh +++ b/tensorflow/lite/java/build_aar_for_release.sh @@ -22,7 +22,7 @@ trap "rm -rf $TMPDIR" EXIT VERSION=1.0 BUILDER=bazel -BASEDIR=tensorflow/contrib/lite +BASEDIR=tensorflow/lite CROSSTOOL="//external:android/crosstool" HOST_CROSSTOOL="@bazel_tools//tools/cpp:toolchain" diff --git a/tensorflow/contrib/lite/java/demo/.gitignore b/tensorflow/lite/java/demo/.gitignore similarity index 100% rename from tensorflow/contrib/lite/java/demo/.gitignore rename to tensorflow/lite/java/demo/.gitignore diff --git a/tensorflow/contrib/lite/java/demo/README.md b/tensorflow/lite/java/demo/README.md similarity index 93% rename from tensorflow/contrib/lite/java/demo/README.md rename to tensorflow/lite/java/demo/README.md index c04b2a61942430..b5bfe39ce7f6ab 100644 --- a/tensorflow/contrib/lite/java/demo/README.md +++ b/tensorflow/lite/java/demo/README.md @@ -42,12 +42,12 @@ code to merge. ```shell bazel build -c opt --cxxopt='--std=c++11' \ - //tensorflow/contrib/lite/java/demo/app/src/main:TfLiteCameraDemo + //tensorflow/lite/java/demo/app/src/main:TfLiteCameraDemo ``` 3. Install the demo on a [debug-enabled device](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install): ```shell - adb install bazel-bin/tensorflow/contrib/lite/java/demo/app/src/main/TfLiteCameraDemo.apk + adb install bazel-bin/tensorflow/lite/java/demo/app/src/main/TfLiteCameraDemo.apk ``` diff --git a/tensorflow/contrib/lite/java/demo/app/build.gradle b/tensorflow/lite/java/demo/app/build.gradle similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/build.gradle rename to tensorflow/lite/java/demo/app/build.gradle diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml b/tensorflow/lite/java/demo/app/src/main/AndroidManifest.xml similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml rename to tensorflow/lite/java/demo/app/src/main/AndroidManifest.xml diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD b/tensorflow/lite/java/demo/app/src/main/BUILD similarity index 78% rename from tensorflow/contrib/lite/java/demo/app/src/main/BUILD rename to tensorflow/lite/java/demo/app/src/main/BUILD index 5ad738389eb8bc..df8a024a570fe0 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD +++ b/tensorflow/lite/java/demo/app/src/main/BUILD @@ -9,7 +9,7 @@ android_binary( srcs = glob(["java/**/*.java"]), aapt_version = "aapt", assets = [ - "//tensorflow/contrib/lite/java/demo/app/src/main/assets:labels_mobilenet_quant_v1_224.txt", + "//tensorflow/lite/java/demo/app/src/main/assets:labels_mobilenet_quant_v1_224.txt", "@tflite_mobilenet//:mobilenet_quant_v1_224.tflite", ], assets_dir = "", @@ -24,8 +24,8 @@ android_binary( # use the target in that case. tags = ["manual"], deps = [ - "//tensorflow/contrib/lite/java:tensorflowlite", - "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", + "//tensorflow/lite/java:tensorflowlite", + "//tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", "@androidsdk//com.android.support:support-v13-25.2.0", "@androidsdk//com.android.support:support-v4-25.2.0", ], diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD b/tensorflow/lite/java/demo/app/src/main/assets/BUILD similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD rename to tensorflow/lite/java/demo/app/src/main/assets/BUILD diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_imagenet_slim.txt b/tensorflow/lite/java/demo/app/src/main/assets/labels_imagenet_slim.txt similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_imagenet_slim.txt rename to tensorflow/lite/java/demo/app/src/main/assets/labels_imagenet_slim.txt diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_mobilenet_quant_v1_224.txt b/tensorflow/lite/java/demo/app/src/main/assets/labels_mobilenet_quant_v1_224.txt similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_mobilenet_quant_v1_224.txt rename to tensorflow/lite/java/demo/app/src/main/assets/labels_mobilenet_quant_v1_224.txt diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/AutoFitTextureView.java b/tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/AutoFitTextureView.java similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/AutoFitTextureView.java rename to tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/AutoFitTextureView.java diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java b/tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java rename to tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/CameraActivity.java b/tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/CameraActivity.java similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/CameraActivity.java rename to tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/CameraActivity.java diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java b/tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java rename to tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierFloatInception.java b/tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierFloatInception.java similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierFloatInception.java rename to tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierFloatInception.java diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java b/tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java rename to tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_action_info.png b/tensorflow/lite/java/demo/app/src/main/res/drawable-hdpi/ic_action_info.png similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_action_info.png rename to tensorflow/lite/java/demo/app/src/main/res/drawable-hdpi/ic_action_info.png diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png b/tensorflow/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png rename to tensorflow/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/tile.9.png b/tensorflow/lite/java/demo/app/src/main/res/drawable-hdpi/tile.9.png similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/tile.9.png rename to tensorflow/lite/java/demo/app/src/main/res/drawable-hdpi/tile.9.png diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_action_info.png b/tensorflow/lite/java/demo/app/src/main/res/drawable-mdpi/ic_action_info.png similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_action_info.png rename to tensorflow/lite/java/demo/app/src/main/res/drawable-mdpi/ic_action_info.png diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png b/tensorflow/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png rename to tensorflow/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_action_info.png b/tensorflow/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_action_info.png similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_action_info.png rename to tensorflow/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_action_info.png diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png b/tensorflow/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png rename to tensorflow/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_action_info.png b/tensorflow/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_action_info.png similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_action_info.png rename to tensorflow/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_action_info.png diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png b/tensorflow/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png rename to tensorflow/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/logo.png b/tensorflow/lite/java/demo/app/src/main/res/drawable-xxhdpi/logo.png similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/logo.png rename to tensorflow/lite/java/demo/app/src/main/res/drawable-xxhdpi/logo.png diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml b/tensorflow/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml rename to tensorflow/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml b/tensorflow/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml rename to tensorflow/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/activity_camera.xml b/tensorflow/lite/java/demo/app/src/main/res/layout/activity_camera.xml similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/res/layout/activity_camera.xml rename to tensorflow/lite/java/demo/app/src/main/res/layout/activity_camera.xml diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml b/tensorflow/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml rename to tensorflow/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-dimens.xml b/tensorflow/lite/java/demo/app/src/main/res/values-sw600dp/template-dimens.xml similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-dimens.xml rename to tensorflow/lite/java/demo/app/src/main/res/values-sw600dp/template-dimens.xml diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-styles.xml b/tensorflow/lite/java/demo/app/src/main/res/values-sw600dp/template-styles.xml similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-styles.xml rename to tensorflow/lite/java/demo/app/src/main/res/values-sw600dp/template-styles.xml diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v11/template-styles.xml b/tensorflow/lite/java/demo/app/src/main/res/values-v11/template-styles.xml similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/res/values-v11/template-styles.xml rename to tensorflow/lite/java/demo/app/src/main/res/values-v11/template-styles.xml diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v21/base-colors.xml b/tensorflow/lite/java/demo/app/src/main/res/values-v21/base-colors.xml similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/res/values-v21/base-colors.xml rename to tensorflow/lite/java/demo/app/src/main/res/values-v21/base-colors.xml diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v21/base-template-styles.xml b/tensorflow/lite/java/demo/app/src/main/res/values-v21/base-template-styles.xml similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/res/values-v21/base-template-styles.xml rename to tensorflow/lite/java/demo/app/src/main/res/values-v21/base-template-styles.xml diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml b/tensorflow/lite/java/demo/app/src/main/res/values/base-strings.xml similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml rename to tensorflow/lite/java/demo/app/src/main/res/values/base-strings.xml diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/colors.xml b/tensorflow/lite/java/demo/app/src/main/res/values/colors.xml similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/res/values/colors.xml rename to tensorflow/lite/java/demo/app/src/main/res/values/colors.xml diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml b/tensorflow/lite/java/demo/app/src/main/res/values/strings.xml similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml rename to tensorflow/lite/java/demo/app/src/main/res/values/strings.xml diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml b/tensorflow/lite/java/demo/app/src/main/res/values/styles.xml similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml rename to tensorflow/lite/java/demo/app/src/main/res/values/styles.xml diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/template-dimens.xml b/tensorflow/lite/java/demo/app/src/main/res/values/template-dimens.xml similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/res/values/template-dimens.xml rename to tensorflow/lite/java/demo/app/src/main/res/values/template-dimens.xml diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/template-styles.xml b/tensorflow/lite/java/demo/app/src/main/res/values/template-styles.xml similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/res/values/template-styles.xml rename to tensorflow/lite/java/demo/app/src/main/res/values/template-styles.xml diff --git a/tensorflow/contrib/lite/java/demo/build.gradle b/tensorflow/lite/java/demo/build.gradle similarity index 100% rename from tensorflow/contrib/lite/java/demo/build.gradle rename to tensorflow/lite/java/demo/build.gradle diff --git a/tensorflow/contrib/lite/java/demo/gradle.properties b/tensorflow/lite/java/demo/gradle.properties similarity index 100% rename from tensorflow/contrib/lite/java/demo/gradle.properties rename to tensorflow/lite/java/demo/gradle.properties diff --git a/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.jar b/tensorflow/lite/java/demo/gradle/wrapper/gradle-wrapper.jar similarity index 100% rename from tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.jar rename to tensorflow/lite/java/demo/gradle/wrapper/gradle-wrapper.jar diff --git a/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.properties b/tensorflow/lite/java/demo/gradle/wrapper/gradle-wrapper.properties similarity index 100% rename from tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.properties rename to tensorflow/lite/java/demo/gradle/wrapper/gradle-wrapper.properties diff --git a/tensorflow/contrib/lite/java/demo/gradlew b/tensorflow/lite/java/demo/gradlew similarity index 100% rename from tensorflow/contrib/lite/java/demo/gradlew rename to tensorflow/lite/java/demo/gradlew diff --git a/tensorflow/contrib/lite/java/demo/gradlew.bat b/tensorflow/lite/java/demo/gradlew.bat similarity index 100% rename from tensorflow/contrib/lite/java/demo/gradlew.bat rename to tensorflow/lite/java/demo/gradlew.bat diff --git a/tensorflow/contrib/lite/java/demo/settings.gradle b/tensorflow/lite/java/demo/settings.gradle similarity index 100% rename from tensorflow/contrib/lite/java/demo/settings.gradle rename to tensorflow/lite/java/demo/settings.gradle diff --git a/tensorflow/contrib/lite/java/ovic/BUILD b/tensorflow/lite/java/ovic/BUILD similarity index 63% rename from tensorflow/contrib/lite/java/ovic/BUILD rename to tensorflow/lite/java/ovic/BUILD index 552468faf4121d..774320871eec9a 100644 --- a/tensorflow/contrib/lite/java/ovic/BUILD +++ b/tensorflow/lite/java/ovic/BUILD @@ -15,15 +15,15 @@ java_test( size = "medium", srcs = ["src/test/java/org/tensorflow/ovic/OvicClassifierTest.java"], data = [ - "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt", - "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata", + "//tensorflow/lite/java/ovic/src/testdata:labels.txt", + "//tensorflow/lite/java/ovic/src/testdata:ovic_testdata", ], javacopts = JAVACOPTS, tags = ["no_oss"], test_class = "org.tensorflow.ovic.OvicClassifierTest", visibility = ["//visibility:public"], deps = [ - "//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib_java", + "//tensorflow/lite/java/ovic:ovicbenchmarkerlib_java", "@com_google_truth", "@junit", ], @@ -33,13 +33,13 @@ java_binary( name = "ovic_validator", srcs = ["src/main/java/org/tensorflow/ovic/OvicValidator.java"], data = [ - "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt", + "//tensorflow/lite/java/ovic/src/testdata:labels.txt", ], main_class = "org.tensorflow.ovic.OvicValidator", tags = ["no_oss"], deps = [ - "//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib_java", - "//tensorflow/contrib/lite/java/ovic:ovicdetectionbenchmarkerlib_java", + "//tensorflow/lite/java/ovic:ovicbenchmarkerlib_java", + "//tensorflow/lite/java/ovic:ovicdetectionbenchmarkerlib_java", ], ) @@ -51,11 +51,11 @@ android_library( "src/main/java/org/tensorflow/ovic/OvicClassifier.java", "src/main/java/org/tensorflow/ovic/OvicClassifierBenchmarker.java", ], - manifest = "//tensorflow/contrib/lite/java:AndroidManifest.xml", + manifest = "//tensorflow/lite/java:AndroidManifest.xml", tags = ["no_oss"], deps = [ - "//tensorflow/contrib/lite/java:tensorflowlite", - "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", + "//tensorflow/lite/java:tensorflowlite", + "//tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", "@org_checkerframework_qual", ], ) @@ -69,10 +69,10 @@ java_library( javacopts = JAVACOPTS, tags = ["no_oss"], deps = [ - "//tensorflow/contrib/lite/java:libtensorflowlite_jni.so", - "//tensorflow/contrib/lite/java:tensorflowlite_java", - "//tensorflow/contrib/lite/java/src/main/native", - "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", + "//tensorflow/lite/java:libtensorflowlite_jni.so", + "//tensorflow/lite/java:tensorflowlite_java", + "//tensorflow/lite/java/src/main/native", + "//tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", "@org_checkerframework_qual", ], ) @@ -83,15 +83,15 @@ java_test( size = "medium", srcs = ["src/test/java/org/tensorflow/ovic/OvicDetectorTest.java"], data = [ - "//tensorflow/contrib/lite/java/ovic/src/testdata:coco_labels.txt", - "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata", + "//tensorflow/lite/java/ovic/src/testdata:coco_labels.txt", + "//tensorflow/lite/java/ovic/src/testdata:ovic_testdata", ], javacopts = JAVACOPTS, tags = ["no_oss"], test_class = "org.tensorflow.ovic.OvicDetectorTest", visibility = ["//visibility:public"], deps = [ - "//tensorflow/contrib/lite/java/ovic:ovicdetectionbenchmarkerlib_java", + "//tensorflow/lite/java/ovic:ovicdetectionbenchmarkerlib_java", "@com_google_truth", "@junit", ], @@ -106,10 +106,10 @@ android_library( "src/main/java/org/tensorflow/ovic/OvicDetector.java", "src/main/java/org/tensorflow/ovic/OvicDetectorBenchmarker.java", ], - manifest = "//tensorflow/contrib/lite/java:AndroidManifest.xml", + manifest = "//tensorflow/lite/java:AndroidManifest.xml", deps = [ - "//tensorflow/contrib/lite/java:tensorflowlite", - "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", + "//tensorflow/lite/java:tensorflowlite", + "//tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", "@org_checkerframework_qual", ], ) @@ -123,10 +123,10 @@ java_library( ], javacopts = JAVACOPTS, deps = [ - "//tensorflow/contrib/lite/java:libtensorflowlite_jni.so", - "//tensorflow/contrib/lite/java:tensorflowlite_java", - "//tensorflow/contrib/lite/java/src/main/native", - "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", + "//tensorflow/lite/java:libtensorflowlite_jni.so", + "//tensorflow/lite/java:tensorflowlite_java", + "//tensorflow/lite/java/src/main/native", + "//tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", "@org_checkerframework_qual", ], ) diff --git a/tensorflow/contrib/lite/java/ovic/README.md b/tensorflow/lite/java/ovic/README.md similarity index 87% rename from tensorflow/contrib/lite/java/ovic/README.md rename to tensorflow/lite/java/ovic/README.md index 489ed3df407781..9e3ceb7e18e260 100644 --- a/tensorflow/contrib/lite/java/ovic/README.md +++ b/tensorflow/lite/java/ovic/README.md @@ -17,7 +17,7 @@ We are releasing an benchmarker Apk that would allow developers to measure laten The test data (models and images) should be downloaded automatically for you by Bazel. In case they are not, you can manually install them as below. -Note: all commands should be called from your tensorflow installation folder (under this folder you should find `tensorflow/contrib/lite`). +Note: all commands should be called from your tensorflow installation folder (under this folder you should find `tensorflow/lite`). * Download the [testdata package](https://storage.googleapis.com/download.tensorflow.org/data/ovic_2018_10_23.zip): @@ -29,7 +29,7 @@ curl -L https://storage.googleapis.com/download.tensorflow.org/data/ovic_2018_10 * Unzip the package into the testdata folder: ```sh -unzip -j /tmp/ovic.zip -d tensorflow/contrib/lite/java/ovic/src/testdata/ +unzip -j /tmp/ovic.zip -d tensorflow/lite/java/ovic/src/testdata/ ``` ### Run tests @@ -37,9 +37,9 @@ unzip -j /tmp/ovic.zip -d tensorflow/contrib/lite/java/ovic/src/testdata/ You can run test with Bazel as below. This helps to ensure that the installation is correct. ```sh -bazel test --cxxopt=--std=c++11 //tensorflow/contrib/lite/java/ovic:OvicClassifierTest --cxxopt=-Wno-all --test_output=all +bazel test --cxxopt=--std=c++11 //tensorflow/lite/java/ovic:OvicClassifierTest --cxxopt=-Wno-all --test_output=all -bazel test --cxxopt=--std=c++11 //tensorflow/contrib/lite/java/ovic:OvicDetectorTest --cxxopt=-Wno-all --test_output=all +bazel test --cxxopt=--std=c++11 //tensorflow/lite/java/ovic:OvicDetectorTest --cxxopt=-Wno-all --test_output=all ``` ### Test your submissions @@ -51,8 +51,8 @@ Once you have a submission that follows the instructions from the [competition s You can call the validator binary below to verify that your model fits the format requirements. This often helps you to catch size mismatches (e.g. output for classification should be [1, 1001] instead of [1,1,1,1001]). Let say the submission file is located at `/path/to/my_model.lite`, then call: ```sh -bazel build --cxxopt=--std=c++11 //tensorflow/contrib/lite/java/ovic:ovic_validator --cxxopt=-Wno-all -bazel-bin/tensorflow/contrib/lite/java/ovic/ovic_validator /path/to/my_model.lite classify +bazel build --cxxopt=--std=c++11 //tensorflow/lite/java/ovic:ovic_validator --cxxopt=-Wno-all +bazel-bin/tensorflow/lite/java/ovic/ovic_validator /path/to/my_model.lite classify ``` Successful validation should print the following message to terminal: @@ -72,14 +72,14 @@ You can go a step further to verify that the model produces results as expected. * Move your submission to the testdata folder: ```sh -cp /path/to/my_model.lite tensorflow/contrib/lite/java/ovic/src/testdata/ +cp /path/to/my_model.lite tensorflow/lite/java/ovic/src/testdata/ ``` * Resize the test image to the resolutions that are expected by your submission: -The test images can be found at `tensorflow/contrib/lite/java/ovic/src/testdata/test_image_*.jpg`. You may reuse these images if your image resolutions are 128x128 or 224x224. +The test images can be found at `tensorflow/lite/java/ovic/src/testdata/test_image_*.jpg`. You may reuse these images if your image resolutions are 128x128 or 224x224. -* Add your model and test image to the BUILD rule at `tensorflow/contrib/lite/java/ovic/src/testdata/BUILD`: +* Add your model and test image to the BUILD rule at `tensorflow/lite/java/ovic/src/testdata/BUILD`: ```JSON filegroup( @@ -113,7 +113,7 @@ We provide two ways to measure the on-device latency of your submission. The fir Make sure that you have followed instructions in [Test your submissions](#test-your-submissions) to add your model to the testdata folder and to the corresponding build rules. -Modify `tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java`: +Modify `tensorflow/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java`: * Add your model to the benchmarker apk by changing `MODEL_PATH` and `TEST_IMAGE_PATH` below to your submission and test image. @@ -140,8 +140,8 @@ Note: You'll need ROOT access to the phone to change processor affinity. * Build and install the app. ``` -bazel build -c opt --cxxopt=--std=c++11 --cxxopt=-Wno-all //tensorflow/contrib/lite/java/ovic/demo/app:ovic_benchmarker_binary -adb install -r bazel-bin/tensorflow/contrib/lite/java/ovic/demo/app/ovic_benchmarker_binary.apk +bazel build -c opt --cxxopt=--std=c++11 --cxxopt=-Wno-all //tensorflow/lite/java/ovic/demo/app:ovic_benchmarker_binary +adb install -r bazel-bin/tensorflow/lite/java/ovic/demo/app/ovic_benchmarker_binary.apk ``` Start the app and pick a task by clicking either the `CLF` button for classification or the `DET` button for detection. The button should turn bright green, signaling that the experiment is running. The benchmarking results will be displayed after about the `WALL_TIME` you specified above. For example: diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/AndroidManifest.xml b/tensorflow/lite/java/ovic/demo/app/AndroidManifest.xml similarity index 100% rename from tensorflow/contrib/lite/java/ovic/demo/app/AndroidManifest.xml rename to tensorflow/lite/java/ovic/demo/app/AndroidManifest.xml diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD b/tensorflow/lite/java/ovic/demo/app/BUILD similarity index 62% rename from tensorflow/contrib/lite/java/ovic/demo/app/BUILD rename to tensorflow/lite/java/ovic/demo/app/BUILD index a30c707483ed82..b3548deaf53689 100644 --- a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD +++ b/tensorflow/lite/java/ovic/demo/app/BUILD @@ -10,9 +10,9 @@ android_binary( ], aapt_version = "aapt", assets = [ - "//tensorflow/contrib/lite/java/ovic/src/testdata:coco_labels.txt", - "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt", - "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata", + "//tensorflow/lite/java/ovic/src/testdata:coco_labels.txt", + "//tensorflow/lite/java/ovic/src/testdata:labels.txt", + "//tensorflow/lite/java/ovic/src/testdata:ovic_testdata", ], assets_dir = "", custom_package = "ovic.demo.app", @@ -24,9 +24,9 @@ android_binary( resource_files = glob(["res/**"]), tags = ["manual"], deps = [ - "//tensorflow/contrib/lite/java:tensorflowlite", - "//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib", - "//tensorflow/contrib/lite/java/ovic:ovicdetectionbenchmarkerlib", + "//tensorflow/lite/java:tensorflowlite", + "//tensorflow/lite/java/ovic:ovicbenchmarkerlib", + "//tensorflow/lite/java/ovic:ovicdetectionbenchmarkerlib", "@androidsdk//com.android.support:support-v13-25.2.0", "@androidsdk//com.android.support:support-v4-25.2.0", ], diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java b/tensorflow/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java similarity index 100% rename from tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java rename to tensorflow/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle b/tensorflow/lite/java/ovic/demo/app/build.gradle similarity index 100% rename from tensorflow/contrib/lite/java/ovic/demo/app/build.gradle rename to tensorflow/lite/java/ovic/demo/app/build.gradle diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-mdpi/ic_launcher.png b/tensorflow/lite/java/ovic/demo/app/res/drawable-mdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-mdpi/ic_launcher.png rename to tensorflow/lite/java/ovic/demo/app/res/drawable-mdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-xhdpi/ic_launcher.png b/tensorflow/lite/java/ovic/demo/app/res/drawable-xhdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-xhdpi/ic_launcher.png rename to tensorflow/lite/java/ovic/demo/app/res/drawable-xhdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable/start_button_color.xml b/tensorflow/lite/java/ovic/demo/app/res/drawable/start_button_color.xml similarity index 100% rename from tensorflow/contrib/lite/java/ovic/demo/app/res/drawable/start_button_color.xml rename to tensorflow/lite/java/ovic/demo/app/res/drawable/start_button_color.xml diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml b/tensorflow/lite/java/ovic/demo/app/res/layout/activity_main.xml similarity index 100% rename from tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml rename to tensorflow/lite/java/ovic/demo/app/res/layout/activity_main.xml diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/values/dimens.xml b/tensorflow/lite/java/ovic/demo/app/res/values/dimens.xml similarity index 100% rename from tensorflow/contrib/lite/java/ovic/demo/app/res/values/dimens.xml rename to tensorflow/lite/java/ovic/demo/app/res/values/dimens.xml diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/values/strings.xml b/tensorflow/lite/java/ovic/demo/app/res/values/strings.xml similarity index 100% rename from tensorflow/contrib/lite/java/ovic/demo/app/res/values/strings.xml rename to tensorflow/lite/java/ovic/demo/app/res/values/strings.xml diff --git a/tensorflow/contrib/lite/java/ovic/demo/build.gradle b/tensorflow/lite/java/ovic/demo/build.gradle similarity index 100% rename from tensorflow/contrib/lite/java/ovic/demo/build.gradle rename to tensorflow/lite/java/ovic/demo/build.gradle diff --git a/tensorflow/contrib/lite/java/ovic/demo/gradle.properties b/tensorflow/lite/java/ovic/demo/gradle.properties similarity index 100% rename from tensorflow/contrib/lite/java/ovic/demo/gradle.properties rename to tensorflow/lite/java/ovic/demo/gradle.properties diff --git a/tensorflow/contrib/lite/java/ovic/demo/gradle/wrapper/gradle-wrapper.jar b/tensorflow/lite/java/ovic/demo/gradle/wrapper/gradle-wrapper.jar similarity index 100% rename from tensorflow/contrib/lite/java/ovic/demo/gradle/wrapper/gradle-wrapper.jar rename to tensorflow/lite/java/ovic/demo/gradle/wrapper/gradle-wrapper.jar diff --git a/tensorflow/contrib/lite/java/ovic/demo/gradle/wrapper/gradle-wrapper.properties b/tensorflow/lite/java/ovic/demo/gradle/wrapper/gradle-wrapper.properties similarity index 100% rename from tensorflow/contrib/lite/java/ovic/demo/gradle/wrapper/gradle-wrapper.properties rename to tensorflow/lite/java/ovic/demo/gradle/wrapper/gradle-wrapper.properties diff --git a/tensorflow/contrib/lite/java/ovic/demo/gradlew b/tensorflow/lite/java/ovic/demo/gradlew similarity index 100% rename from tensorflow/contrib/lite/java/ovic/demo/gradlew rename to tensorflow/lite/java/ovic/demo/gradlew diff --git a/tensorflow/contrib/lite/java/ovic/demo/gradlew.bat b/tensorflow/lite/java/ovic/demo/gradlew.bat similarity index 100% rename from tensorflow/contrib/lite/java/ovic/demo/gradlew.bat rename to tensorflow/lite/java/ovic/demo/gradlew.bat diff --git a/tensorflow/contrib/lite/java/ovic/demo/settings.gradle b/tensorflow/lite/java/ovic/demo/settings.gradle similarity index 100% rename from tensorflow/contrib/lite/java/ovic/demo/settings.gradle rename to tensorflow/lite/java/ovic/demo/settings.gradle diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/BoundingBox.java b/tensorflow/lite/java/ovic/src/main/java/org/tensorflow/ovic/BoundingBox.java similarity index 100% rename from tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/BoundingBox.java rename to tensorflow/lite/java/ovic/src/main/java/org/tensorflow/ovic/BoundingBox.java diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java b/tensorflow/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java similarity index 100% rename from tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java rename to tensorflow/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassificationResult.java b/tensorflow/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassificationResult.java similarity index 100% rename from tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassificationResult.java rename to tensorflow/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassificationResult.java diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java b/tensorflow/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java similarity index 100% rename from tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java rename to tensorflow/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifierBenchmarker.java b/tensorflow/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifierBenchmarker.java similarity index 100% rename from tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifierBenchmarker.java rename to tensorflow/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifierBenchmarker.java diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetectionResult.java b/tensorflow/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetectionResult.java similarity index 100% rename from tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetectionResult.java rename to tensorflow/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetectionResult.java diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetector.java b/tensorflow/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetector.java similarity index 100% rename from tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetector.java rename to tensorflow/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetector.java diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetectorBenchmarker.java b/tensorflow/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetectorBenchmarker.java similarity index 100% rename from tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetectorBenchmarker.java rename to tensorflow/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetectorBenchmarker.java diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicValidator.java b/tensorflow/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicValidator.java similarity index 98% rename from tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicValidator.java rename to tensorflow/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicValidator.java index 5756380abb751f..0a7aee043271b8 100644 --- a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicValidator.java +++ b/tensorflow/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicValidator.java @@ -47,7 +47,7 @@ public static void main(String[] args) { final boolean isDetection = taskString.equals("detect"); // Label file for detection is never used, so the same label file is used for both tasks. final String labelPath = - "tensorflow/contrib/lite/java/ovic/src/testdata/labels.txt"; + "tensorflow/lite/java/ovic/src/testdata/labels.txt"; try { MappedByteBuffer model = loadModelFile(modelFile); diff --git a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java b/tensorflow/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java similarity index 98% rename from tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java rename to tensorflow/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java index 99e874ca786a22..c309c5bd55114b 100644 --- a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java +++ b/tensorflow/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java @@ -45,7 +45,7 @@ public final class OvicClassifierTest { private ByteBuffer lowResTestImage = null; private OvicClassificationResult testResult = null; private static final String LABELS_PATH = - "tensorflow/contrib/lite/java/ovic/src/testdata/labels.txt"; + "tensorflow/lite/java/ovic/src/testdata/labels.txt"; private static final String QUANTIZED_MODEL_PATH = "external/tflite_ovic_testdata/quantized_model.lite"; private static final String LOW_RES_MODEL_PATH = diff --git a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicDetectorTest.java b/tensorflow/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicDetectorTest.java similarity index 98% rename from tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicDetectorTest.java rename to tensorflow/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicDetectorTest.java index 489d7a0f2b87d9..709f8fb5c32933 100644 --- a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicDetectorTest.java +++ b/tensorflow/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicDetectorTest.java @@ -41,7 +41,7 @@ public final class OvicDetectorTest { private ByteBuffer testImage = null; private static final String LABELS_PATH = - "tensorflow/contrib/lite/java/ovic/src/testdata/coco_labels.txt"; + "tensorflow/lite/java/ovic/src/testdata/coco_labels.txt"; private static final String MODEL_PATH = "external/tflite_ovic_testdata/quantized_detect.lite"; private static final String TEST_IMAGE_PATH = diff --git a/tensorflow/contrib/lite/java/ovic/src/testdata/BUILD b/tensorflow/lite/java/ovic/src/testdata/BUILD similarity index 100% rename from tensorflow/contrib/lite/java/ovic/src/testdata/BUILD rename to tensorflow/lite/java/ovic/src/testdata/BUILD diff --git a/tensorflow/contrib/lite/java/ovic/src/testdata/coco_labels.txt b/tensorflow/lite/java/ovic/src/testdata/coco_labels.txt similarity index 100% rename from tensorflow/contrib/lite/java/ovic/src/testdata/coco_labels.txt rename to tensorflow/lite/java/ovic/src/testdata/coco_labels.txt diff --git a/tensorflow/contrib/lite/java/ovic/src/testdata/labels.txt b/tensorflow/lite/java/ovic/src/testdata/labels.txt similarity index 100% rename from tensorflow/contrib/lite/java/ovic/src/testdata/labels.txt rename to tensorflow/lite/java/ovic/src/testdata/labels.txt diff --git a/tensorflow/contrib/lite/java/proguard.flags b/tensorflow/lite/java/proguard.flags similarity index 100% rename from tensorflow/contrib/lite/java/proguard.flags rename to tensorflow/lite/java/proguard.flags diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/DataType.java similarity index 100% rename from tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java rename to tensorflow/lite/java/src/main/java/org/tensorflow/lite/DataType.java diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Delegate.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Delegate.java similarity index 100% rename from tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Delegate.java rename to tensorflow/lite/java/src/main/java/org/tensorflow/lite/Delegate.java diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java similarity index 100% rename from tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java rename to tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java similarity index 100% rename from tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java rename to tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java similarity index 100% rename from tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java rename to tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java similarity index 100% rename from tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java rename to tensorflow/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/package-info.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/package-info.java similarity index 100% rename from tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/package-info.java rename to tensorflow/lite/java/src/main/java/org/tensorflow/lite/package-info.java diff --git a/tensorflow/contrib/lite/java/src/main/native/BUILD b/tensorflow/lite/java/src/main/native/BUILD similarity index 90% rename from tensorflow/contrib/lite/java/src/main/native/BUILD rename to tensorflow/lite/java/src/main/native/BUILD index f91345f369fe11..2abba24345824c 100644 --- a/tensorflow/contrib/lite/java/src/main/native/BUILD +++ b/tensorflow/lite/java/src/main/native/BUILD @@ -4,7 +4,7 @@ package(default_visibility = ["//visibility:public"]) -load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") +load("//tensorflow/lite:build_def.bzl", "tflite_copts") licenses(["notice"]) # Apache 2.0 @@ -40,9 +40,9 @@ cc_library( "-ldl", ], deps = [ - "//tensorflow/contrib/lite:context", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/lite:context", + "//tensorflow/lite:framework", + "//tensorflow/lite:schema_fbs_version", ], alwayslink = 1, ) @@ -99,7 +99,7 @@ cc_library( "-ldl", ], deps = [ - "//tensorflow/contrib/lite/testing:init_tensorflow", + "//tensorflow/lite/testing:init_tensorflow", ], alwayslink = 1, ) @@ -115,7 +115,7 @@ cc_library( copts = tflite_copts(), deps = [ ":native_framework_only", - "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/lite/kernels:builtin_ops", ], alwayslink = 1, ) diff --git a/tensorflow/contrib/lite/java/src/main/native/builtin_ops_jni.cc b/tensorflow/lite/java/src/main/native/builtin_ops_jni.cc similarity index 95% rename from tensorflow/contrib/lite/java/src/main/native/builtin_ops_jni.cc rename to tensorflow/lite/java/src/main/native/builtin_ops_jni.cc index cce356370fa770..95bc0a4fa8d1d4 100644 --- a/tensorflow/contrib/lite/java/src/main/native/builtin_ops_jni.cc +++ b/tensorflow/lite/java/src/main/native/builtin_ops_jni.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/lite/kernels/register.h" namespace tflite { diff --git a/tensorflow/contrib/lite/java/src/main/native/exception_jni.cc b/tensorflow/lite/java/src/main/native/exception_jni.cc similarity index 96% rename from tensorflow/contrib/lite/java/src/main/native/exception_jni.cc rename to tensorflow/lite/java/src/main/native/exception_jni.cc index 18d177f1a6d49a..5406c7197f0c6b 100644 --- a/tensorflow/contrib/lite/java/src/main/native/exception_jni.cc +++ b/tensorflow/lite/java/src/main/native/exception_jni.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h" +#include "tensorflow/lite/java/src/main/native/exception_jni.h" const char kIllegalArgumentException[] = "java/lang/IllegalArgumentException"; const char kIllegalStateException[] = "java/lang/IllegalStateException"; diff --git a/tensorflow/contrib/lite/java/src/main/native/exception_jni.h b/tensorflow/lite/java/src/main/native/exception_jni.h similarity index 84% rename from tensorflow/contrib/lite/java/src/main/native/exception_jni.h rename to tensorflow/lite/java/src/main/native/exception_jni.h index 2a4bbdbeadcc64..ebd91e875b5b58 100644 --- a/tensorflow/contrib/lite/java/src/main/native/exception_jni.h +++ b/tensorflow/lite/java/src/main/native/exception_jni.h @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_EXCEPTION_JNI_H_ -#define TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_EXCEPTION_JNI_H_ +#ifndef TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_EXCEPTION_JNI_H_ +#define TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_EXCEPTION_JNI_H_ #include -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/lite/error_reporter.h" #ifdef __cplusplus extern "C" { @@ -47,4 +47,4 @@ class BufferErrorReporter : public tflite::ErrorReporter { #ifdef __cplusplus } // extern "C" #endif // __cplusplus -#endif // TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_EXCEPTION_JNI_H_ +#endif // TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_EXCEPTION_JNI_H_ diff --git a/tensorflow/contrib/lite/java/src/main/native/init_tensorflow_jni.cc b/tensorflow/lite/java/src/main/native/init_tensorflow_jni.cc similarity index 85% rename from tensorflow/contrib/lite/java/src/main/native/init_tensorflow_jni.cc rename to tensorflow/lite/java/src/main/native/init_tensorflow_jni.cc index 74aa384df30334..1fa9d1f50e50d2 100644 --- a/tensorflow/contrib/lite/java/src/main/native/init_tensorflow_jni.cc +++ b/tensorflow/lite/java/src/main/native/init_tensorflow_jni.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/java/src/main/native/init_tensorflow_jni.h" -#include "tensorflow/contrib/lite/testing/init_tensorflow.h" +#include "tensorflow/lite/java/src/main/native/init_tensorflow_jni.h" +#include "tensorflow/lite/testing/init_tensorflow.h" JNIEXPORT void JNICALL Java_org_tensorflow_lite_TensorFlowLite_initTensorFlow( JNIEnv* env, jclass clazz) { diff --git a/tensorflow/contrib/lite/java/src/main/native/init_tensorflow_jni.h b/tensorflow/lite/java/src/main/native/init_tensorflow_jni.h similarity index 81% rename from tensorflow/contrib/lite/java/src/main/native/init_tensorflow_jni.h rename to tensorflow/lite/java/src/main/native/init_tensorflow_jni.h index 4689eb05fedcf8..1454d6d4633d4f 100644 --- a/tensorflow/contrib/lite/java/src/main/native/init_tensorflow_jni.h +++ b/tensorflow/lite/java/src/main/native/init_tensorflow_jni.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_INIT_TENSORFLOW_JNI_H_ -#define TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_INIT_TENSORFLOW_JNI_H_ +#ifndef TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_INIT_TENSORFLOW_JNI_H_ +#define TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_INIT_TENSORFLOW_JNI_H_ #include @@ -33,4 +33,4 @@ JNIEXPORT void JNICALL Java_org_tensorflow_lite_TensorFlowLite_initTensorFlow( } // extern "C" #endif // __cplusplus -#endif // TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_INIT_TENSORFLOW_JNI_H_ +#endif // TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_INIT_TENSORFLOW_JNI_H_ diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc similarity index 98% rename from tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc rename to tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index 83c6c9cb456f24..c7389c581100ac 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h" +#include "tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.h" namespace { tflite::Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) { @@ -465,9 +465,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_applyDelegate( TfLiteDelegate* delegate = convertLongToDelegate(env, delegate_handle); if (delegate == nullptr) return; - TfLiteStatus status = - interpreter->ModifyGraphWithDelegate(delegate, - /* allow_dynamic_tensors= */ true); + TfLiteStatus status = interpreter->ModifyGraphWithDelegate(delegate); if (status != kTfLiteOk) { throwException(env, kIllegalArgumentException, "Internal error: Failed to apply delegate: %s", diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.h similarity index 93% rename from tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h rename to tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.h index 5086bf8c2825fc..e184b8f1a783d5 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h +++ b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.h @@ -13,18 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_NATIVEINTERPRETERWRAPPER_JNI_H_ -#define TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_NATIVEINTERPRETERWRAPPER_JNI_H_ +#ifndef TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_NATIVEINTERPRETERWRAPPER_JNI_H_ +#define TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_NATIVEINTERPRETERWRAPPER_JNI_H_ #include #include #include #include -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h" -#include "tensorflow/contrib/lite/java/src/main/native/tensor_jni.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/java/src/main/native/exception_jni.h" +#include "tensorflow/lite/java/src/main/native/tensor_jni.h" +#include "tensorflow/lite/model.h" namespace tflite { // This is to be provided at link-time by a library. @@ -245,4 +245,4 @@ JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_delete( #ifdef __cplusplus } // extern "C" #endif // __cplusplus -#endif // TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_NATIVEINTERPRETERWRAPPER_JNI_H_ +#endif // TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_NATIVEINTERPRETERWRAPPER_JNI_H_ diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc b/tensorflow/lite/java/src/main/native/tensor_jni.cc similarity index 98% rename from tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc rename to tensorflow/lite/java/src/main/native/tensor_jni.cc index d3378f5f145dee..1d813d50da44de 100644 --- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc +++ b/tensorflow/lite/java/src/main/native/tensor_jni.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/java/src/main/native/tensor_jni.h" +#include "tensorflow/lite/java/src/main/native/tensor_jni.h" #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/java/src/main/native/exception_jni.h" namespace { diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h b/tensorflow/lite/java/src/main/native/tensor_jni.h similarity index 93% rename from tensorflow/contrib/lite/java/src/main/native/tensor_jni.h rename to tensorflow/lite/java/src/main/native/tensor_jni.h index c5e9690e9a04ba..ec0442e93f6f9d 100644 --- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h +++ b/tensorflow/lite/java/src/main/native/tensor_jni.h @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_TENSOR_JNI_H_ -#define TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_TENSOR_JNI_H_ +#ifndef TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_TENSOR_JNI_H_ +#define TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_TENSOR_JNI_H_ #include -#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/lite/c/c_api_internal.h" #ifdef __cplusplus extern "C" { @@ -109,4 +109,4 @@ Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env, #ifdef __cplusplus } // extern "C" #endif // __cplusplus -#endif // TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_TENSOR_JNI_H_ +#endif // TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_TENSOR_JNI_H_ diff --git a/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.cc b/tensorflow/lite/java/src/main/native/tensorflow_lite_jni.cc similarity index 88% rename from tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.cc rename to tensorflow/lite/java/src/main/native/tensorflow_lite_jni.cc index 2e7f2f56921b87..2b8cf4201cea95 100644 --- a/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.cc +++ b/tensorflow/lite/java/src/main/native/tensorflow_lite_jni.cc @@ -15,8 +15,8 @@ limitations under the License. #include -#include "tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h" -#include "tensorflow/contrib/lite/version.h" +#include "tensorflow/lite/java/src/main/native/tensorflow_lite_jni.h" +#include "tensorflow/lite/version.h" JNIEXPORT jstring JNICALL Java_org_tensorflow_lite_TensorFlowLite_version(JNIEnv* env, jclass /*clazz*/) { diff --git a/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h b/tensorflow/lite/java/src/main/native/tensorflow_lite_jni.h similarity index 81% rename from tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h rename to tensorflow/lite/java/src/main/native/tensorflow_lite_jni.h index 5e2a7ded1b495e..de3e703110c455 100644 --- a/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h +++ b/tensorflow/lite/java/src/main/native/tensorflow_lite_jni.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_TENSORFLOW_LITE_JNI_H_ -#define TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_TENSORFLOW_LITE_JNI_H_ +#ifndef TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_TENSORFLOW_LITE_JNI_H_ +#define TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_TENSORFLOW_LITE_JNI_H_ #include @@ -33,4 +33,4 @@ Java_org_tensorflow_lite_TensorFlowLite_version(JNIEnv*, jclass); #ifdef __cplusplus } // extern "C" #endif // __cplusplus -#endif // TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_TENSORFLOW_LITE_JNI_H_ +#endif // TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_TENSORFLOW_LITE_JNI_H_ diff --git a/tensorflow/contrib/lite/java/src/main/native/version_script.lds b/tensorflow/lite/java/src/main/native/version_script.lds similarity index 100% rename from tensorflow/contrib/lite/java/src/main/native/version_script.lds rename to tensorflow/lite/java/src/main/native/version_script.lds diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java similarity index 100% rename from tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java rename to tensorflow/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java similarity index 96% rename from tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java rename to tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java index 3b3d9f0e7fc070..b22399a4a47dcf 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java @@ -30,7 +30,7 @@ public final class InterpreterFlexTest { private static final File FLEX_MODEL_FILE = - new File("tensorflow/contrib/lite/testdata/multi_add_flex.bin"); + new File("tensorflow/lite/testdata/multi_add_flex.bin"); /** Smoke test validating that flex model loading works when the flex delegate is linked. */ @Test diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterMobileNetTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterMobileNetTest.java new file mode 100644 index 00000000000000..b69bfa076e2268 --- /dev/null +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterMobileNetTest.java @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import static com.google.common.truth.Truth.assertThat; + +import java.io.File; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.tensorflow.lite.Interpreter} agains a MobileNet model. */ +@RunWith(JUnit4.class) +public final class InterpreterMobileNetTest { + + private static final File MOBILENET_MODEL_FILE = + new File("tensorflow/lite/java/src/testdata/mobilenet.tflite.bin"); + + @Test + public void testMobilenetRun() { + // Create a gray image. + float[][][][] img = new float[1][224][224][3]; + for (int i = 0; i < 224; ++i) { + for (int j = 0; j < 224; ++j) { + img[0][i][j][0] = 0.5f; + img[0][i][j][1] = 0.5f; + img[0][i][j][2] = 0.5f; + } + } + + // Allocate memory to receive the output values. + float[][] labels = new float[1][1001]; + + Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE); + interpreter.run(img, labels); + assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(new int[] {1, 224, 224, 3}); + assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(new int[] {1, 1001}); + interpreter.close(); + + assertThat(labels[0]) + .usingExactEquality() + .containsNoneOf(new float[] {Float.NaN, Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY}); + } +} diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java similarity index 91% rename from tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java rename to tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java index d5e0347402a4e7..7e591b009d2f12 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -38,13 +38,10 @@ public final class InterpreterTest { private static final File MODEL_FILE = - new File("tensorflow/contrib/lite/java/src/testdata/add.bin"); - - private static final File MOBILENET_MODEL_FILE = - new File("tensorflow/contrib/lite/java/src/testdata/mobilenet.tflite.bin"); + new File("tensorflow/lite/java/src/testdata/add.bin"); private static final File FLEX_MODEL_FILE = - new File("tensorflow/contrib/lite/testdata/multi_add_flex.bin"); + new File("tensorflow/lite/testdata/multi_add_flex.bin"); @Test public void testInterpreter() throws Exception { @@ -214,32 +211,6 @@ public void testResizeInput() { } } - @Test - public void testMobilenetRun() { - // Create a gray image. - float[][][][] img = new float[1][224][224][3]; - for (int i = 0; i < 224; ++i) { - for (int j = 0; j < 224; ++j) { - img[0][i][j][0] = 0.5f; - img[0][i][j][1] = 0.5f; - img[0][i][j][2] = 0.5f; - } - } - - // Allocate memory to receive the output values. - float[][] labels = new float[1][1001]; - - Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE); - interpreter.run(img, labels); - assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(new int[] {1, 224, 224, 3}); - assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(new int[] {1, 1001}); - interpreter.close(); - - assertThat(labels[0]) - .usingExactEquality() - .containsNoneOf(new float[] {Float.NaN, Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY}); - } - @Test public void testRunWithWrongInputType() { Interpreter interpreter = new Interpreter(MODEL_FILE); @@ -286,7 +257,7 @@ public void testRunWithWrongOutputType() { @Test public void testGetInputIndex() { - Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE); + Interpreter interpreter = new Interpreter(MODEL_FILE); try { interpreter.getInputIndex("WrongInputName"); fail(); @@ -303,7 +274,7 @@ public void testGetInputIndex() { @Test public void testGetOutputIndex() { - Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE); + Interpreter interpreter = new Interpreter(MODEL_FILE); try { interpreter.getOutputIndex("WrongOutputName"); fail(); @@ -312,9 +283,9 @@ public void testGetOutputIndex() { .hasMessageThat() .contains( "'WrongOutputName' is not a valid name for any output. Names of outputs and their" - + " indexes are {MobilenetV1/Predictions/Softmax=0}"); + + " indexes are {output=0}"); } - int index = interpreter.getOutputIndex("MobilenetV1/Predictions/Softmax"); + int index = interpreter.getOutputIndex("output"); assertThat(index).isEqualTo(0); } diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java similarity index 97% rename from tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java rename to tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java index 270bd6703a101d..07d334c33b2337 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java @@ -32,28 +32,28 @@ public final class NativeInterpreterWrapperTest { private static final String FLOAT_MODEL_PATH = - "tensorflow/contrib/lite/java/src/testdata/add.bin"; + "tensorflow/lite/java/src/testdata/add.bin"; private static final String INT_MODEL_PATH = - "tensorflow/contrib/lite/java/src/testdata/int32.bin"; + "tensorflow/lite/java/src/testdata/int32.bin"; private static final String LONG_MODEL_PATH = - "tensorflow/contrib/lite/java/src/testdata/int64.bin"; + "tensorflow/lite/java/src/testdata/int64.bin"; private static final String BYTE_MODEL_PATH = - "tensorflow/contrib/lite/java/src/testdata/uint8.bin"; + "tensorflow/lite/java/src/testdata/uint8.bin"; private static final String QUANTIZED_MODEL_PATH = - "tensorflow/contrib/lite/java/src/testdata/quantized.bin"; + "tensorflow/lite/java/src/testdata/quantized.bin"; private static final String INVALID_MODEL_PATH = - "tensorflow/contrib/lite/java/src/testdata/invalid_model.bin"; + "tensorflow/lite/java/src/testdata/invalid_model.bin"; private static final String MODEL_WITH_CUSTOM_OP_PATH = - "tensorflow/contrib/lite/java/src/testdata/with_custom_op.lite"; + "tensorflow/lite/java/src/testdata/with_custom_op.lite"; private static final String NONEXISTING_MODEL_PATH = - "tensorflow/contrib/lite/java/src/testdata/nonexisting_model.bin"; + "tensorflow/lite/java/src/testdata/nonexisting_model.bin"; @Test public void testConstructor() { diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java similarity index 100% rename from tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java rename to tensorflow/lite/java/src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java similarity index 99% rename from tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java rename to tensorflow/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java index 56a38ea3e225e9..35ff4328b83e3b 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java @@ -33,7 +33,7 @@ public final class TensorTest { private static final String MODEL_PATH = - "tensorflow/contrib/lite/java/src/testdata/add.bin"; + "tensorflow/lite/java/src/testdata/add.bin"; private NativeInterpreterWrapper wrapper; private Tensor tensor; diff --git a/tensorflow/contrib/lite/java/src/test/native/BUILD b/tensorflow/lite/java/src/test/native/BUILD similarity index 64% rename from tensorflow/contrib/lite/java/src/test/native/BUILD rename to tensorflow/lite/java/src/test/native/BUILD index 17a10587dc3555..4d3e82b1ac1499 100644 --- a/tensorflow/contrib/lite/java/src/test/native/BUILD +++ b/tensorflow/lite/java/src/test/native/BUILD @@ -5,7 +5,7 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 -load("//tensorflow/contrib/lite:build_def.bzl", "tflite_jni_binary") +load("//tensorflow/lite:build_def.bzl", "tflite_jni_binary") cc_library( name = "native", @@ -17,11 +17,15 @@ cc_library( # For non-Android toolchains, generate jni.h and jni_md.h. "//tensorflow:android": [], "//conditions:default": [ - "//tensorflow/contrib/lite/java/src/main/native:jni.h", - "//tensorflow/contrib/lite/java/src/main/native:jni_md.h", + "//tensorflow/lite/java/src/main/native:jni.h", + "//tensorflow/lite/java/src/main/native:jni_md.h", ], }), - deps = ["//tensorflow/contrib/lite/c:c_api_internal"], + includes = select({ + "//tensorflow:android": [], + "//conditions:default": ["../../main/native/."], + }), + deps = ["//tensorflow/lite/c:c_api_internal"], ) tflite_jni_binary( diff --git a/tensorflow/contrib/lite/java/src/test/native/interpreter_test_jni.cc b/tensorflow/lite/java/src/test/native/interpreter_test_jni.cc similarity index 77% rename from tensorflow/contrib/lite/java/src/test/native/interpreter_test_jni.cc rename to tensorflow/lite/java/src/test/native/interpreter_test_jni.cc index 6aad4973f778af..d83cb4cd305aa6 100644 --- a/tensorflow/contrib/lite/java/src/test/native/interpreter_test_jni.cc +++ b/tensorflow/lite/java/src/test/native/interpreter_test_jni.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/lite/c/c_api_internal.h" #ifdef __cplusplus extern "C" { @@ -25,6 +25,8 @@ Java_org_tensorflow_lite_InterpreterTest_getNativeHandleForDelegate( JNIEnv* env, jclass clazz) { // A simple op which outputs a vector of length 1 with the value [7]. static TfLiteRegistration registration = { + .init = nullptr, + .free = nullptr, .prepare = [](TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; @@ -38,10 +40,16 @@ Java_org_tensorflow_lite_InterpreterTest_getNativeHandleForDelegate( TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; output->data.f[0] = 7.0f; return kTfLiteOk; - }}; + }, + .profiling_string = nullptr, + .builtin_code = 0, + .custom_name = "", + .version = 1, + }; // A simple delegate which replaces all ops with a single op that outputs a // vector of length 1 with the value [7]. static TfLiteDelegate delegate = { + .data_ = nullptr, .Prepare = [](TfLiteContext* context, TfLiteDelegate* delegate) -> TfLiteStatus { TfLiteIntArray* execution_plan; @@ -50,7 +58,12 @@ Java_org_tensorflow_lite_InterpreterTest_getNativeHandleForDelegate( context->ReplaceSubgraphsWithDelegateKernels(context, registration, execution_plan, delegate); return kTfLiteOk; - }}; + }, + .CopyFromBufferHandle = nullptr, + .CopyToBufferHandle = nullptr, + .FreeBufferHandle = nullptr, + .flags = kTfLiteDelegateFlagsAllowDynamicTensors, + }; return reinterpret_cast(&delegate); } @@ -59,10 +72,14 @@ Java_org_tensorflow_lite_InterpreterTest_getNativeHandleForInvalidDelegate( JNIEnv* env, jclass clazz) { // A simple delegate that fails during preparation. static TfLiteDelegate delegate = { - .Prepare = [](TfLiteContext* context, - TfLiteDelegate* delegate) -> TfLiteStatus { - return kTfLiteError; - }}; + .data_ = nullptr, + .Prepare = [](TfLiteContext* context, TfLiteDelegate* delegate) + -> TfLiteStatus { return kTfLiteError; }, + .CopyFromBufferHandle = nullptr, + .CopyToBufferHandle = nullptr, + .FreeBufferHandle = nullptr, + .flags = kTfLiteDelegateFlagsNone, + }; return reinterpret_cast(&delegate); } diff --git a/tensorflow/contrib/lite/java/src/testdata/add.bin b/tensorflow/lite/java/src/testdata/add.bin similarity index 100% rename from tensorflow/contrib/lite/java/src/testdata/add.bin rename to tensorflow/lite/java/src/testdata/add.bin diff --git a/tensorflow/contrib/lite/java/src/testdata/float32.bin b/tensorflow/lite/java/src/testdata/float32.bin similarity index 100% rename from tensorflow/contrib/lite/java/src/testdata/float32.bin rename to tensorflow/lite/java/src/testdata/float32.bin diff --git a/tensorflow/contrib/lite/java/src/testdata/int32.bin b/tensorflow/lite/java/src/testdata/int32.bin similarity index 100% rename from tensorflow/contrib/lite/java/src/testdata/int32.bin rename to tensorflow/lite/java/src/testdata/int32.bin diff --git a/tensorflow/contrib/lite/java/src/testdata/int64.bin b/tensorflow/lite/java/src/testdata/int64.bin similarity index 100% rename from tensorflow/contrib/lite/java/src/testdata/int64.bin rename to tensorflow/lite/java/src/testdata/int64.bin diff --git a/tensorflow/contrib/lite/java/src/testdata/invalid_model.bin b/tensorflow/lite/java/src/testdata/invalid_model.bin similarity index 100% rename from tensorflow/contrib/lite/java/src/testdata/invalid_model.bin rename to tensorflow/lite/java/src/testdata/invalid_model.bin diff --git a/tensorflow/contrib/lite/java/src/testdata/quantized.bin b/tensorflow/lite/java/src/testdata/quantized.bin similarity index 100% rename from tensorflow/contrib/lite/java/src/testdata/quantized.bin rename to tensorflow/lite/java/src/testdata/quantized.bin diff --git a/tensorflow/contrib/lite/java/src/testdata/uint8.bin b/tensorflow/lite/java/src/testdata/uint8.bin similarity index 100% rename from tensorflow/contrib/lite/java/src/testdata/uint8.bin rename to tensorflow/lite/java/src/testdata/uint8.bin diff --git a/tensorflow/contrib/lite/java/src/testdata/with_custom_op.lite b/tensorflow/lite/java/src/testdata/with_custom_op.lite similarity index 100% rename from tensorflow/contrib/lite/java/src/testdata/with_custom_op.lite rename to tensorflow/lite/java/src/testdata/with_custom_op.lite diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD b/tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD similarity index 85% rename from tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD rename to tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD index af1d99ef41e641..88641c86ed64e7 100644 --- a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD +++ b/tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD @@ -15,6 +15,6 @@ android_library( ], ), deps = [ - "//tensorflow/contrib/lite/java:tensorflowlite_java", + "//tensorflow/lite/java:tensorflowlite_java", ], ) diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java b/tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java similarity index 100% rename from tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java rename to tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD similarity index 68% rename from tensorflow/contrib/lite/kernels/BUILD rename to tensorflow/lite/kernels/BUILD index 363e7ed2e82197..010ba834661f7d 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -4,8 +4,8 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 -load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") -load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") +load("//tensorflow/lite:build_def.bzl", "tflite_copts") +load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_opts_nortti_if_android") # Suppress warnings that are introduced by Eigen Tensor. @@ -31,8 +31,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -44,12 +44,12 @@ cc_library( hdrs = ["test_util.h"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:schema_fbs_version", - "//tensorflow/contrib/lite:string_util", - "//tensorflow/contrib/lite/kernels/internal:tensor_utils", - "//tensorflow/contrib/lite/testing:util", "//tensorflow/core:tflite_portable_logging", + "//tensorflow/lite:framework", + "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite:string_util", + "//tensorflow/lite/kernels/internal:tensor_utils", + "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -65,9 +65,9 @@ cc_library( copts = tflite_copts() + EXTRA_EIGEN_COPTS, deps = [ ":op_macros", - "//tensorflow/contrib/lite:arena_planner", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/kernels/internal:optimized", + "//tensorflow/lite:arena_planner", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels/internal:optimized", ], ) @@ -82,7 +82,7 @@ cc_library( copts = tflite_copts(), deps = [ ":op_macros", - "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/lite/c:c_api_internal", "@gemmlowp", ], ) @@ -93,7 +93,7 @@ cc_library( "activation_functor.h", ], deps = [ - "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/lite/c:c_api_internal", ], ) @@ -113,9 +113,9 @@ cc_library( "kernel_util.h", ], deps = [ - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/kernels/internal:round", - "//tensorflow/contrib/lite/kernels/internal:types", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels/internal:round", + "//tensorflow/lite/kernels/internal:types", ], ) @@ -129,7 +129,7 @@ tf_cc_test( ], deps = [ ":kernel_util", - "//tensorflow/contrib/lite/testing:util", + "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -144,7 +144,7 @@ tf_cc_test( ], deps = [ ":test_util", - "//tensorflow/contrib/lite/testing:util", + "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -154,7 +154,7 @@ cc_library( srcs = [], hdrs = ["padding.h"], deps = [ - "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/lite/c:c_api_internal", ], ) @@ -243,18 +243,18 @@ cc_library( ":lstm_eval", ":op_macros", ":padding", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:string_util", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/kernels:gemm_support", - "//tensorflow/contrib/lite/kernels/internal:audio_utils", - "//tensorflow/contrib/lite/kernels/internal:kernel_utils", - "//tensorflow/contrib/lite/kernels/internal:optimized", - "//tensorflow/contrib/lite/kernels/internal:optimized_base", - "//tensorflow/contrib/lite/kernels/internal:quantization_util", - "//tensorflow/contrib/lite/kernels/internal:reference_base", - "//tensorflow/contrib/lite/kernels/internal:tensor", - "//tensorflow/contrib/lite/kernels/internal:tensor_utils", + "//tensorflow/lite:framework", + "//tensorflow/lite:string_util", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:gemm_support", + "//tensorflow/lite/kernels/internal:audio_utils", + "//tensorflow/lite/kernels/internal:kernel_utils", + "//tensorflow/lite/kernels/internal:optimized", + "//tensorflow/lite/kernels/internal:optimized_base", + "//tensorflow/lite/kernels/internal:quantization_util", + "//tensorflow/lite/kernels/internal:reference_base", + "//tensorflow/lite/kernels/internal:tensor", + "//tensorflow/lite/kernels/internal:tensor_utils", "@farmhash_archive//:farmhash", "@flatbuffers", ], @@ -266,9 +266,9 @@ cc_library( hdrs = ["lstm_eval.h"], deps = [ ":op_macros", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/kernels/internal:kernel_utils", - "//tensorflow/contrib/lite/kernels/internal:tensor_utils", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels/internal:kernel_utils", + "//tensorflow/lite/kernels/internal:tensor_utils", ], ) @@ -278,9 +278,9 @@ cc_library( hdrs = ["register.h"], deps = [ ":builtin_op_kernels", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:util", - "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/lite:framework", + "//tensorflow/lite:util", + "//tensorflow/lite/c:c_api_internal", ], ) @@ -294,8 +294,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", "@flatbuffers", ], @@ -311,8 +311,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", "@flatbuffers", ], @@ -328,8 +328,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", "@flatbuffers", ], @@ -345,8 +345,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", "@flatbuffers", ], @@ -362,8 +362,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", "@flatbuffers", ], @@ -376,8 +376,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -389,8 +389,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -405,8 +405,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -421,8 +421,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -437,8 +437,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -453,10 +453,10 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", - "//tensorflow/contrib/lite/kernels/internal:reference", - "//tensorflow/contrib/lite/kernels/internal:reference_base", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", + "//tensorflow/lite/kernels/internal:reference", + "//tensorflow/lite/kernels/internal:reference_base", "@com_google_googletest//:gtest", ], ) @@ -471,8 +471,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -487,8 +487,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -503,8 +503,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -516,8 +516,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -529,8 +529,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_absl//absl/memory", "@com_google_googletest//:gtest", ], @@ -543,8 +543,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_absl//absl/memory", "@com_google_googletest//:gtest", ], @@ -560,8 +560,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_absl//absl/memory", "@com_google_googletest//:gtest", ], @@ -574,8 +574,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -590,9 +590,9 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", - "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -607,8 +607,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -623,8 +623,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -639,8 +639,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -655,8 +655,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -671,8 +671,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -684,8 +684,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -700,8 +700,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -716,8 +716,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -732,8 +732,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -748,8 +748,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -761,8 +761,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -777,8 +777,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -790,8 +790,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -806,9 +806,9 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -823,9 +823,9 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -837,8 +837,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -850,8 +850,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -863,8 +863,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -876,8 +876,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -889,8 +889,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -902,9 +902,9 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", - "//tensorflow/contrib/lite/kernels/internal:tensor_utils", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", + "//tensorflow/lite/kernels/internal:tensor_utils", "@com_google_absl//absl/memory", "@com_google_googletest//:gtest", ], @@ -917,8 +917,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -930,8 +930,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -943,9 +943,9 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", - "//tensorflow/contrib/lite/kernels/internal:reference_base", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", + "//tensorflow/lite/kernels/internal:reference_base", "@com_google_googletest//:gtest", ], ) @@ -960,9 +960,9 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", - "//tensorflow/contrib/lite/kernels/internal:reference_base", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", + "//tensorflow/lite/kernels/internal:reference_base", "@com_google_googletest//:gtest", ], ) @@ -974,8 +974,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -987,9 +987,9 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:string_util", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite:string_util", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -1001,8 +1001,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", "@flatbuffers", ], @@ -1015,8 +1015,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -1028,9 +1028,9 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:string_util", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite:string_util", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -1042,8 +1042,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -1058,8 +1058,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -1074,8 +1074,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -1090,8 +1090,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -1106,9 +1106,9 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -1125,8 +1125,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -1141,8 +1141,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -1159,8 +1159,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -1177,8 +1177,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -1193,8 +1193,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_absl//absl/memory", "@com_google_googletest//:gtest", ], @@ -1210,9 +1210,9 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -1227,9 +1227,9 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -1244,9 +1244,9 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -1261,9 +1261,9 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -1275,9 +1275,9 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -1289,8 +1289,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -1302,9 +1302,9 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -1316,9 +1316,9 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:builtin_op_data", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:builtin_op_data", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -1330,9 +1330,9 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:builtin_op_data", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:builtin_op_data", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -1344,9 +1344,9 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:builtin_op_data", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:builtin_op_data", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -1358,9 +1358,9 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:builtin_op_data", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:builtin_op_data", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -1372,9 +1372,9 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:builtin_op_data", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:builtin_op_data", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/contrib/lite/kernels/activation_functor.h b/tensorflow/lite/kernels/activation_functor.h similarity index 86% rename from tensorflow/contrib/lite/kernels/activation_functor.h rename to tensorflow/lite/kernels/activation_functor.h index e075dc705410bb..60e93c185a9c07 100644 --- a/tensorflow/contrib/lite/kernels/activation_functor.h +++ b/tensorflow/lite/kernels/activation_functor.h @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ +#ifndef TENSORFLOW_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ +#define TENSORFLOW_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ #include #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/builtin_op_data.h" namespace tflite { @@ -55,4 +55,4 @@ class ActivationFunctor { } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ +#endif // TENSORFLOW_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/lite/kernels/activations.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/activations.cc rename to tensorflow/lite/kernels/activations.cc index 9aed4f09b82cc0..9c525d964077eb 100644 --- a/tensorflow/contrib/lite/kernels/activations.cc +++ b/tensorflow/lite/kernels/activations.cc @@ -19,14 +19,14 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/lite/kernels/activations_test.cc similarity index 99% rename from tensorflow/contrib/lite/kernels/activations_test.cc rename to tensorflow/lite/kernels/activations_test.cc index 9fa47e190a1dc7..fff4121dc0c265 100644 --- a/tensorflow/contrib/lite/kernels/activations_test.cc +++ b/tensorflow/lite/kernels/activations_test.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/lite/kernels/add.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/add.cc rename to tensorflow/lite/kernels/add.cc index b4393e8097f7f5..f4bfd8d3248178 100644 --- a/tensorflow/contrib/lite/kernels/add.cc +++ b/tensorflow/lite/kernels/add.cc @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/add_test.cc b/tensorflow/lite/kernels/add_test.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/add_test.cc rename to tensorflow/lite/kernels/add_test.cc index 261dd36ef0c517..1d33adf1999ecd 100644 --- a/tensorflow/contrib/lite/kernels/add_test.cc +++ b/tensorflow/lite/kernels/add_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/arg_min_max.cc b/tensorflow/lite/kernels/arg_min_max.cc similarity index 93% rename from tensorflow/contrib/lite/kernels/arg_min_max.cc rename to tensorflow/lite/kernels/arg_min_max.cc index 531f4e1f1b057f..eea2de27f74af8 100644 --- a/tensorflow/contrib/lite/kernels/arg_min_max.cc +++ b/tensorflow/lite/kernels/arg_min_max.cc @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/arg_min_max_test.cc b/tensorflow/lite/kernels/arg_min_max_test.cc similarity index 97% rename from tensorflow/contrib/lite/kernels/arg_min_max_test.cc rename to tensorflow/lite/kernels/arg_min_max_test.cc index c8181efc360c7f..dcdff74cc6f376 100644 --- a/tensorflow/contrib/lite/kernels/arg_min_max_test.cc +++ b/tensorflow/lite/kernels/arg_min_max_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc b/tensorflow/lite/kernels/audio_spectrogram.cc similarity index 91% rename from tensorflow/contrib/lite/kernels/audio_spectrogram.cc rename to tensorflow/lite/kernels/audio_spectrogram.cc index 0d2d5e775f82a2..5a995b31ca5e6f 100644 --- a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc +++ b/tensorflow/lite/kernels/audio_spectrogram.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/spectrogram.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/spectrogram.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" #include "flatbuffers/flexbuffers.h" // TF:flatbuffers diff --git a/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc b/tensorflow/lite/kernels/audio_spectrogram_test.cc similarity index 95% rename from tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc rename to tensorflow/lite/kernels/audio_spectrogram_test.cc index 7e4ff6fc16f26d..527af2767b1bfb 100644 --- a/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc +++ b/tensorflow/lite/kernels/audio_spectrogram_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include "flatbuffers/flexbuffers.h" // TF:flatbuffers -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/lite/kernels/basic_rnn.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/basic_rnn.cc rename to tensorflow/lite/kernels/basic_rnn.cc index 7ec92ad401cc15..7c66ce1992f4c3 100644 --- a/tensorflow/contrib/lite/kernels/basic_rnn.cc +++ b/tensorflow/lite/kernels/basic_rnn.cc @@ -15,12 +15,12 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/activation_functor.h" -#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/activation_functor.h" +#include "tensorflow/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc b/tensorflow/lite/kernels/basic_rnn_test.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/basic_rnn_test.cc rename to tensorflow/lite/kernels/basic_rnn_test.cc index d1797354044c2f..240057d18a176d 100644 --- a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc +++ b/tensorflow/lite/kernels/basic_rnn_test.cc @@ -21,10 +21,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/lite/kernels/batch_to_space_nd.cc similarity index 93% rename from tensorflow/contrib/lite/kernels/batch_to_space_nd.cc rename to tensorflow/lite/kernels/batch_to_space_nd.cc index fe2865dfb9a993..34fdf34f70c966 100644 --- a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc +++ b/tensorflow/lite/kernels/batch_to_space_nd.cc @@ -14,13 +14,13 @@ limitations under the License. ==============================================================================*/ #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc b/tensorflow/lite/kernels/batch_to_space_nd_test.cc similarity index 95% rename from tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc rename to tensorflow/lite/kernels/batch_to_space_nd_test.cc index 95b025c1b30cc6..a3e06d4c893270 100644 --- a/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc +++ b/tensorflow/lite/kernels/batch_to_space_nd_test.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc similarity index 99% rename from tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc rename to tensorflow/lite/kernels/bidirectional_sequence_lstm.cc index f8660fbaa237e8..2c345bba69e487 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc @@ -20,14 +20,14 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/activation_functor.h" -#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/lstm_eval.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/activation_functor.h" +#include "tensorflow/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/lstm_eval.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc b/tensorflow/lite/kernels/bidirectional_sequence_lstm_test.cc similarity index 99% rename from tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc rename to tensorflow/lite/kernels/bidirectional_sequence_lstm_test.cc index db98d6c49d42ac..b865322682a6db 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_lstm_test.cc @@ -21,11 +21,11 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc rename to tensorflow/lite/kernels/bidirectional_sequence_rnn.cc index 8b281b174e3e1a..5194c2463092ee 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc @@ -19,12 +19,12 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/activation_functor.h" -#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/activation_functor.h" +#include "tensorflow/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc similarity index 99% rename from tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc rename to tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc index d0d04428c9594d..5bad8e02c29608 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/cast.cc b/tensorflow/lite/kernels/cast.cc similarity index 91% rename from tensorflow/contrib/lite/kernels/cast.cc rename to tensorflow/lite/kernels/cast.cc index a7972140ac9f22..ac6c85b96921dc 100644 --- a/tensorflow/contrib/lite/kernels/cast.cc +++ b/tensorflow/lite/kernels/cast.cc @@ -15,13 +15,13 @@ limitations under the License. #include #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" -#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/string_util.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/cast_test.cc b/tensorflow/lite/kernels/cast_test.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/cast_test.cc rename to tensorflow/lite/kernels/cast_test.cc index 954f998206563a..acdc331a7ea78e 100644 --- a/tensorflow/contrib/lite/kernels/cast_test.cc +++ b/tensorflow/lite/kernels/cast_test.cc @@ -15,10 +15,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/lite/kernels/comparisons.cc similarity index 97% rename from tensorflow/contrib/lite/kernels/comparisons.cc rename to tensorflow/lite/kernels/comparisons.cc index 3926af5b973947..a914449ae552e3 100644 --- a/tensorflow/contrib/lite/kernels/comparisons.cc +++ b/tensorflow/lite/kernels/comparisons.cc @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" -#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/string_util.h" namespace tflite { namespace ops { @@ -41,7 +41,7 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, input1->type != kTfLiteString || input1->type != kTfLiteBool); // Currently only support tensors have the same type. - TF_LITE_ENSURE_EQ(context, input1->type, input2->type); + TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type); output->type = kTfLiteBool; bool requires_broadcast = !HaveSameShapes(input1, input2); diff --git a/tensorflow/contrib/lite/kernels/comparisons_test.cc b/tensorflow/lite/kernels/comparisons_test.cc similarity index 99% rename from tensorflow/contrib/lite/kernels/comparisons_test.cc rename to tensorflow/lite/kernels/comparisons_test.cc index 04c8bf2e3017bf..3c278c1f9e1097 100644 --- a/tensorflow/contrib/lite/kernels/comparisons_test.cc +++ b/tensorflow/lite/kernels/comparisons_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/lite/kernels/concatenation.cc similarity index 94% rename from tensorflow/contrib/lite/kernels/concatenation.cc rename to tensorflow/lite/kernels/concatenation.cc index 7ad3399ffd3933..a8dd160c8dbb42 100644 --- a/tensorflow/contrib/lite/kernels/concatenation.cc +++ b/tensorflow/lite/kernels/concatenation.cc @@ -19,13 +19,13 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/concatenation_test.cc b/tensorflow/lite/kernels/concatenation_test.cc similarity index 97% rename from tensorflow/contrib/lite/kernels/concatenation_test.cc rename to tensorflow/lite/kernels/concatenation_test.cc index 467ff6f7e149e3..422380a03eaf90 100644 --- a/tensorflow/contrib/lite/kernels/concatenation_test.cc +++ b/tensorflow/lite/kernels/concatenation_test.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/lite/kernels/conv.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/conv.cc rename to tensorflow/lite/kernels/conv.cc index 6695282a924b13..0c14b9eb65692f 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/lite/kernels/conv.cc @@ -20,20 +20,20 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/eigen_support.h" -#include "tensorflow/contrib/lite/kernels/gemm_support.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" -#include "tensorflow/contrib/lite/kernels/padding.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/eigen_support.h" +#include "tensorflow/lite/kernels/gemm_support.h" +#include "tensorflow/lite/kernels/internal/optimized/cblas_conv.h" +#include "tensorflow/lite/kernels/internal/optimized/multithreaded_conv.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/kernels/padding.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/lite/kernels/conv_test.cc similarity index 99% rename from tensorflow/contrib/lite/kernels/conv_test.cc rename to tensorflow/lite/kernels/conv_test.cc index f7e6f083ed23f8..eebf9f9de46943 100644 --- a/tensorflow/contrib/lite/kernels/conv_test.cc +++ b/tensorflow/lite/kernels/conv_test.cc @@ -16,10 +16,10 @@ limitations under the License. #include #include "absl/memory/memory.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/lite/kernels/depthwise_conv.cc similarity index 94% rename from tensorflow/contrib/lite/kernels/depthwise_conv.cc rename to tensorflow/lite/kernels/depthwise_conv.cc index 19958844a1af87..3f4ae5087b267a 100644 --- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc +++ b/tensorflow/lite/kernels/depthwise_conv.cc @@ -19,17 +19,17 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h" -#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" -#include "tensorflow/contrib/lite/kernels/padding.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h" +#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h" +#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/kernels/padding.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc b/tensorflow/lite/kernels/depthwise_conv_test.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/depthwise_conv_test.cc rename to tensorflow/lite/kernels/depthwise_conv_test.cc index 4a33a0319d0dc3..d924e6f700781e 100644 --- a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc +++ b/tensorflow/lite/kernels/depthwise_conv_test.cc @@ -15,10 +15,10 @@ limitations under the License. #include #include #include "absl/memory/memory.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/dequantize.cc b/tensorflow/lite/kernels/dequantize.cc similarity index 90% rename from tensorflow/contrib/lite/kernels/dequantize.cc rename to tensorflow/lite/kernels/dequantize.cc index 59bf64e0afabc4..b2825bb9ea5a57 100644 --- a/tensorflow/contrib/lite/kernels/dequantize.cc +++ b/tensorflow/lite/kernels/dequantize.cc @@ -15,12 +15,12 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/dequantize_test.cc b/tensorflow/lite/kernels/dequantize_test.cc similarity index 90% rename from tensorflow/contrib/lite/kernels/dequantize_test.cc rename to tensorflow/lite/kernels/dequantize_test.cc index fcd74206177a0a..55265d93e527fd 100644 --- a/tensorflow/contrib/lite/kernels/dequantize_test.cc +++ b/tensorflow/lite/kernels/dequantize_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess.cc b/tensorflow/lite/kernels/detection_postprocess.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/detection_postprocess.cc rename to tensorflow/lite/kernels/detection_postprocess.cc index b24231eb06361b..84e2a0efb27c5e 100644 --- a/tensorflow/contrib/lite/kernels/detection_postprocess.cc +++ b/tensorflow/lite/kernels/detection_postprocess.cc @@ -16,13 +16,13 @@ limitations under the License. #include #include #include "flatbuffers/flexbuffers.h" // TF:flatbuffers -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc b/tensorflow/lite/kernels/detection_postprocess_test.cc similarity index 99% rename from tensorflow/contrib/lite/kernels/detection_postprocess_test.cc rename to tensorflow/lite/kernels/detection_postprocess_test.cc index 4d7ddd64bee480..d7ffaf1d82b542 100644 --- a/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc +++ b/tensorflow/lite/kernels/detection_postprocess_test.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include "flatbuffers/flexbuffers.h" // TF:flatbuffers -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/div.cc b/tensorflow/lite/kernels/div.cc similarity index 91% rename from tensorflow/contrib/lite/kernels/div.cc rename to tensorflow/lite/kernels/div.cc index 8d4bb5100664a3..fb40953123505a 100644 --- a/tensorflow/contrib/lite/kernels/div.cc +++ b/tensorflow/lite/kernels/div.cc @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/div_test.cc b/tensorflow/lite/kernels/div_test.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/div_test.cc rename to tensorflow/lite/kernels/div_test.cc index 97aa2fe04e2741..68a8855dd1346f 100644 --- a/tensorflow/contrib/lite/kernels/div_test.cc +++ b/tensorflow/lite/kernels/div_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/eigen_support.cc b/tensorflow/lite/kernels/eigen_support.cc similarity index 94% rename from tensorflow/contrib/lite/kernels/eigen_support.cc rename to tensorflow/lite/kernels/eigen_support.cc index e542ad076528fa..44e0086ad88303 100644 --- a/tensorflow/contrib/lite/kernels/eigen_support.cc +++ b/tensorflow/lite/kernels/eigen_support.cc @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/kernels/eigen_support.h" +#include "tensorflow/lite/kernels/eigen_support.h" #include -#include "tensorflow/contrib/lite/arena_planner.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/arena_planner.h" +#include "tensorflow/lite/kernels/internal/optimized/eigen_spatial_convolutions.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace eigen_support { diff --git a/tensorflow/contrib/lite/kernels/eigen_support.h b/tensorflow/lite/kernels/eigen_support.h similarity index 85% rename from tensorflow/contrib/lite/kernels/eigen_support.h rename to tensorflow/lite/kernels/eigen_support.h index feb1543f7be348..c24ae6896a7e97 100644 --- a/tensorflow/contrib/lite/kernels/eigen_support.h +++ b/tensorflow/lite/kernels/eigen_support.h @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_ +#ifndef TENSORFLOW_LITE_KERNELS_EIGEN_SUPPORT_H_ +#define TENSORFLOW_LITE_KERNELS_EIGEN_SUPPORT_H_ -#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/lite/c/c_api_internal.h" namespace EigenForTFLite { struct ThreadPoolDevice; @@ -38,4 +38,4 @@ const EigenForTFLite::ThreadPoolDevice* GetThreadPoolDevice( } // namespace eigen_support } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_ +#endif // TENSORFLOW_LITE_KERNELS_EIGEN_SUPPORT_H_ diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/lite/kernels/elementwise.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/elementwise.cc rename to tensorflow/lite/kernels/elementwise.cc index 8c624b320808d2..416a69eb0ed824 100644 --- a/tensorflow/contrib/lite/kernels/elementwise.cc +++ b/tensorflow/lite/kernels/elementwise.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/elementwise_test.cc b/tensorflow/lite/kernels/elementwise_test.cc similarity index 95% rename from tensorflow/contrib/lite/kernels/elementwise_test.cc rename to tensorflow/lite/kernels/elementwise_test.cc index 5dd89a0eaec13b..52df8dc3cca0b0 100644 --- a/tensorflow/contrib/lite/kernels/elementwise_test.cc +++ b/tensorflow/lite/kernels/elementwise_test.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/lite/kernels/embedding_lookup.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/embedding_lookup.cc rename to tensorflow/lite/kernels/embedding_lookup.cc index 1d0c71ad48e36c..fad32607b4980c 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc +++ b/tensorflow/lite/kernels/embedding_lookup.cc @@ -37,10 +37,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc b/tensorflow/lite/kernels/embedding_lookup_sparse.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc rename to tensorflow/lite/kernels/embedding_lookup_sparse.cc index 0b076941ea2164..72bfe5b4f5d71f 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc +++ b/tensorflow/lite/kernels/embedding_lookup_sparse.cc @@ -65,11 +65,11 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc b/tensorflow/lite/kernels/embedding_lookup_sparse_test.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc rename to tensorflow/lite/kernels/embedding_lookup_sparse_test.cc index ef2b5422253ea8..0c555fdd7de61f 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc +++ b/tensorflow/lite/kernels/embedding_lookup_sparse_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc b/tensorflow/lite/kernels/embedding_lookup_test.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/embedding_lookup_test.cc rename to tensorflow/lite/kernels/embedding_lookup_test.cc index 4a88d168c60203..8ea98a5f0dcbfb 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc +++ b/tensorflow/lite/kernels/embedding_lookup_test.cc @@ -20,10 +20,10 @@ License. #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/exp.cc b/tensorflow/lite/kernels/exp.cc similarity index 88% rename from tensorflow/contrib/lite/kernels/exp.cc rename to tensorflow/lite/kernels/exp.cc index 673e7be90a6d57..607b398ebd73f6 100644 --- a/tensorflow/contrib/lite/kernels/exp.cc +++ b/tensorflow/lite/kernels/exp.cc @@ -14,12 +14,12 @@ limitations under the License. ==============================================================================*/ #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/exp_test.cc b/tensorflow/lite/kernels/exp_test.cc similarity index 91% rename from tensorflow/contrib/lite/kernels/exp_test.cc rename to tensorflow/lite/kernels/exp_test.cc index eed67369a1f30e..fa71fe351a421a 100644 --- a/tensorflow/contrib/lite/kernels/exp_test.cc +++ b/tensorflow/lite/kernels/exp_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/expand_dims.cc b/tensorflow/lite/kernels/expand_dims.cc similarity index 90% rename from tensorflow/contrib/lite/kernels/expand_dims.cc rename to tensorflow/lite/kernels/expand_dims.cc index fa1140b19c09dd..dd2479f34e6e8f 100644 --- a/tensorflow/contrib/lite/kernels/expand_dims.cc +++ b/tensorflow/lite/kernels/expand_dims.cc @@ -15,12 +15,12 @@ limitations under the License. ==============================================================================*/ #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { namespace builtin { diff --git a/tensorflow/contrib/lite/kernels/expand_dims_test.cc b/tensorflow/lite/kernels/expand_dims_test.cc similarity index 90% rename from tensorflow/contrib/lite/kernels/expand_dims_test.cc rename to tensorflow/lite/kernels/expand_dims_test.cc index a3bc1813dbc776..ea0c6c0fc830ec 100644 --- a/tensorflow/contrib/lite/kernels/expand_dims_test.cc +++ b/tensorflow/lite/kernels/expand_dims_test.cc @@ -14,11 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/fake_quant.cc b/tensorflow/lite/kernels/fake_quant.cc similarity index 89% rename from tensorflow/contrib/lite/kernels/fake_quant.cc rename to tensorflow/lite/kernels/fake_quant.cc index b51af72fe66a69..9c799a7ec2247d 100644 --- a/tensorflow/contrib/lite/kernels/fake_quant.cc +++ b/tensorflow/lite/kernels/fake_quant.cc @@ -14,12 +14,12 @@ limitations under the License. ==============================================================================*/ #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/fake_quant_test.cc b/tensorflow/lite/kernels/fake_quant_test.cc similarity index 95% rename from tensorflow/contrib/lite/kernels/fake_quant_test.cc rename to tensorflow/lite/kernels/fake_quant_test.cc index 11a02f7ed7474e..ce14703421e1cd 100644 --- a/tensorflow/contrib/lite/kernels/fake_quant_test.cc +++ b/tensorflow/lite/kernels/fake_quant_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/floor.cc b/tensorflow/lite/kernels/floor.cc similarity index 88% rename from tensorflow/contrib/lite/kernels/floor.cc rename to tensorflow/lite/kernels/floor.cc index 59ff77f35b8d3f..aa117e3cacfc46 100644 --- a/tensorflow/contrib/lite/kernels/floor.cc +++ b/tensorflow/lite/kernels/floor.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/floor_div.cc b/tensorflow/lite/kernels/floor_div.cc similarity index 93% rename from tensorflow/contrib/lite/kernels/floor_div.cc rename to tensorflow/lite/kernels/floor_div.cc index 5d62cd27550f4f..9d404af5b0b5e9 100644 --- a/tensorflow/contrib/lite/kernels/floor_div.cc +++ b/tensorflow/lite/kernels/floor_div.cc @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/floor_div_test.cc b/tensorflow/lite/kernels/floor_div_test.cc similarity index 93% rename from tensorflow/contrib/lite/kernels/floor_div_test.cc rename to tensorflow/lite/kernels/floor_div_test.cc index eea69b61ac161e..8816260d9b45da 100644 --- a/tensorflow/contrib/lite/kernels/floor_div_test.cc +++ b/tensorflow/lite/kernels/floor_div_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/floor_mod.cc b/tensorflow/lite/kernels/floor_mod.cc similarity index 90% rename from tensorflow/contrib/lite/kernels/floor_mod.cc rename to tensorflow/lite/kernels/floor_mod.cc index b6bf054443b09a..beddac2174e372 100644 --- a/tensorflow/contrib/lite/kernels/floor_mod.cc +++ b/tensorflow/lite/kernels/floor_mod.cc @@ -14,11 +14,11 @@ limitations under the License. ==============================================================================*/ #include #include -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" // TODO(b/117523611): We should factor out a binary_op and put binary ops there. namespace tflite { @@ -82,8 +82,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const TfLiteType type = input1->type; if (type != kTfLiteInt32 && type != kTfLiteFloat32) { - context->ReportError(context, - "Currently floor_mod only supports int32 and float."); + context->ReportError(context, "Type '%s' is not supported by floor_mod.", + TfLiteTypeGetName(type)); return kTfLiteError; } output->type = type; @@ -149,8 +149,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { output); } default: { - context->ReportError( - context, "Currently floor_mod only supports int32 and float."); + context->ReportError(context, "Type '%s' is not supported by floor_mod.", + TfLiteTypeGetName(input1->type)); return kTfLiteError; } } diff --git a/tensorflow/contrib/lite/kernels/floor_mod_test.cc b/tensorflow/lite/kernels/floor_mod_test.cc similarity index 95% rename from tensorflow/contrib/lite/kernels/floor_mod_test.cc rename to tensorflow/lite/kernels/floor_mod_test.cc index c581a5393648db..9d75f5ce2e3ef8 100644 --- a/tensorflow/contrib/lite/kernels/floor_mod_test.cc +++ b/tensorflow/lite/kernels/floor_mod_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/floor_test.cc b/tensorflow/lite/kernels/floor_test.cc similarity index 93% rename from tensorflow/contrib/lite/kernels/floor_test.cc rename to tensorflow/lite/kernels/floor_test.cc index b71e0400b6dc92..9bcbdba8a4f0b2 100644 --- a/tensorflow/contrib/lite/kernels/floor_test.cc +++ b/tensorflow/lite/kernels/floor_test.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/lite/kernels/fully_connected.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/fully_connected.cc rename to tensorflow/lite/kernels/fully_connected.cc index cac556db33a6fa..63cca1cf5427f9 100644 --- a/tensorflow/contrib/lite/kernels/fully_connected.cc +++ b/tensorflow/lite/kernels/fully_connected.cc @@ -20,17 +20,17 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/activation_functor.h" -#include "tensorflow/contrib/lite/kernels/gemm_support.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/activation_functor.h" +#include "tensorflow/lite/kernels/gemm_support.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/fully_connected_test.cc b/tensorflow/lite/kernels/fully_connected_test.cc similarity index 99% rename from tensorflow/contrib/lite/kernels/fully_connected_test.cc rename to tensorflow/lite/kernels/fully_connected_test.cc index 08b43209466a1b..3351a30b123b12 100644 --- a/tensorflow/contrib/lite/kernels/fully_connected_test.cc +++ b/tensorflow/lite/kernels/fully_connected_test.cc @@ -21,11 +21,11 @@ limitations under the License. #include #include #include "absl/memory/memory.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/gather.cc b/tensorflow/lite/kernels/gather.cc similarity index 83% rename from tensorflow/contrib/lite/kernels/gather.cc rename to tensorflow/lite/kernels/gather.cc index b5afeb1a7bd552..195a6d2b81b6b7 100644 --- a/tensorflow/contrib/lite/kernels/gather.cc +++ b/tensorflow/lite/kernels/gather.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" -#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/string_util.h" namespace tflite { namespace ops { @@ -42,8 +42,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, positions->type, kTfLiteInt32); // Assign to output the input type. output->type = input->type; - // TODO(mgubin): Only default axis == 0 is supported. - TF_LITE_ENSURE_EQ(context, params->axis, 0); // Check conditions for different types. switch (input->type) { case kTfLiteFloat32: @@ -57,29 +55,36 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1); } break; default: - context->ReportError( - context, "Only float32 and string types are supported, got %d", - input->type); + context->ReportError(context, "Type '%s' is not supported by gather.", + TfLiteTypeGetName(input->type)); return kTfLiteError; } + + int axis = params->axis; + if (axis < 0) { + axis += NumDimensions(input); + } + TF_LITE_ENSURE(context, 0 <= axis && axis < NumDimensions(input)); + const int num_dimensions = NumDimensions(input) + NumDimensions(positions) - 1; - TF_LITE_ENSURE(context, params->axis <= num_dimensions); TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions); int output_index = 0; - for (int i = 0; i < params->axis; ++i) { + for (int i = 0; i < axis; ++i) { output_shape->data[output_index++] = input->dims->data[i]; } for (int i = 0; i < positions->dims->size; ++i) { output_shape->data[output_index++] = positions->dims->data[i]; } - for (int i = params->axis + 1; i < input->dims->size; ++i) { + for (int i = axis + 1; i < input->dims->size; ++i) { output_shape->data[output_index++] = input->dims->data[i]; } return context->ResizeTensor(context, output, output_shape); } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const auto* params = + reinterpret_cast(node->builtin_data); const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* positions = GetInput(context, node, kInputPositions); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); @@ -88,6 +93,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { { \ tflite::GatherParams op_params; \ op_params.input_rank = input_rank; \ + op_params.axis = params->axis; \ optimized_ops::Gather( \ op_params, GetTensorShape(input), GetTensorData(input), \ GetTensorShape(positions), GetTensorData(positions), \ diff --git a/tensorflow/contrib/lite/kernels/gather_test.cc b/tensorflow/lite/kernels/gather_test.cc similarity index 80% rename from tensorflow/contrib/lite/kernels/gather_test.cc rename to tensorflow/lite/kernels/gather_test.cc index 1b48884e0907c6..58460f847fa062 100644 --- a/tensorflow/contrib/lite/kernels/gather_test.cc +++ b/tensorflow/lite/kernels/gather_test.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { @@ -27,12 +27,12 @@ using ::testing::ElementsAreArray; class GatherOpModel : public SingleOpModel { public: GatherOpModel(std::initializer_list input_shape, TensorType input_type, - std::initializer_list positions_shape) { + std::initializer_list positions_shape, int axis = 0) { input_ = AddInput(input_type); positions_ = AddInput(TensorType_INT32); output_ = AddOutput(input_type); SetBuiltinOp(BuiltinOperator_GATHER, BuiltinOptions_GatherOptions, - CreateGatherOptions(builder_, 0).Union()); + CreateGatherOptions(builder_, axis).Union()); BuildInterpreter({input_shape, positions_shape}); } @@ -123,6 +123,28 @@ TEST(FloatGatherOpTest, Slice) { EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray(ArrayFloatNear({0.2, 0.8}))); } +TEST(FloatGatherOpTest, Axis1) { + const int axis = 1; + GatherOpModel m({1, 2, 3}, TensorType_FLOAT32, {2}, axis); + m.SetInputFloat({1, 2, 3, 4, 5, 6}); + m.SetPositions({1, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), + ElementsAreArray(ArrayFloatNear({4, 5, 6, 1, 2, 3}))); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 3})); +} + +TEST(FloatGatherOpTest, LastAxis) { + const int axis = -1; + GatherOpModel m({1, 2, 3}, TensorType_FLOAT32, {2}, axis); + m.SetInputFloat({1, 2, 3, 4, 5, 6}); + m.SetPositions({2, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), + ElementsAreArray(ArrayFloatNear({3, 1, 6, 4}))); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 2})); +} + TEST(Uint8tGatherOpTest, Shuffle) { GatherOpModel m({2, 2}, TensorType_UINT8, {2}); m.SetInputUint8({133, 134, 14, 15}); diff --git a/tensorflow/contrib/lite/kernels/gemm_support.cc b/tensorflow/lite/kernels/gemm_support.cc similarity index 95% rename from tensorflow/contrib/lite/kernels/gemm_support.cc rename to tensorflow/lite/kernels/gemm_support.cc index ed334af2da877e..cc224cb8840125 100644 --- a/tensorflow/contrib/lite/kernels/gemm_support.cc +++ b/tensorflow/lite/kernels/gemm_support.cc @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/kernels/gemm_support.h" +#include "tensorflow/lite/kernels/gemm_support.h" #include -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace gemm_support { diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/lite/kernels/gemm_support.h similarity index 89% rename from tensorflow/contrib/lite/kernels/gemm_support.h rename to tensorflow/lite/kernels/gemm_support.h index 43cd2b3055c5c3..1feb638952acb0 100644 --- a/tensorflow/contrib/lite/kernels/gemm_support.h +++ b/tensorflow/lite/kernels/gemm_support.h @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ +#ifndef TENSORFLOW_LITE_KERNELS_GEMM_SUPPORT_H_ +#define TENSORFLOW_LITE_KERNELS_GEMM_SUPPORT_H_ #include "public/gemmlowp.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/lite/c/c_api_internal.h" namespace tflite { namespace gemm_support { @@ -48,4 +48,4 @@ void DecrementUsageCounter(TfLiteContext* context); } // namespace gemm_support } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ +#endif // TENSORFLOW_LITE_KERNELS_GEMM_SUPPORT_H_ diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc b/tensorflow/lite/kernels/hashtable_lookup.cc similarity index 94% rename from tensorflow/contrib/lite/kernels/hashtable_lookup.cc rename to tensorflow/lite/kernels/hashtable_lookup.cc index c0b3c3c0c5beae..b6ae7a3d1a5479 100644 --- a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc +++ b/tensorflow/lite/kernels/hashtable_lookup.cc @@ -39,11 +39,11 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" -#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/string_util.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc b/tensorflow/lite/kernels/hashtable_lookup_test.cc similarity index 95% rename from tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc rename to tensorflow/lite/kernels/hashtable_lookup_test.cc index ba0ed5ce063926..d2ca76a206783f 100644 --- a/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc +++ b/tensorflow/lite/kernels/hashtable_lookup_test.cc @@ -19,11 +19,11 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/string_util.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD similarity index 89% rename from tensorflow/contrib/lite/kernels/internal/BUILD rename to tensorflow/lite/kernels/internal/BUILD index 9f8f224094d818..01c4c5d6b2cee9 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -4,8 +4,8 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 -load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") -load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") +load("//tensorflow/lite:build_def.bzl", "tflite_copts") +load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") tflite_deps_intel = [ "@arm_neon_2_x86_sse", @@ -44,7 +44,7 @@ cc_library( "types.h", ], deps = [ - "//tensorflow/contrib/lite/kernels:op_macros", + "//tensorflow/lite/kernels:op_macros", "@com_google_absl//absl/base:core_headers", ], ) @@ -58,7 +58,7 @@ cc_library( "types.h", ], deps = [ - "//tensorflow/contrib/lite/kernels:op_macros", + "//tensorflow/lite/kernels:op_macros", "@com_google_absl//absl/base:core_headers", ], ) @@ -181,7 +181,7 @@ cc_library( ":tensor_utils", "//third_party/eigen3", "@gemmlowp", - "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/lite/c:c_api_internal", ] + select({ ":haswell": tflite_deps_intel, ":ios_x86_64": tflite_deps_intel, @@ -217,7 +217,7 @@ cc_library( ":round", "//third_party/eigen3", "@gemmlowp", - "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/lite/c:c_api_internal", ] + select({ ":haswell": tflite_deps_intel, ":ios_x86_64": tflite_deps_intel, @@ -247,7 +247,7 @@ cc_library( ":optimized_base", ":tensor", ":types", - "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/lite/c:c_api_internal", "//third_party/eigen3", ], ) @@ -281,7 +281,7 @@ cc_library( deps = [ ":round", ":types", - "//tensorflow/contrib/lite/kernels:op_macros", + "//tensorflow/lite/kernels:op_macros", ], ) @@ -326,8 +326,8 @@ cc_library( ":strided_slice_logic", ":types", "@gemmlowp", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/kernels:op_macros", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:op_macros", ] + select({ ":haswell": tflite_deps_intel, ":ios_x86_64": tflite_deps_intel, @@ -360,8 +360,8 @@ cc_library( ":legacy_types", ":types", "@gemmlowp", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/kernels:op_macros", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:op_macros", ] + select({ ":haswell": tflite_deps_intel, ":ios_x86_64": tflite_deps_intel, @@ -383,7 +383,7 @@ cc_library( ], deps = [ ":types", - "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/lite/c:c_api_internal", ], ) @@ -396,7 +396,7 @@ cc_library( ], deps = [ ":types", - "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/lite/c:c_api_internal", ], ) @@ -410,9 +410,9 @@ cc_library( ], deps = [ ":round", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/kernels:activation_functor", - "//tensorflow/contrib/lite/kernels:op_macros", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:activation_functor", + "//tensorflow/lite/kernels:op_macros", ], ) @@ -435,9 +435,9 @@ cc_library( ":cpu_check", ":round", ":types", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/kernels:activation_functor", - "//tensorflow/contrib/lite/kernels:op_macros", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:activation_functor", + "//tensorflow/lite/kernels:op_macros", "@arm_neon_2_x86_sse", "@gemmlowp", ], @@ -449,7 +449,7 @@ cc_library( hdrs = ["kernel_utils.h"], deps = [ ":tensor_utils", - "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/lite/c:c_api_internal", ], ) @@ -492,9 +492,9 @@ cc_library( copts = NEON_FLAGS_IF_APPLICABLE, deps = [ "@com_google_absl//absl/base:core_headers", - "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/lite/c:c_api_internal", "@arm_neon_2_x86_sse", - "//tensorflow/contrib/lite/kernels:op_macros", + "//tensorflow/lite/kernels:op_macros", "@gemmlowp", ] + select({ ":arm": [ @@ -548,7 +548,7 @@ cc_library( hdrs = ["test_util.h"], deps = [ ":types", - "//tensorflow/contrib/lite:string", + "//tensorflow/lite:string", ], ) @@ -569,8 +569,8 @@ cc_test( ], deps = [ ":tensor_utils", - "//tensorflow/contrib/lite/c:c_api_internal", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest_main", ], ) @@ -593,6 +593,8 @@ cc_test( srcs = ["depthwiseconv_quantized_test.cc"], tags = [ "no_oss", + # TODO(b/119052685): Re-enable this test in TSAN. + "notsan", "tflite_not_portable_ios", ], deps = [ @@ -648,7 +650,7 @@ cc_test( ":quantization_util", ":reference_base", ":test_util", - "//tensorflow/contrib/lite:string", + "//tensorflow/lite:string", "@com_google_googletest//:gtest_main", ], ) @@ -668,7 +670,7 @@ cc_test( ":quantization_util", ":reference_base", ":test_util", - "//tensorflow/contrib/lite:string", + "//tensorflow/lite:string", "@com_google_googletest//:gtest_main", ], ) @@ -680,7 +682,7 @@ cc_test( deps = [ ":optimized_base", ":reference_base", - "//tensorflow/contrib/lite:string", + "//tensorflow/lite:string", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/contrib/lite/kernels/internal/batch_to_space_nd_test.cc b/tensorflow/lite/kernels/internal/batch_to_space_nd_test.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/internal/batch_to_space_nd_test.cc rename to tensorflow/lite/kernels/internal/batch_to_space_nd_test.cc index 5a2901ac8c2972..5fc2c93ba0e3a7 100644 --- a/tensorflow/contrib/lite/kernels/internal/batch_to_space_nd_test.cc +++ b/tensorflow/lite/kernels/internal/batch_to_space_nd_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/lite/kernels/internal/common.h similarity index 97% rename from tensorflow/contrib/lite/kernels/internal/common.h rename to tensorflow/lite/kernels/internal/common.h index e67fee11b8d24d..fdb72037f84e4c 100644 --- a/tensorflow/contrib/lite/kernels/internal/common.h +++ b/tensorflow/lite/kernels/internal/common.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_ #ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK #ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK @@ -27,7 +27,7 @@ limitations under the License. #include #endif -#if defined __GNUC__ && defined __SSE4_1__ +#if defined __GNUC__ && defined __SSE4_1__ && !defined TF_LITE_DISABLE_X86_NEON #define USE_NEON #define OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS @@ -46,7 +46,7 @@ limitations under the License. #endif #include "fixedpoint/fixedpoint.h" -#include "tensorflow/contrib/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/internal/types.h" namespace tflite { @@ -266,4 +266,4 @@ inline void NdArrayDescsForElementwiseBroadcast( } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/compatibility.h b/tensorflow/lite/kernels/internal/compatibility.h similarity index 92% rename from tensorflow/contrib/lite/kernels/internal/compatibility.h rename to tensorflow/lite/kernels/internal/compatibility.h index 7c176e0fa1c8e8..bfd021ac48df5b 100644 --- a/tensorflow/contrib/lite/kernels/internal/compatibility.h +++ b/tensorflow/lite/kernels/internal/compatibility.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ #include -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/kernels/op_macros.h" #ifndef TFLITE_DCHECK #define TFLITE_DCHECK(condition) (condition) ? (void)0 : TFLITE_ASSERT_FALSE @@ -107,4 +107,4 @@ using uint32 = std::uint32_t; #define TFLITE_DEPRECATED(message) #endif -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc b/tensorflow/lite/kernels/internal/depthwiseconv_float_test.cc similarity index 95% rename from tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc rename to tensorflow/lite/kernels/internal/depthwiseconv_float_test.cc index 41862a21a6ed5a..3602b9ffd84357 100644 --- a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc +++ b/tensorflow/lite/kernels/internal/depthwiseconv_float_test.cc @@ -17,13 +17,13 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/kernels/internal/common.h" -#include "tensorflow/contrib/lite/kernels/internal/test_util.h" -#include "tensorflow/contrib/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/test_util.h" +#include "tensorflow/lite/kernels/internal/types.h" #define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK -#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h" +#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h" +#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc b/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc rename to tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc index fd0413715086c7..3682499d494cc4 100644 --- a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc @@ -22,13 +22,13 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/kernels/internal/test_util.h" -#include "tensorflow/contrib/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/internal/test_util.h" +#include "tensorflow/lite/kernels/internal/types.h" #define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK -#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h" +#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h" +#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h" +#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/lite/kernels/internal/kernel_utils.cc similarity index 99% rename from tensorflow/contrib/lite/kernels/internal/kernel_utils.cc rename to tensorflow/lite/kernels/internal/kernel_utils.cc index 7875b23979e33c..0836a3b662d356 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc +++ b/tensorflow/lite/kernels/internal/kernel_utils.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/lite/kernels/internal/kernel_utils.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/lite/kernels/internal/tensor_utils.h" namespace tflite { namespace kernel_utils { diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/lite/kernels/internal/kernel_utils.h similarity index 94% rename from tensorflow/contrib/lite/kernels/internal/kernel_utils.h rename to tensorflow/lite/kernels/internal/kernel_utils.h index 0387d753e5abf1..ebb91678fecd94 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h +++ b/tensorflow/lite/kernels/internal/kernel_utils.h @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/builtin_op_data.h" namespace tflite { namespace kernel_utils { @@ -86,4 +86,4 @@ void RnnBatchStep( } // namespace kernel_utils } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/legacy_types.h b/tensorflow/lite/kernels/internal/legacy_types.h similarity index 74% rename from tensorflow/contrib/lite/kernels/internal/legacy_types.h rename to tensorflow/lite/kernels/internal/legacy_types.h index 2e4d3137f5c6ac..c19a1adb90f48c 100644 --- a/tensorflow/contrib/lite/kernels/internal/legacy_types.h +++ b/tensorflow/lite/kernels/internal/legacy_types.h @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_LEGACY_TYPES_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_LEGACY_TYPES_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_LEGACY_TYPES_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_LEGACY_TYPES_H_ -#include "tensorflow/contrib/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/internal/types.h" namespace tflite { @@ -23,4 +23,4 @@ namespace tflite { } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_LEGACY_TYPES_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_LEGACY_TYPES_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/log_quantized_test.cc b/tensorflow/lite/kernels/internal/log_quantized_test.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/internal/log_quantized_test.cc rename to tensorflow/lite/kernels/internal/log_quantized_test.cc index 8963abb9afd9d5..8c39350ab1dd89 100644 --- a/tensorflow/contrib/lite/kernels/internal/log_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/log_quantized_test.cc @@ -27,9 +27,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/string.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/string.h" namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc b/tensorflow/lite/kernels/internal/logsoftmax_quantized_test.cc similarity index 97% rename from tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc rename to tensorflow/lite/kernels/internal/logsoftmax_quantized_test.cc index 2252ca1bcc2190..889a726f3a915f 100644 --- a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/logsoftmax_quantized_test.cc @@ -23,11 +23,11 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/test_util.h" -#include "tensorflow/contrib/lite/string.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/test_util.h" +#include "tensorflow/lite/string.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/internal/mfcc.cc b/tensorflow/lite/kernels/internal/mfcc.cc similarity index 97% rename from tensorflow/contrib/lite/kernels/internal/mfcc.cc rename to tensorflow/lite/kernels/internal/mfcc.cc index eafe0c7afee6fa..fddd4c46b6094a 100644 --- a/tensorflow/contrib/lite/kernels/internal/mfcc.cc +++ b/tensorflow/lite/kernels/internal/mfcc.cc @@ -15,7 +15,7 @@ limitations under the License. #include -#include "tensorflow/contrib/lite/kernels/internal/mfcc.h" +#include "tensorflow/lite/kernels/internal/mfcc.h" namespace tflite { namespace internal { diff --git a/tensorflow/contrib/lite/kernels/internal/mfcc.h b/tensorflow/lite/kernels/internal/mfcc.h similarity index 88% rename from tensorflow/contrib/lite/kernels/internal/mfcc.h rename to tensorflow/lite/kernels/internal/mfcc.h index d8500ecdcf38e5..8dae91efdeb542 100644 --- a/tensorflow/contrib/lite/kernels/internal/mfcc.h +++ b/tensorflow/lite/kernels/internal/mfcc.h @@ -15,13 +15,13 @@ limitations under the License. // Basic class for computing MFCCs from spectrogram slices. -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_MFCC_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_MFCC_H_ #include -#include "tensorflow/contrib/lite/kernels/internal/mfcc_dct.h" -#include "tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.h" +#include "tensorflow/lite/kernels/internal/mfcc_dct.h" +#include "tensorflow/lite/kernels/internal/mfcc_mel_filterbank.h" namespace tflite { namespace internal { @@ -75,4 +75,4 @@ class Mfcc { } // namespace internal } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_MFCC_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/mfcc_dct.cc b/tensorflow/lite/kernels/internal/mfcc_dct.cc similarity index 97% rename from tensorflow/contrib/lite/kernels/internal/mfcc_dct.cc rename to tensorflow/lite/kernels/internal/mfcc_dct.cc index b0b7d181bdcf01..c249fdb020a3ac 100644 --- a/tensorflow/contrib/lite/kernels/internal/mfcc_dct.cc +++ b/tensorflow/lite/kernels/internal/mfcc_dct.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/kernels/internal/mfcc_dct.h" +#include "tensorflow/lite/kernels/internal/mfcc_dct.h" #include diff --git a/tensorflow/contrib/lite/kernels/internal/mfcc_dct.h b/tensorflow/lite/kernels/internal/mfcc_dct.h similarity index 86% rename from tensorflow/contrib/lite/kernels/internal/mfcc_dct.h rename to tensorflow/lite/kernels/internal/mfcc_dct.h index a53f5cbd9bb70c..f2947b506b2aed 100644 --- a/tensorflow/contrib/lite/kernels/internal/mfcc_dct.h +++ b/tensorflow/lite/kernels/internal/mfcc_dct.h @@ -15,8 +15,8 @@ limitations under the License. // Basic minimal DCT class for MFCC speech processing. -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_DCT_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_DCT_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_MFCC_DCT_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_MFCC_DCT_H_ #include @@ -40,4 +40,4 @@ class MfccDct { } // namespace internal } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_DCT_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_MFCC_DCT_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.cc b/tensorflow/lite/kernels/internal/mfcc_mel_filterbank.cc similarity index 99% rename from tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.cc rename to tensorflow/lite/kernels/internal/mfcc_mel_filterbank.cc index c3deb33d91a47b..9748da39862edd 100644 --- a/tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.cc +++ b/tensorflow/lite/kernels/internal/mfcc_mel_filterbank.cc @@ -28,7 +28,7 @@ limitations under the License. // channels may end up with no contributing FFT bins. The resulting mel // spectrum output will have some channels that are always zero. -#include "tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.h" +#include "tensorflow/lite/kernels/internal/mfcc_mel_filterbank.h" #include diff --git a/tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.h b/tensorflow/lite/kernels/internal/mfcc_mel_filterbank.h similarity index 91% rename from tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.h rename to tensorflow/lite/kernels/internal/mfcc_mel_filterbank.h index c1db28243eea39..53d05bff5f45e4 100644 --- a/tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.h +++ b/tensorflow/lite/kernels/internal/mfcc_mel_filterbank.h @@ -15,8 +15,8 @@ limitations under the License. // Basic class for applying a mel-scale mapping to a power spectrum. -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_MEL_FILTERBANK_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_MEL_FILTERBANK_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_MFCC_MEL_FILTERBANK_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_MFCC_MEL_FILTERBANK_H_ #include @@ -60,4 +60,4 @@ class MfccMelFilterbank { } // namespace internal } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MFCC_MEL_FILTERBANK_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_MFCC_MEL_FILTERBANK_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h b/tensorflow/lite/kernels/internal/optimized/cblas_conv.h similarity index 90% rename from tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h rename to tensorflow/lite/kernels/internal/optimized/cblas_conv.h index 2d96da65c33bd4..53772050503b2b 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h +++ b/tensorflow/lite/kernels/internal/optimized/cblas_conv.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_CONV_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_CONV_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_CONV_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_CONV_H_ // The Conv implementation based on CBLAS interface. This is only used on iOS // for now, utilizing Apple's Accelerate framework. @@ -22,11 +22,11 @@ limitations under the License. #if TFLITE_USE_APPLE_ACCELERATE_FOR_CONV #include #else -#include "tensorflow/contrib/lite/kernels/internal/optimized/cblas_reference.h" +#include "tensorflow/lite/kernels/internal/optimized/cblas_reference.h" #endif -#include "tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/optimized/multithreaded_conv.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" namespace tflite { namespace cblas_ops { @@ -106,4 +106,4 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape, } // namespace cblas_ops } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_CONV_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_CONV_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cblas_reference.h b/tensorflow/lite/kernels/internal/optimized/cblas_reference.h similarity index 89% rename from tensorflow/contrib/lite/kernels/internal/optimized/cblas_reference.h rename to tensorflow/lite/kernels/internal/optimized/cblas_reference.h index 6acc513805c939..fa07578612aaa5 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/cblas_reference.h +++ b/tensorflow/lite/kernels/internal/optimized/cblas_reference.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_REFERENCE_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_REFERENCE_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_REFERENCE_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_REFERENCE_H_ -#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" // The reference implementation for a small subset of CBLAS interface. // This is only used for testing CBLAS implementation, and should never be used @@ -66,4 +66,4 @@ void cblas_sgemm(const enum CBLAS_ORDER order, } // namespace cblas_ops } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_REFERENCE_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_REFERENCE_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h b/tensorflow/lite/kernels/internal/optimized/cpu_check.h similarity index 89% rename from tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h rename to tensorflow/lite/kernels/internal/optimized/cpu_check.h index 934308ef291956..ac4ea7d6dae045 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h +++ b/tensorflow/lite/kernels/internal/optimized/cpu_check.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_H_ namespace tflite { @@ -58,4 +58,4 @@ inline bool TestCPUFeatureNeon() { return false; } : Portable##funcname(__VA_ARGS__) #endif -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h similarity index 99% rename from tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h rename to tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h index 4e21805048b703..25b66d4b5537f5 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ #include "public/gemmlowp.h" -#include "tensorflow/contrib/lite/kernels/internal/common.h" -#include "tensorflow/contrib/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/types.h" namespace tflite { namespace optimized_ops { @@ -1068,4 +1068,4 @@ inline void DepthwiseConv( } // namespace optimized_ops } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h similarity index 99% rename from tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h rename to tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h index 3c8d447c6d88c9..5317cea8843923 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ #include "fixedpoint/fixedpoint.h" #include "public/gemmlowp.h" -#include "tensorflow/contrib/lite/kernels/internal/common.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h" -#include "tensorflow/contrib/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h" +#include "tensorflow/lite/kernels/internal/types.h" namespace tflite { namespace optimized_ops { @@ -1994,4 +1994,4 @@ inline void DepthwiseConv( } // namespace optimized_ops } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h similarity index 99% rename from tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h rename to tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h index f437487fc2a512..3f2ed0b1f0eb3c 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_3X3_FILTER_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_3X3_FILTER_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_3X3_FILTER_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_3X3_FILTER_H_ #include "fixedpoint/fixedpoint.h" #include "public/gemmlowp.h" -#include "tensorflow/contrib/lite/kernels/internal/common.h" -#include "tensorflow/contrib/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/types.h" namespace tflite { namespace optimized_ops { @@ -3446,4 +3446,4 @@ inline void DepthwiseConv3x3Filter( } // namespace optimized_ops } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_3X3_FILTER_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_3X3_FILTER_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h b/tensorflow/lite/kernels/internal/optimized/eigen_spatial_convolutions.h similarity index 95% rename from tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h rename to tensorflow/lite/kernels/internal/optimized/eigen_spatial_convolutions.h index ce3cde76999c77..29e3f534a38d42 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h +++ b/tensorflow/lite/kernels/internal/optimized/eigen_spatial_convolutions.h @@ -16,8 +16,8 @@ limitations under the License. // Copied from tensorflow/core/kernels/eigen_spatial_convolutions.h. // TODO(petewarden) - move this to a common location in Eigen itself. -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ #define EIGEN_USE_CUSTOM_THREAD_POOL #define EIGEN_USE_THREADS @@ -32,9 +32,9 @@ limitations under the License. #define TFLITE_REDUCE_INSTANTIATIONS_OPEN_SOURCE #define Eigen EigenForTFLite #if defined(TFLITE_REDUCE_INSTANTIATIONS_GOOGLE) -#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h" +#include "tensorflow/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h" #elif defined(TFLITE_REDUCE_INSTANTIATIONS_OPEN_SOURCE) -#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h" +#include "tensorflow/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h" #else #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #endif @@ -226,4 +226,4 @@ EIGEN_DEVICE_FUNC // clang-format on -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h b/tensorflow/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h similarity index 95% rename from tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h rename to tensorflow/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h index 6443f425b7d643..f71ddbf3220db5 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h +++ b/tensorflow/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ #define EIGEN_USE_CUSTOM_THREAD_POOL #define EIGEN_USE_THREADS @@ -140,4 +140,4 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h" #include "Eigen/src/Core/util/ReenableStupidWarnings.h" -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h b/tensorflow/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h similarity index 95% rename from tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h rename to tensorflow/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h index d34708b8fd0c07..5e83b7b846e33b 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h +++ b/tensorflow/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h @@ -19,8 +19,8 @@ limitations under the License. // clang-format off -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ #include "Eigen/Core" @@ -164,4 +164,4 @@ typedef unsigned __int64 uint64_t; #include "Eigen/src/Core/util/ReenableStupidWarnings.h" -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h similarity index 99% rename from tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h rename to tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h index 4218be20a4a08f..5485d907c29399 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h @@ -12,18 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_ #include #include -#include "tensorflow/contrib/lite/kernels/internal/common.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h" +#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h" +#include "tensorflow/lite/kernels/internal/types.h" namespace tflite { namespace optimized_ops { @@ -1869,4 +1869,4 @@ void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, } // namespace optimized_ops } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/lite/kernels/internal/optimized/multithreaded_conv.h similarity index 92% rename from tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h rename to tensorflow/lite/kernels/internal/optimized/multithreaded_conv.h index 4139cf4eba98d1..12dfd1abb61972 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h +++ b/tensorflow/lite/kernels/internal/optimized/multithreaded_conv.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_MULTITHREADED_CONV_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_MULTITHREADED_CONV_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_MULTITHREADED_CONV_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_MULTITHREADED_CONV_H_ #include #include @@ -26,11 +26,11 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/kernels/internal/common.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/types.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/optimized/eigen_spatial_convolutions.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/types.h" namespace tflite { namespace multithreaded_ops { @@ -174,4 +174,4 @@ inline void Conv(const Eigen::ThreadPoolDevice& device, } // namespace multithreaded_ops } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_MULTITHREADED_CONV_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_MULTITHREADED_CONV_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc rename to tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc index 36c15dbc578930..cf40ebb241d013 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -15,12 +15,12 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/kernels/activation_functor.h" -#include "tensorflow/contrib/lite/kernels/internal/common.h" -#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h" -#include "tensorflow/contrib/lite/kernels/internal/round.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/kernels/activation_functor.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/optimized/tensor_utils_impl.h" +#include "tensorflow/lite/kernels/internal/round.h" #ifdef USE_NEON diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h similarity index 93% rename from tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h rename to tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h index 630a6bbf297086..903f4c80139cd3 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ // TODO(ghodrat): Remove this header file and the dependency to internal data // structure. -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" +#include "tensorflow/lite/kernels/internal/optimized/tensor_utils_impl.h" namespace tflite { namespace tensor_utils { @@ -153,4 +153,4 @@ void MeanStddevNormalization(const float* input_vector, float* output_vector, } // namespace tensor_utils } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h similarity index 99% rename from tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h rename to tensorflow/lite/kernels/internal/optimized/optimized_ops.h index 7d2f53fbe66319..6f7031b36d25e8 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_ #include #include @@ -29,13 +29,13 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "fixedpoint/fixedpoint.h" #include "public/gemmlowp.h" -#include "tensorflow/contrib/lite/kernels/internal/common.h" -#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/round.h" -#include "tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" -#include "tensorflow/contrib/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/round.h" +#include "tensorflow/lite/kernels/internal/strided_slice_logic.h" +#include "tensorflow/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/lite/kernels/internal/types.h" namespace tflite { namespace optimized_ops { @@ -5990,4 +5990,4 @@ inline void ResizeNearestNeighbor( #pragma GCC diagnostic pop #endif -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/lite/kernels/internal/optimized/tensor_utils_impl.h similarity index 96% rename from tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h rename to tensorflow/lite/kernels/internal/optimized/tensor_utils_impl.h index f87760a6c3e4a5..8f52ef131dedf4 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h +++ b/tensorflow/lite/kernels/internal/optimized/tensor_utils_impl.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_ // TODO(ghodrat): Remove this header file and the dependency to internal data // structure. -#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/builtin_op_data.h" #if defined(_MSC_VER) #define __restrict__ __restrict @@ -183,4 +183,4 @@ void PortableMeanStddevNormalization(const float* input_vector, } // namespace tensor_utils } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/lite/kernels/internal/quantization_util.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/internal/quantization_util.cc rename to tensorflow/lite/kernels/internal/quantization_util.cc index 544ef16ce18a36..0279d2a9229e02 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc +++ b/tensorflow/lite/kernels/internal/quantization_util.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" -#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" -#include "tensorflow/contrib/lite/kernels/internal/round.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/round.h" namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/lite/kernels/internal/quantization_util.h similarity index 96% rename from tensorflow/contrib/lite/kernels/internal/quantization_util.h rename to tensorflow/lite/kernels/internal/quantization_util.h index d74a1bac97f86c..bf313f39cd8b40 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.h +++ b/tensorflow/lite/kernels/internal/quantization_util.h @@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_QUANTIZATION_UTIL_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_QUANTIZATION_UTIL_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_QUANTIZATION_UTIL_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_QUANTIZATION_UTIL_H_ #include #include #include -#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" -#include "tensorflow/contrib/lite/kernels/internal/round.h" -#include "tensorflow/contrib/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/round.h" +#include "tensorflow/lite/kernels/internal/types.h" namespace tflite { @@ -277,4 +277,4 @@ bool CheckedLog2(const float x, int* log2_result); } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_QUANTIZATION_UTIL_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_QUANTIZATION_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/lite/kernels/internal/quantization_util_test.cc similarity index 99% rename from tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc rename to tensorflow/lite/kernels/internal/quantization_util_test.cc index 25ea72b886a06e..2f8f7713795bf0 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc +++ b/tensorflow/lite/kernels/internal/quantization_util_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" #include #include diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h b/tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h similarity index 90% rename from tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h rename to tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h index 11224270a4b17f..0cecb16b48c919 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h +++ b/tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ -#include "tensorflow/contrib/lite/kernels/internal/common.h" -#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" -#include "tensorflow/contrib/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/types.h" namespace tflite { namespace reference_ops { @@ -97,4 +97,4 @@ inline void DepthwiseConv( } // end namespace reference_ops } // end namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h similarity index 91% rename from tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h rename to tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h index eab28e6c84c77f..002444b6810925 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h +++ b/tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h @@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ #include #include "fixedpoint/fixedpoint.h" -#include "tensorflow/contrib/lite/kernels/internal/common.h" -#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" -#include "tensorflow/contrib/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/types.h" namespace tflite { namespace reference_ops { @@ -109,4 +109,4 @@ inline void DepthwiseConv( } // end namespace reference_ops } // end namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h b/tensorflow/lite/kernels/internal/reference/fully_connected.h similarity index 96% rename from tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h rename to tensorflow/lite/kernels/internal/reference/fully_connected.h index 3c7fd292567131..8495452220b8e7 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h +++ b/tensorflow/lite/kernels/internal/reference/fully_connected.h @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_FULLY_CONNECTED_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_FULLY_CONNECTED_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_FULLY_CONNECTED_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_FULLY_CONNECTED_H_ #include "fixedpoint/fixedpoint.h" -#include "tensorflow/contrib/lite/kernels/internal/common.h" -#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" -#include "tensorflow/contrib/lite/kernels/internal/round.h" -#include "tensorflow/contrib/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/round.h" +#include "tensorflow/lite/kernels/internal/types.h" namespace tflite { namespace reference_ops { @@ -323,4 +323,4 @@ inline void ShuffledFullyConnected( } // namespace reference_ops } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_FULLY_CONNECTED_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_FULLY_CONNECTED_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h similarity index 99% rename from tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h rename to tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h index c8b64cfd96798c..c92f28c79efed0 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h @@ -12,17 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_ #include #include -#include "tensorflow/contrib/lite/kernels/internal/common.h" -#include "tensorflow/contrib/lite/kernels/internal/legacy_types.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/legacy_types.h" +#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h" +#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" namespace tflite { @@ -803,6 +803,7 @@ inline void Gather(const T* input_data, const Dims<4>& input_dims, const Dims<4>& output_dims) { tflite::GatherParams op_params; op_params.input_rank = input_rank; + op_params.axis = 4 - input_rank; Gather(op_params, DimsToShape(input_dims), input_data, DimsToShape(coords_dims), coords_data, DimsToShape(output_dims), @@ -2122,4 +2123,4 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims, } // namespace reference_ops } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc similarity index 97% rename from tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc rename to tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc index 70d25c4bd9357a..d692063a968dab 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -16,10 +16,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/kernels/activation_functor.h" -#include "tensorflow/contrib/lite/kernels/internal/round.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/kernels/activation_functor.h" +#include "tensorflow/lite/kernels/internal/round.h" +#include "tensorflow/lite/kernels/op_macros.h" #if defined(_MSC_VER) #define __restrict__ __restrict diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h similarity index 97% rename from tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h rename to tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h index 714b1164ee2d84..a06ebc1600d4fe 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ // TODO(ghodrat): Remove this header file and the dependency to internal data // structure. -#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/builtin_op_data.h" #if defined(_MSC_VER) #define __restrict__ __restrict @@ -265,4 +265,4 @@ void MeanStddevNormalization(const float* input_vector, float* output_vector, } // namespace tensor_utils } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h similarity index 98% rename from tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h rename to tensorflow/lite/kernels/internal/reference/reference_ops.h index 6bee1439a98c08..b1fefbef04c87e 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ #include #include @@ -26,13 +26,13 @@ limitations under the License. #include "fixedpoint/fixedpoint.h" #include "public/gemmlowp.h" -#include "tensorflow/contrib/lite/kernels/internal/common.h" -#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/softmax.h" -#include "tensorflow/contrib/lite/kernels/internal/round.h" -#include "tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h" -#include "tensorflow/contrib/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" +#include "tensorflow/lite/kernels/internal/reference/softmax.h" +#include "tensorflow/lite/kernels/internal/round.h" +#include "tensorflow/lite/kernels/internal/strided_slice_logic.h" +#include "tensorflow/lite/kernels/internal/types.h" namespace tflite { @@ -2938,39 +2938,37 @@ inline void Floor(const RuntimeShape& input_shape, const float* input_data, template inline void Gather(const tflite::GatherParams& op_params, - const RuntimeShape& unextended_input_shape, - const T* input_data, const RuntimeShape& coords_shape, - const int32* coords_data, - const RuntimeShape& unextended_output_shape, - T* output_data) { - TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); - const RuntimeShape input_shape = - RuntimeShape::ExtendedShape(4, unextended_input_shape); - const RuntimeShape output_shape = - RuntimeShape::ExtendedShape(4, unextended_output_shape); - - const int input_rank = op_params.input_rank; - const int gather_dimensions = output_shape.DimensionsCount(); - TFLITE_DCHECK_GE(input_shape.DimensionsCount(), gather_dimensions); - const int axis = gather_dimensions - input_rank; - TFLITE_DCHECK_LT(axis, gather_dimensions); + const RuntimeShape& input_shape, const T* input_data, + const RuntimeShape& coords_shape, const int32* coords_data, + const RuntimeShape& output_shape, T* output_data) { + int axis = op_params.axis; + if (axis < 0) { + axis += input_shape.DimensionsCount(); + } TFLITE_DCHECK_GE(axis, 0); + TFLITE_DCHECK_LT(axis, input_shape.DimensionsCount()); + const int axis_size = input_shape.Dims(axis); const int coords_count = coords_shape.FlatSize(); - TFLITE_DCHECK_EQ(coords_count, output_shape.Dims(axis)); - int64_t stride = 1; - for (int i = axis + 1; i < gather_dimensions; ++i) { - stride *= input_shape.Dims(i); + int outer_size = 1; + for (int i = 0; i < axis; ++i) { + outer_size *= input_shape.Dims(i); } - T* out = output_data; - for (int i = 0; i < coords_count; ++i) { - TFLITE_DCHECK_GE(coords_data[i], 0); - TFLITE_DCHECK_LT(coords_data[i], input_shape.Dims(axis)); - const T* in = input_data + coords_data[i] * stride; - memcpy(out, in, sizeof(T) * stride); - out += stride; + int inner_size = 1; + for (int i = axis + 1; i < input_shape.DimensionsCount(); ++i) { + inner_size *= input_shape.Dims(i); + } + + for (int outer = 0; outer < outer_size; ++outer) { + for (int i = 0; i < coords_count; ++i) { + TFLITE_DCHECK_GE(coords_data[i], 0); + TFLITE_DCHECK_LT(coords_data[i], axis_size); + std::memcpy( + output_data + (outer * coords_count + i) * inner_size, + input_data + (outer * axis_size + coords_data[i]) * inner_size, + sizeof(T) * inner_size); + } } } @@ -4467,4 +4465,4 @@ inline void ResizeNearestNeighbor( } // namespace reference_ops } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/softmax.h b/tensorflow/lite/kernels/internal/reference/softmax.h similarity index 92% rename from tensorflow/contrib/lite/kernels/internal/reference/softmax.h rename to tensorflow/lite/kernels/internal/reference/softmax.h index 7d442961349e34..51de6b51aa5308 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/softmax.h +++ b/tensorflow/lite/kernels/internal/reference/softmax.h @@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_ #include "fixedpoint/fixedpoint.h" -#include "tensorflow/contrib/lite/kernels/internal/common.h" -#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" -#include "tensorflow/contrib/lite/kernels/internal/round.h" -#include "tensorflow/contrib/lite/kernels/internal/types.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/round.h" +#include "tensorflow/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace reference_ops { @@ -176,4 +176,4 @@ inline void Softmax(const float* in, const int input_size, const int batch_size, } // namespace reference_ops } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc b/tensorflow/lite/kernels/internal/resize_bilinear_test.cc similarity index 95% rename from tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc rename to tensorflow/lite/kernels/internal/resize_bilinear_test.cc index 15df31f75a69b9..1c5ac1992f0f64 100644 --- a/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc +++ b/tensorflow/lite/kernels/internal/resize_bilinear_test.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/test_util.h" -#include "tensorflow/contrib/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/test_util.h" +#include "tensorflow/lite/kernels/internal/types.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/internal/resize_nearest_neighbor_test.cc b/tensorflow/lite/kernels/internal/resize_nearest_neighbor_test.cc similarity index 97% rename from tensorflow/contrib/lite/kernels/internal/resize_nearest_neighbor_test.cc rename to tensorflow/lite/kernels/internal/resize_nearest_neighbor_test.cc index a26410e78c67f1..102ee04e6a89bd 100644 --- a/tensorflow/contrib/lite/kernels/internal/resize_nearest_neighbor_test.cc +++ b/tensorflow/lite/kernels/internal/resize_nearest_neighbor_test.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/test_util.h" -#include "tensorflow/contrib/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/test_util.h" +#include "tensorflow/lite/kernels/internal/types.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/internal/round.h b/tensorflow/lite/kernels/internal/round.h similarity index 86% rename from tensorflow/contrib/lite/kernels/internal/round.h rename to tensorflow/lite/kernels/internal/round.h index f299d0bd8733dc..cb494bfd5374d9 100644 --- a/tensorflow/contrib/lite/kernels/internal/round.h +++ b/tensorflow/lite/kernels/internal/round.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_ROUND_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_ROUND_H_ #include @@ -36,4 +36,4 @@ inline T TfLiteRound(const T x) { } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_ROUND_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc b/tensorflow/lite/kernels/internal/softmax_quantized_test.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc rename to tensorflow/lite/kernels/internal/softmax_quantized_test.cc index 831fb3c24353b2..743ce0355c96fd 100644 --- a/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/softmax_quantized_test.cc @@ -23,11 +23,11 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/test_util.h" -#include "tensorflow/contrib/lite/string.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/test_util.h" +#include "tensorflow/lite/string.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/internal/spectrogram.cc b/tensorflow/lite/kernels/internal/spectrogram.cc similarity index 99% rename from tensorflow/contrib/lite/kernels/internal/spectrogram.cc rename to tensorflow/lite/kernels/internal/spectrogram.cc index 20abcb725859d0..58769ad8cc7a06 100644 --- a/tensorflow/contrib/lite/kernels/internal/spectrogram.cc +++ b/tensorflow/lite/kernels/internal/spectrogram.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/kernels/internal/spectrogram.h" +#include "tensorflow/lite/kernels/internal/spectrogram.h" #include #include diff --git a/tensorflow/contrib/lite/kernels/internal/spectrogram.h b/tensorflow/lite/kernels/internal/spectrogram.h similarity index 95% rename from tensorflow/contrib/lite/kernels/internal/spectrogram.h rename to tensorflow/lite/kernels/internal/spectrogram.h index b77a68f7dfe6ed..b885b9d7d2d845 100644 --- a/tensorflow/contrib/lite/kernels/internal/spectrogram.h +++ b/tensorflow/lite/kernels/internal/spectrogram.h @@ -28,8 +28,8 @@ limitations under the License. // window = hann(window_length_samples, 'periodic'); // S = abs(spectrogram(audio, window, overlap_samples)).^2; -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_SPECTROGRAM_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_SPECTROGRAM_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_SPECTROGRAM_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_SPECTROGRAM_H_ #include #include @@ -107,4 +107,4 @@ class Spectrogram { } // namespace internal } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_SPECTROGRAM_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_SPECTROGRAM_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h b/tensorflow/lite/kernels/internal/strided_slice_logic.h similarity index 94% rename from tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h rename to tensorflow/lite/kernels/internal/strided_slice_logic.h index af5db1064c1b7b..e7fd5ca9319556 100644 --- a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h +++ b/tensorflow/lite/kernels/internal/strided_slice_logic.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_ #include #include -#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" -#include "tensorflow/contrib/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/types.h" namespace tflite { namespace strided_slice { @@ -195,4 +195,4 @@ inline tflite::StridedSliceParams BuildStridedSliceParams( } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/lite/kernels/internal/tensor.h similarity index 91% rename from tensorflow/contrib/lite/kernels/internal/tensor.h rename to tensorflow/lite/kernels/internal/tensor.h index 689cea03e75875..b806753d886132 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor.h +++ b/tensorflow/lite/kernels/internal/tensor.h @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_H_ #include #include -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h" -#include "tensorflow/contrib/lite/kernels/internal/types.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/internal/types.h" namespace tflite { @@ -111,4 +111,4 @@ class VectorOfQuantizedTensors : public VectorOfTensors { } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h b/tensorflow/lite/kernels/internal/tensor_ctypes.h similarity index 89% rename from tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h rename to tensorflow/lite/kernels/internal/tensor_ctypes.h index 9f5b33d2175351..d24dca9bfbbee7 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h +++ b/tensorflow/lite/kernels/internal/tensor_ctypes.h @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_ -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/types.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/types.h" namespace tflite { @@ -99,4 +99,4 @@ inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) { } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_test.cc b/tensorflow/lite/kernels/internal/tensor_test.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/internal/tensor_test.cc rename to tensorflow/lite/kernels/internal/tensor_test.cc index 2ed73ba82d6473..7bfe280d6f883c 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_test.cc +++ b/tensorflow/lite/kernels/internal/tensor_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/internal/tensor.h" #include #include diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc b/tensorflow/lite/kernels/internal/tensor_utils.cc similarity index 74% rename from tensorflow/contrib/lite/kernels/internal/tensor_utils.cc rename to tensorflow/lite/kernels/internal/tensor_utils.cc index f4181b18a8f46f..701e5a66aa1bac 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/tensor_utils.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" -#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/lite/kernels/internal/common.h" #ifndef USE_NEON #if defined(__ARM_NEON__) || defined(__ARM_NEON) @@ -22,7 +22,7 @@ limitations under the License. #endif // USE_NEON #ifdef USE_NEON -#include "tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h" +#include "tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h" #else -#include "tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h" +#include "tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h" #endif // USE_NEON diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/lite/kernels/internal/tensor_utils.h similarity index 96% rename from tensorflow/contrib/lite/kernels/internal/tensor_utils.h rename to tensorflow/lite/kernels/internal/tensor_utils.h index b0fe5adf65de83..71ae69522f9a45 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h +++ b/tensorflow/lite/kernels/internal/tensor_utils.h @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/builtin_op_data.h" #if defined(_MSC_VER) #define __restrict__ __restrict @@ -165,4 +165,4 @@ void MeanStddevNormalization(const float* input_vector, float* output_vector, } // namespace tensor_utils } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/lite/kernels/internal/tensor_utils_test.cc similarity index 99% rename from tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc rename to tensorflow/lite/kernels/internal/tensor_utils_test.cc index 6458af714b8c71..29866d066406e5 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc +++ b/tensorflow/lite/kernels/internal/tensor_utils_test.cc @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/lite/kernels/internal/tensor_utils.h" #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/kernels/test_util.h" namespace tflite { namespace tensor_utils { diff --git a/tensorflow/contrib/lite/kernels/internal/test_util.cc b/tensorflow/lite/kernels/internal/test_util.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/internal/test_util.cc rename to tensorflow/lite/kernels/internal/test_util.cc index 390d57a0108870..4462775ddbdd25 100644 --- a/tensorflow/contrib/lite/kernels/internal/test_util.cc +++ b/tensorflow/lite/kernels/internal/test_util.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/kernels/internal/test_util.h" +#include "tensorflow/lite/kernels/internal/test_util.h" #include #include diff --git a/tensorflow/contrib/lite/kernels/internal/test_util.h b/tensorflow/lite/kernels/internal/test_util.h similarity index 93% rename from tensorflow/contrib/lite/kernels/internal/test_util.h rename to tensorflow/lite/kernels/internal/test_util.h index e4a383bedfc034..766a627c99e03b 100644 --- a/tensorflow/contrib/lite/kernels/internal/test_util.h +++ b/tensorflow/lite/kernels/internal/test_util.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TEST_UTIL_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TEST_UTIL_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_TEST_UTIL_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_TEST_UTIL_H_ #include #include @@ -22,7 +22,7 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/internal/types.h" namespace tflite { @@ -100,4 +100,4 @@ void FillRandomSkyscraper(std::vector* vec, int depth, } } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TEST_UTIL_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TEST_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/lite/kernels/internal/types.h similarity index 98% rename from tensorflow/contrib/lite/kernels/internal/types.h rename to tensorflow/lite/kernels/internal/types.h index 694c1797da27e9..04b95ddc63d7a6 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/lite/kernels/internal/types.h @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_ #include #include -#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" namespace tflite { @@ -86,7 +86,7 @@ enum class FullyConnectedWeightsFormat : uint8 { // bytes before using them in signed arithmetic, see this file for more // explanations on the 'signed int8 trick' in matrix multiplication kernels: // - // tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc + // tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc // kShuffled4x16Int8, }; @@ -1033,4 +1033,4 @@ inline void GetActivationParams(const P& params, float* min, float* max) { } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_ diff --git a/tensorflow/contrib/lite/kernels/kernel_util.cc b/tensorflow/lite/kernels/kernel_util.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/kernel_util.cc rename to tensorflow/lite/kernels/kernel_util.cc index 503ef284591912..e39890e3320eb4 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util.cc +++ b/tensorflow/lite/kernels/kernel_util.cc @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/kernel_util.h" #include #include #include -#include "tensorflow/contrib/lite/kernels/internal/round.h" +#include "tensorflow/lite/kernels/internal/round.h" namespace tflite { @@ -117,6 +117,10 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context, int64_t dims1 = NumDimensions(input1); int64_t dims2 = NumDimensions(input2); int64_t out_dims = std::max(dims1, dims2); + if (NumElements(input1) == 0) { + *output_shape = TfLiteIntArrayCopy(input1->dims); + return kTfLiteOk; + } std::unique_ptr shape( TfLiteIntArrayCreate(out_dims), TfLiteIntArrayFree); for (int i = 0; i < out_dims; ++i) { diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/lite/kernels/kernel_util.h similarity index 95% rename from tensorflow/contrib/lite/kernels/kernel_util.h rename to tensorflow/lite/kernels/kernel_util.h index e9a5fd7a4052cd..3cc00588d63fed 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util.h +++ b/tensorflow/lite/kernels/kernel_util.h @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ +#ifndef TENSORFLOW_LITE_KERNELS_KERNEL_UTIL_H_ +#define TENSORFLOW_LITE_KERNELS_KERNEL_UTIL_H_ #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" namespace tflite { @@ -135,4 +135,4 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context, TfLiteIntArray** output_shape); } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ +#endif // TENSORFLOW_LITE_KERNELS_KERNEL_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/kernel_util_test.cc b/tensorflow/lite/kernels/kernel_util_test.cc similarity index 97% rename from tensorflow/contrib/lite/kernels/kernel_util_test.cc rename to tensorflow/lite/kernels/kernel_util_test.cc index bf6f249acc85ee..70eb1836589109 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util_test.cc +++ b/tensorflow/lite/kernels/kernel_util_test.cc @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/kernel_util.h" #include #include -#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/lite/testing/util.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/lite/kernels/l2norm.cc similarity index 90% rename from tensorflow/contrib/lite/kernels/l2norm.cc rename to tensorflow/lite/kernels/l2norm.cc index e02d7df9ef1a38..19a4824e9398de 100644 --- a/tensorflow/contrib/lite/kernels/l2norm.cc +++ b/tensorflow/lite/kernels/l2norm.cc @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/l2norm_test.cc b/tensorflow/lite/kernels/l2norm_test.cc similarity index 95% rename from tensorflow/contrib/lite/kernels/l2norm_test.cc rename to tensorflow/lite/kernels/l2norm_test.cc index 070ed60040997f..50108a5a264c36 100644 --- a/tensorflow/contrib/lite/kernels/l2norm_test.cc +++ b/tensorflow/lite/kernels/l2norm_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc b/tensorflow/lite/kernels/layer_norm_lstm.cc similarity index 99% rename from tensorflow/contrib/lite/kernels/layer_norm_lstm.cc rename to tensorflow/lite/kernels/layer_norm_lstm.cc index 48dd03e7ae7e2a..5b0046a7b31c9c 100644 --- a/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc +++ b/tensorflow/lite/kernels/layer_norm_lstm.cc @@ -17,9 +17,9 @@ limitations under the License. // deviation to the activation of the LSTM layers. Please see // https://arxiv.org/abs/1607.06450 for details. #include "flatbuffers/flexbuffers.h" // TF:flatbuffers -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/lite/kernels/kernel_util.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc b/tensorflow/lite/kernels/layer_norm_lstm_test.cc similarity index 99% rename from tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc rename to tensorflow/lite/kernels/layer_norm_lstm_test.cc index 1535f750f94e72..e89bce50c311eb 100644 --- a/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc +++ b/tensorflow/lite/kernels/layer_norm_lstm_test.cc @@ -20,10 +20,10 @@ limitations under the License. #include #include #include "flatbuffers/flexbuffers.h" // TF:flatbuffers -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/local_response_norm.cc b/tensorflow/lite/kernels/local_response_norm.cc similarity index 89% rename from tensorflow/contrib/lite/kernels/local_response_norm.cc rename to tensorflow/lite/kernels/local_response_norm.cc index 334d2a2788d10f..5cbf5d9eae700f 100644 --- a/tensorflow/contrib/lite/kernels/local_response_norm.cc +++ b/tensorflow/lite/kernels/local_response_norm.cc @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/local_response_norm_test.cc b/tensorflow/lite/kernels/local_response_norm_test.cc similarity index 94% rename from tensorflow/contrib/lite/kernels/local_response_norm_test.cc rename to tensorflow/lite/kernels/local_response_norm_test.cc index d75ce258a04c82..bd644e07f46562 100644 --- a/tensorflow/contrib/lite/kernels/local_response_norm_test.cc +++ b/tensorflow/lite/kernels/local_response_norm_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/log_softmax_test.cc b/tensorflow/lite/kernels/log_softmax_test.cc similarity index 91% rename from tensorflow/contrib/lite/kernels/log_softmax_test.cc rename to tensorflow/lite/kernels/log_softmax_test.cc index 1acc966cdc947c..fb126295e6afdf 100644 --- a/tensorflow/contrib/lite/kernels/log_softmax_test.cc +++ b/tensorflow/lite/kernels/log_softmax_test.cc @@ -20,11 +20,11 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/logical.cc b/tensorflow/lite/kernels/logical.cc similarity index 93% rename from tensorflow/contrib/lite/kernels/logical.cc rename to tensorflow/lite/kernels/logical.cc index f770cb35d1b9ff..582bcff64a882e 100644 --- a/tensorflow/contrib/lite/kernels/logical.cc +++ b/tensorflow/lite/kernels/logical.cc @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/logical_test.cc b/tensorflow/lite/kernels/logical_test.cc similarity index 94% rename from tensorflow/contrib/lite/kernels/logical_test.cc rename to tensorflow/lite/kernels/logical_test.cc index 206cbde98fa48e..b31616452717b1 100644 --- a/tensorflow/contrib/lite/kernels/logical_test.cc +++ b/tensorflow/lite/kernels/logical_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/lsh_projection.cc b/tensorflow/lite/kernels/lsh_projection.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/lsh_projection.cc rename to tensorflow/lite/kernels/lsh_projection.cc index 9fa1c5f1002d89..f68ff4d634a7c9 100644 --- a/tensorflow/contrib/lite/kernels/lsh_projection.cc +++ b/tensorflow/lite/kernels/lsh_projection.cc @@ -59,10 +59,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" #include namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/lsh_projection_test.cc b/tensorflow/lite/kernels/lsh_projection_test.cc similarity index 94% rename from tensorflow/contrib/lite/kernels/lsh_projection_test.cc rename to tensorflow/lite/kernels/lsh_projection_test.cc index 414d728dfc1530..cb2724a6ccebd9 100644 --- a/tensorflow/contrib/lite/kernels/lsh_projection_test.cc +++ b/tensorflow/lite/kernels/lsh_projection_test.cc @@ -16,10 +16,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/lite/kernels/lstm.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/lstm.cc rename to tensorflow/lite/kernels/lstm.cc index 3666122e941f55..b57e2883b05232 100644 --- a/tensorflow/contrib/lite/kernels/lstm.cc +++ b/tensorflow/lite/kernels/lstm.cc @@ -20,17 +20,17 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/activation_functor.h" -#include "tensorflow/contrib/lite/kernels/gemm_support.h" -#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/lstm_eval.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/activation_functor.h" +#include "tensorflow/lite/kernels/gemm_support.h" +#include "tensorflow/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/lstm_eval.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/lstm_eval.cc b/tensorflow/lite/kernels/lstm_eval.cc similarity index 99% rename from tensorflow/contrib/lite/kernels/lstm_eval.cc rename to tensorflow/lite/kernels/lstm_eval.cc index f2ba7b46d9b053..f179ecb195e4dd 100644 --- a/tensorflow/contrib/lite/kernels/lstm_eval.cc +++ b/tensorflow/lite/kernels/lstm_eval.cc @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/kernels/lstm_eval.h" +#include "tensorflow/lite/kernels/lstm_eval.h" #include -#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/lstm_eval.h b/tensorflow/lite/kernels/lstm_eval.h similarity index 93% rename from tensorflow/contrib/lite/kernels/lstm_eval.h rename to tensorflow/lite/kernels/lstm_eval.h index 8d8b97aead63f2..c8a4d284f3c431 100644 --- a/tensorflow/contrib/lite/kernels/lstm_eval.h +++ b/tensorflow/lite/kernels/lstm_eval.h @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_LSTM_EVAL_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_LSTM_EVAL_H_ +#ifndef TENSORFLOW_LITE_KERNELS_LSTM_EVAL_H_ +#define TENSORFLOW_LITE_KERNELS_LSTM_EVAL_H_ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" namespace tflite { namespace ops { @@ -78,4 +78,4 @@ TfLiteStatus EvalHybrid( } // namespace builtin } // namespace ops } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_LSTM_EVAL_H_ +#endif // TENSORFLOW_LITE_KERNELS_LSTM_EVAL_H_ diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/lite/kernels/lstm_test.cc similarity index 99% rename from tensorflow/contrib/lite/kernels/lstm_test.cc rename to tensorflow/lite/kernels/lstm_test.cc index f8947db7242174..03ad2e899d29b1 100644 --- a/tensorflow/contrib/lite/kernels/lstm_test.cc +++ b/tensorflow/lite/kernels/lstm_test.cc @@ -22,10 +22,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/maximum_minimum.cc b/tensorflow/lite/kernels/maximum_minimum.cc similarity index 93% rename from tensorflow/contrib/lite/kernels/maximum_minimum.cc rename to tensorflow/lite/kernels/maximum_minimum.cc index 7cb01465eef45c..3bcaabf675eba4 100644 --- a/tensorflow/contrib/lite/kernels/maximum_minimum.cc +++ b/tensorflow/lite/kernels/maximum_minimum.cc @@ -14,12 +14,12 @@ limitations under the License. ==============================================================================*/ #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/maximum_minimum_test.cc b/tensorflow/lite/kernels/maximum_minimum_test.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/maximum_minimum_test.cc rename to tensorflow/lite/kernels/maximum_minimum_test.cc index fd4d5367c5a636..acb74e09d3fb47 100644 --- a/tensorflow/contrib/lite/kernels/maximum_minimum_test.cc +++ b/tensorflow/lite/kernels/maximum_minimum_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/mfcc.cc b/tensorflow/lite/kernels/mfcc.cc similarity index 89% rename from tensorflow/contrib/lite/kernels/mfcc.cc rename to tensorflow/lite/kernels/mfcc.cc index 5153ce5634c33e..f5b0212728e02b 100644 --- a/tensorflow/contrib/lite/kernels/mfcc.cc +++ b/tensorflow/lite/kernels/mfcc.cc @@ -12,17 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/kernels/internal/mfcc.h" +#include "tensorflow/lite/kernels/internal/mfcc.h" #include "flatbuffers/flexbuffers.h" // TF:flatbuffers -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/mfcc_dct.h" -#include "tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/mfcc_dct.h" +#include "tensorflow/lite/kernels/internal/mfcc_mel_filterbank.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/mfcc_test.cc b/tensorflow/lite/kernels/mfcc_test.cc similarity index 93% rename from tensorflow/contrib/lite/kernels/mfcc_test.cc rename to tensorflow/lite/kernels/mfcc_test.cc index fe692232227966..ade5bf53d11f7d 100644 --- a/tensorflow/contrib/lite/kernels/mfcc_test.cc +++ b/tensorflow/lite/kernels/mfcc_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include "flatbuffers/flexbuffers.h" // TF:flatbuffers -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/lite/kernels/mul.cc similarity index 95% rename from tensorflow/contrib/lite/kernels/mul.cc rename to tensorflow/lite/kernels/mul.cc index e0aac8a84244dd..b405dee47ef01e 100644 --- a/tensorflow/contrib/lite/kernels/mul.cc +++ b/tensorflow/lite/kernels/mul.cc @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/mul_test.cc b/tensorflow/lite/kernels/mul_test.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/mul_test.cc rename to tensorflow/lite/kernels/mul_test.cc index 0f9c0c2eee51e7..200cc26dadc352 100644 --- a/tensorflow/contrib/lite/kernels/mul_test.cc +++ b/tensorflow/lite/kernels/mul_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/neg.cc b/tensorflow/lite/kernels/neg.cc similarity index 95% rename from tensorflow/contrib/lite/kernels/neg.cc rename to tensorflow/lite/kernels/neg.cc index 0ddd0644f5a1cc..e9a1aa23254230 100644 --- a/tensorflow/contrib/lite/kernels/neg.cc +++ b/tensorflow/lite/kernels/neg.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/kernel_util.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/neg_test.cc b/tensorflow/lite/kernels/neg_test.cc similarity index 91% rename from tensorflow/contrib/lite/kernels/neg_test.cc rename to tensorflow/lite/kernels/neg_test.cc index 3d3594c60bbe16..d461ede3c480e2 100644 --- a/tensorflow/contrib/lite/kernels/neg_test.cc +++ b/tensorflow/lite/kernels/neg_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/one_hot.cc b/tensorflow/lite/kernels/one_hot.cc similarity index 95% rename from tensorflow/contrib/lite/kernels/one_hot.cc rename to tensorflow/lite/kernels/one_hot.cc index 910aed6f142dc9..2ac12fe9308f38 100644 --- a/tensorflow/contrib/lite/kernels/one_hot.cc +++ b/tensorflow/lite/kernels/one_hot.cc @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/one_hot_test.cc b/tensorflow/lite/kernels/one_hot_test.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/one_hot_test.cc rename to tensorflow/lite/kernels/one_hot_test.cc index 6b604ec7a7f86b..85438327e7e3a3 100644 --- a/tensorflow/contrib/lite/kernels/one_hot_test.cc +++ b/tensorflow/lite/kernels/one_hot_test.cc @@ -16,10 +16,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/op_macros.h b/tensorflow/lite/kernels/op_macros.h similarity index 88% rename from tensorflow/contrib/lite/kernels/op_macros.h rename to tensorflow/lite/kernels/op_macros.h index d0c5630649c98f..1a54a378b03d64 100644 --- a/tensorflow/contrib/lite/kernels/op_macros.h +++ b/tensorflow/lite/kernels/op_macros.h @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_OP_MACROS_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_OP_MACROS_H_ +#ifndef TENSORFLOW_LITE_KERNELS_OP_MACROS_H_ +#define TENSORFLOW_LITE_KERNELS_OP_MACROS_H_ // If we're on a platform without standard IO functions, fall back to a // non-portable function. #ifdef TF_LITE_MCU_DEBUG_LOG -#include "tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h" +#include "tensorflow/lite/experimental/micro/micro_error_reporter.h" #define DEBUG_LOG(x) \ do { \ @@ -67,4 +67,4 @@ inline void InfiniteLoop() { if ((x) != (y)) TF_LITE_FATAL(#x " didn't equal " #y); \ } while (0) -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_OP_MACROS_H_ +#endif // TENSORFLOW_LITE_KERNELS_OP_MACROS_H_ diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/lite/kernels/optional_tensor_test.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/optional_tensor_test.cc rename to tensorflow/lite/kernels/optional_tensor_test.cc index 90a915bb023b2b..a09f86015894c4 100644 --- a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc +++ b/tensorflow/lite/kernels/optional_tensor_test.cc @@ -20,10 +20,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/pack.cc b/tensorflow/lite/kernels/pack.cc similarity index 87% rename from tensorflow/contrib/lite/kernels/pack.cc rename to tensorflow/lite/kernels/pack.cc index c368582ef76c72..24fabccde09fab 100644 --- a/tensorflow/contrib/lite/kernels/pack.cc +++ b/tensorflow/lite/kernels/pack.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" namespace tflite { namespace ops { @@ -41,9 +41,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, data->axis >= 0); if (input0->type != kTfLiteInt32 && input0->type != kTfLiteFloat32 && input0->type != kTfLiteUInt8 && input0->type != kTfLiteInt16) { - context->ReportError(context, - "Currently pack only supports " - "float32/uint8/int16/int32."); + context->ReportError(context, "Type '%s' is not supported by pack.", + TfLiteTypeGetName(input0->type)); return kTfLiteError; } // Make sure all inputs have the same shape and type. @@ -112,9 +111,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { break; } default: { - context->ReportError(context, - "Currently pack only supports " - "float32/uint8/int32."); + context->ReportError(context, "Type '%s' is not supported by pack.", + TfLiteTypeGetName(output->type)); return kTfLiteError; } } diff --git a/tensorflow/contrib/lite/kernels/pack_test.cc b/tensorflow/lite/kernels/pack_test.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/pack_test.cc rename to tensorflow/lite/kernels/pack_test.cc index c70dbd2764b615..a47e9ff40d079b 100644 --- a/tensorflow/contrib/lite/kernels/pack_test.cc +++ b/tensorflow/lite/kernels/pack_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/lite/kernels/pad.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/pad.cc rename to tensorflow/lite/kernels/pad.cc index a3cefe99ccb404..8e6ed6e741f782 100644 --- a/tensorflow/contrib/lite/kernels/pad.cc +++ b/tensorflow/lite/kernels/pad.cc @@ -14,13 +14,13 @@ limitations under the License. ==============================================================================*/ #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/pad_test.cc b/tensorflow/lite/kernels/pad_test.cc similarity index 97% rename from tensorflow/contrib/lite/kernels/pad_test.cc rename to tensorflow/lite/kernels/pad_test.cc index 9c55767f69ce3d..415a285c707e6a 100644 --- a/tensorflow/contrib/lite/kernels/pad_test.cc +++ b/tensorflow/lite/kernels/pad_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { @@ -58,19 +58,6 @@ class PadOpModel : public SingleOpModel { int constant_values_; }; -namespace { - -// Returns the corresponding TensorType given the type T. -template -TensorType GetTensorType() { - if (std::is_same::value) return TensorType_FLOAT32; - if (std::is_same::value) return TensorType_INT32; - if (std::is_same::value) return TensorType_UINT8; - return TensorType_MIN; // default value -} - -} // namespace - // Tests case where paddings is a const tensor. Type T is the dtype. template class PadV2OpConstModel : public PadOpModel { diff --git a/tensorflow/contrib/lite/kernels/padding.h b/tensorflow/lite/kernels/padding.h similarity index 90% rename from tensorflow/contrib/lite/kernels/padding.h rename to tensorflow/lite/kernels/padding.h index 42b6b45d3bfc4a..30aa4f1bd330e2 100644 --- a/tensorflow/contrib/lite/kernels/padding.h +++ b/tensorflow/lite/kernels/padding.h @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ +#ifndef TENSORFLOW_LITE_KERNELS_PADDING_H_ +#define TENSORFLOW_LITE_KERNELS_PADDING_H_ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/builtin_op_data.h" namespace tflite { @@ -55,4 +55,4 @@ inline TfLitePaddingValues ComputePaddingHeightWidth( } } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ +#endif // TENSORFLOW_LITE_KERNELS_PADDING_H_ diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/lite/kernels/pooling.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/pooling.cc rename to tensorflow/lite/kernels/pooling.cc index 6451142391599e..694a36ffbcf3c8 100644 --- a/tensorflow/contrib/lite/kernels/pooling.cc +++ b/tensorflow/lite/kernels/pooling.cc @@ -19,14 +19,14 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" -#include "tensorflow/contrib/lite/kernels/padding.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/kernels/padding.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/pooling_test.cc b/tensorflow/lite/kernels/pooling_test.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/pooling_test.cc rename to tensorflow/lite/kernels/pooling_test.cc index 01c91b2ba905e2..80eef02509009c 100644 --- a/tensorflow/contrib/lite/kernels/pooling_test.cc +++ b/tensorflow/lite/kernels/pooling_test.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/pow.cc b/tensorflow/lite/kernels/pow.cc similarity index 93% rename from tensorflow/contrib/lite/kernels/pow.cc rename to tensorflow/lite/kernels/pow.cc index 1e96cc80b16791..9f84e1cc5e6d83 100644 --- a/tensorflow/contrib/lite/kernels/pow.cc +++ b/tensorflow/lite/kernels/pow.cc @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/pow_test.cc b/tensorflow/lite/kernels/pow_test.cc similarity index 95% rename from tensorflow/contrib/lite/kernels/pow_test.cc rename to tensorflow/lite/kernels/pow_test.cc index 74b3aef5bd39d8..60d674e9779f0a 100644 --- a/tensorflow/contrib/lite/kernels/pow_test.cc +++ b/tensorflow/lite/kernels/pow_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/range.cc b/tensorflow/lite/kernels/range.cc similarity index 93% rename from tensorflow/contrib/lite/kernels/range.cc rename to tensorflow/lite/kernels/range.cc index 241b40574522b8..eefe5db1ecee7a 100644 --- a/tensorflow/contrib/lite/kernels/range.cc +++ b/tensorflow/lite/kernels/range.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/range_test.cc b/tensorflow/lite/kernels/range_test.cc similarity index 95% rename from tensorflow/contrib/lite/kernels/range_test.cc rename to tensorflow/lite/kernels/range_test.cc index 8faa092fb24d90..e1d4aaba433050 100644 --- a/tensorflow/contrib/lite/kernels/range_test.cc +++ b/tensorflow/lite/kernels/range_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/reduce.cc b/tensorflow/lite/kernels/reduce.cc similarity index 97% rename from tensorflow/contrib/lite/kernels/reduce.cc rename to tensorflow/lite/kernels/reduce.cc index 4732a37a65a37a..ed2d475f6d7d38 100644 --- a/tensorflow/contrib/lite/kernels/reduce.cc +++ b/tensorflow/lite/kernels/reduce.cc @@ -15,13 +15,13 @@ limitations under the License. #include #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/reduce_test.cc b/tensorflow/lite/kernels/reduce_test.cc similarity index 99% rename from tensorflow/contrib/lite/kernels/reduce_test.cc rename to tensorflow/lite/kernels/reduce_test.cc index fb2ec58ab28ebc..c1526bddb719e7 100644 --- a/tensorflow/contrib/lite/kernels/reduce_test.cc +++ b/tensorflow/lite/kernels/reduce_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc similarity index 99% rename from tensorflow/contrib/lite/kernels/register.cc rename to tensorflow/lite/kernels/register.cc index 92eddf7d79df1b..c6834537671034 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/util.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/util.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/lite/kernels/register.h similarity index 78% rename from tensorflow/contrib/lite/kernels/register.h rename to tensorflow/lite/kernels/register.h index 61856ab9de6563..eb5ce667d4c9eb 100644 --- a/tensorflow/contrib/lite/kernels/register.h +++ b/tensorflow/lite/kernels/register.h @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ +#ifndef TENSORFLOW_LITE_KERNELS_REGISTER_H_ +#define TENSORFLOW_LITE_KERNELS_REGISTER_H_ #include -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/mutable_op_resolver.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/mutable_op_resolver.h" namespace tflite { namespace ops { @@ -37,4 +37,4 @@ class BuiltinOpResolver : public MutableOpResolver { } // namespace ops } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ +#endif // TENSORFLOW_LITE_KERNELS_REGISTER_H_ diff --git a/tensorflow/contrib/lite/kernels/relu1.cc b/tensorflow/lite/kernels/relu1.cc similarity index 92% rename from tensorflow/contrib/lite/kernels/relu1.cc rename to tensorflow/lite/kernels/relu1.cc index abafee2d576fd7..5a55631405b6b3 100644 --- a/tensorflow/contrib/lite/kernels/relu1.cc +++ b/tensorflow/lite/kernels/relu1.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/relu1_test.cc b/tensorflow/lite/kernels/relu1_test.cc similarity index 95% rename from tensorflow/contrib/lite/kernels/relu1_test.cc rename to tensorflow/lite/kernels/relu1_test.cc index b1d25a9f504fe2..f52d10b0b7f32a 100644 --- a/tensorflow/contrib/lite/kernels/relu1_test.cc +++ b/tensorflow/lite/kernels/relu1_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include #include "flatbuffers/flexbuffers.h" // TF:flatbuffers -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/lite/kernels/reshape.cc similarity index 60% rename from tensorflow/contrib/lite/kernels/reshape.cc rename to tensorflow/lite/kernels/reshape.cc index f41147b2d6433a..d040c677019a8e 100644 --- a/tensorflow/contrib/lite/kernels/reshape.cc +++ b/tensorflow/lite/kernels/reshape.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { @@ -59,7 +59,7 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node, return context->ResizeTensor(context, output, output_shape); } -TfLiteStatus ResizeOutputWithShapeTensor(TfLiteContext* context, +TfLiteIntArray* GetOutputShapeFromTensor(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* shape = GetInput(context, node, kShapeTensor); @@ -67,30 +67,14 @@ TfLiteStatus ResizeOutputWithShapeTensor(TfLiteContext* context, for (int i = 0; i < output_shape->size; ++i) { output_shape->data[i] = shape->data.i32[i]; } - return ResizeOutput(context, node, output_shape); + + return output_shape; } -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { +TfLiteIntArray* GetOutputShapeFromParam(TfLiteContext* context, + TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); - TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2); - TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - - // Attempt to use shape tensor if it exists. - if (NumInputs(node) == 2) { - const TfLiteTensor* shape = GetInput(context, node, kShapeTensor); - // Check if the shape tensor is valid. - if (shape->dims->size == 1 && shape->type == kTfLiteInt32) { - // Set the output tensor as dynamic if the shape isn't constnat. - if (!IsConstantTensor(shape)) { - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - SetTensorToDynamic(output); - return kTfLiteOk; - } - // Shape is constant. Resize now. - return ResizeOutputWithShapeTensor(context, node); - } - } // The function is returned above this line if the shape tensor is usable. // Now fallback to the shape parameter in `TfLiteReshapeParams`. int num_dimensions = params->num_dimensions; @@ -104,15 +88,67 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { for (int i = 0; i < num_dimensions; ++i) { output_shape->data[i] = params->shape[i]; } - return ResizeOutput(context, node, output_shape); + + return output_shape; +} + +// Check if the shape tensor is valid. Shapes should be int32 vectors. +bool ShapeIsVector(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* shape = GetInput(context, node, kShapeTensor); + return (shape->dims->size == 1 && shape->type == kTfLiteInt32); +} + +TfLiteIntArray* GetOutputShape(TfLiteContext* context, TfLiteNode* node) { + if (NumInputs(node) == 2 && ShapeIsVector(context, node)) { + return GetOutputShapeFromTensor(context, node); + } else { + return GetOutputShapeFromParam(context, node); + } +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + // Always postpone sizing string tensors, even if we could in principle + // calculate their shapes now. String tensors don't benefit from having their + // shapes precalculated because the actual memory can only be allocated after + // we know all the content. + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + if (output->type != kTfLiteString) { + if (NumInputs(node) == 1 || + IsConstantTensor(GetInput(context, node, kShapeTensor))) { + TF_LITE_ENSURE_OK( + context, ResizeOutput(context, node, GetOutputShape(context, node))); + } else { + SetTensorToDynamic(output); + } + } + return kTfLiteOk; } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + // There are two ways in which the 'output' can be made dynamic: it could be + // a string tensor, or its shape cannot be calculated during Prepare(). In + // either case, we now have all the information to calculate its shape. if (IsDynamicTensor(output)) { - TF_LITE_ENSURE_OK(context, ResizeOutputWithShapeTensor(context, node)); + TF_LITE_ENSURE_OK( + context, ResizeOutput(context, node, GetOutputShape(context, node))); + } + + // Note that string tensors are always "dynamic" in the sense that their size + // is not known until we have all the content. This applies even when their + // shape is known ahead of time. As a result, a string tensor is never given + // any memory by ResizeOutput(), and we need to do it manually here. Since + // reshape doesn't change the data, the output tensor needs exactly as many + // bytes as the input tensor. + if (output->type == kTfLiteString) { + auto bytes_required = input->bytes; + TfLiteTensorRealloc(bytes_required, output); + output->bytes = bytes_required; } memcpy(output->data.raw, input->data.raw, input->bytes); diff --git a/tensorflow/lite/kernels/reshape_test.cc b/tensorflow/lite/kernels/reshape_test.cc new file mode 100644 index 00000000000000..00bbbef57eccef --- /dev/null +++ b/tensorflow/lite/kernels/reshape_test.cc @@ -0,0 +1,239 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; +using ::testing::IsEmpty; + +// There are three ways to specify the output shape of a Reshape +// op. +enum ShapeSpecificationType { + // The output shape is hardcoded in the ReshapeOptions object. + kAsReshapeOption, + // The output shape is specified as an input tensor, which is connected to a + // Const node, which is guaranteed not to change once inference starts. The + // shape is also hardcoded as in kAsReshapeOption. + kAsConstantTensor, + // The output shape is specifed as an input tensor that can change based on + // external input. That is, the shape is not know before the inference + // starts. The shape is also hardcoded as in kAsReshapeOption. + kAsTensor, +}; + +class ReshapeOpTest + : public ::testing::Test, + public ::testing::WithParamInterface {}; + +template +class ReshapeOpModel : public SingleOpModel { + public: + ReshapeOpModel(std::initializer_list input_shape, + std::initializer_list shape_shape, + std::initializer_list shape_data, + ShapeSpecificationType shape_type) { + switch (shape_type) { + case kAsTensor: + BuildWithTensorShape(input_shape, shape_shape, shape_data); + break; + case kAsConstantTensor: + BuildWithConstantTensorShape(input_shape, shape_shape, shape_data); + break; + case kAsReshapeOption: + // In this case the shape of the new shape doesn't matter. It is + // always hardcoded as a flat vector. + BuildWithHardcodedShape(input_shape, shape_data); + break; + } + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetStringInput(std::initializer_list data) { + PopulateStringTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + void BuildWithHardcodedShape(std::initializer_list input_shape, + std::initializer_list shape_data) { + input_ = AddInput({GetTensorType(), input_shape}); + output_ = AddOutput(GetTensorType()); + SetBuiltinOp( + BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions, + CreateReshapeOptions(builder_, builder_.CreateVector(shape_data)) + .Union()); + BuildInterpreter({GetShape(input_)}); + } + + void BuildWithTensorShape(std::initializer_list input_shape, + std::initializer_list shape_shape, + std::initializer_list shape_data) { + input_ = AddInput({GetTensorType(), input_shape}); + output_ = AddOutput(GetTensorType()); + int shape_input_tensor = AddInput({TensorType_INT32, shape_shape}); + // Note how shape also appears in ReshapeOptions + SetBuiltinOp( + BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions, + CreateReshapeOptions(builder_, builder_.CreateVector(shape_data)) + .Union()); + BuildInterpreter({GetShape(input_), GetShape(shape_input_tensor)}); + if (shape_data.size() != 0) { + PopulateTensor(shape_input_tensor, shape_data); + } + } + + void BuildWithConstantTensorShape(std::initializer_list input_shape, + std::initializer_list shape_shape, + std::initializer_list shape_data) { + input_ = AddInput({GetTensorType(), input_shape}); + output_ = AddOutput(GetTensorType()); + AddConstInput(TensorType_INT32, shape_data, shape_shape); + // Note how the shape also appears in the ReshapeOptions. + SetBuiltinOp( + BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions, + CreateReshapeOptions(builder_, builder_.CreateVector(shape_data)) + .Union()); + BuildInterpreter({GetShape(input_)}); + } + + int input_; + int output_; +}; + +TEST_P(ReshapeOpTest, MismatchedDimensions) { + if (GetParam() == kAsTensor) { + ReshapeOpModel m({1, 2, 4, 1}, {2}, {2, 1}, GetParam()); + m.SetInput({3}); + EXPECT_DEATH(m.Invoke(), "num_input_elements != num_output_elements"); + } else { + EXPECT_DEATH(ReshapeOpModel({1, 2, 4, 1}, {2}, {2, 1}, GetParam()), + "num_input_elements != num_output_elements"); + } +} + +TEST_P(ReshapeOpTest, TooManyDimensions) { + if (GetParam() == kAsReshapeOption) { + EXPECT_DEATH(ReshapeOpModel({1, 1, 2, 1, 1, 1, 1, 1, 1}, {9}, + {1, 1, 1, 1, 1, 1, 1, 1, 2}, GetParam()), + "Found too many dimensions"); + } else { + ReshapeOpModel m({1, 1, 2, 1, 1, 1, 1, 1, 1}, {9}, + {1, 1, 1, 1, 1, 1, 1, 1, 2}, GetParam()); + m.SetInput({3, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 4})); + EXPECT_THAT(m.GetOutputShape(), + ElementsAreArray({1, 1, 1, 1, 1, 1, 1, 1, 2})); + } +} + +TEST_P(ReshapeOpTest, TooManySpecialDimensions) { + if (GetParam() != kAsTensor) { + EXPECT_DEATH( + ReshapeOpModel({1, 2, 4, 1}, {4}, {-1, -1, 2, 4}, GetParam()), + "stretch_dim != -1"); + } else { + ReshapeOpModel m({1, 2, 4, 1}, {4}, {-1, -1, 2, 4}, GetParam()); + EXPECT_DEATH(m.Invoke(), "stretch_dim != -1"); + } +} + +// Create the model with a 2x2 shape. Processing still works because the new +// shape ends up being hardcoded as a flat vector. +TEST_P(ReshapeOpTest, InvalidShape) { + ReshapeOpModel m({1, 2, 2}, {2, 2}, {1, 2, 2, 1}, GetParam()); + m.SetInput({5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 6, 7, 8})); +} + +// This is the normal scenario, where shape is a vector. +TEST_P(ReshapeOpTest, RegularShapes) { + ReshapeOpModel m({1, 2, 4, 1}, {3}, {2, 2, 2}, GetParam()); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); +} + +TEST_P(ReshapeOpTest, WithStretchDimension) { + ReshapeOpModel m({1, 2, 4, 1}, {3}, {2, 1, -1}, GetParam()); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 4})); +} + +// Shape is specified as '[]', which is the modern way to represent scalar +// input and output. +TEST_P(ReshapeOpTest, ScalarOutput) { + ReshapeOpModel m({1}, {0}, {}, GetParam()); + m.SetInput({3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); +} + +// Some old models specify '[0]' as the new shape, indicating that both input +// and output are scalars. +TEST_P(ReshapeOpTest, LegacyScalarOutput) { + if (GetParam() == kAsConstantTensor) { + EXPECT_DEATH(ReshapeOpModel({1}, {1}, {0}, GetParam()), + "num_input_elements != num_output_elements"); + } else if (GetParam() == kAsTensor) { + ReshapeOpModel m({1}, {1}, {0}, GetParam()); + m.SetInput({3}); + EXPECT_DEATH(m.Invoke(), "num_input_elements != num_output_elements"); + } else { + ReshapeOpModel m({1}, {1}, {0}, GetParam()); + m.SetInput({3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); + } +} + +TEST_P(ReshapeOpTest, Strings) { + ReshapeOpModel m({1, 2, 4, 1}, {3}, {2, 2, 2}, GetParam()); + m.SetStringInput({"1", "2", "3", "4", "5", "6", "7", "8"}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({"1", "2", "3", "4", "5", "6", "7", "8"})); +} + +INSTANTIATE_TEST_CASE_P(VariedShapeSpec, ReshapeOpTest, + ::testing::Values(kAsReshapeOption, kAsConstantTensor, + kAsTensor)); +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/lite/kernels/resize_bilinear.cc similarity index 92% rename from tensorflow/contrib/lite/kernels/resize_bilinear.cc rename to tensorflow/lite/kernels/resize_bilinear.cc index fb045d15f35735..d42cb188669587 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc +++ b/tensorflow/lite/kernels/resize_bilinear.cc @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/lite/kernels/resize_bilinear_test.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/resize_bilinear_test.cc rename to tensorflow/lite/kernels/resize_bilinear_test.cc index f4289105f7931a..530bb32b946f07 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc +++ b/tensorflow/lite/kernels/resize_bilinear_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/resize_nearest_neighbor.cc b/tensorflow/lite/kernels/resize_nearest_neighbor.cc similarity index 92% rename from tensorflow/contrib/lite/kernels/resize_nearest_neighbor.cc rename to tensorflow/lite/kernels/resize_nearest_neighbor.cc index 95c920f95c5ec7..a48d8004f8b6ce 100644 --- a/tensorflow/contrib/lite/kernels/resize_nearest_neighbor.cc +++ b/tensorflow/lite/kernels/resize_nearest_neighbor.cc @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/resize_nearest_neighbor_test.cc b/tensorflow/lite/kernels/resize_nearest_neighbor_test.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/resize_nearest_neighbor_test.cc rename to tensorflow/lite/kernels/resize_nearest_neighbor_test.cc index b2154ff72c4c16..03e2effd84c4ad 100644 --- a/tensorflow/contrib/lite/kernels/resize_nearest_neighbor_test.cc +++ b/tensorflow/lite/kernels/resize_nearest_neighbor_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/select.cc b/tensorflow/lite/kernels/select.cc similarity index 94% rename from tensorflow/contrib/lite/kernels/select.cc rename to tensorflow/lite/kernels/select.cc index 4780a86ee51ee3..4687ab44171fab 100644 --- a/tensorflow/contrib/lite/kernels/select.cc +++ b/tensorflow/lite/kernels/select.cc @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" -#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/string_util.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/select_test.cc b/tensorflow/lite/kernels/select_test.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/select_test.cc rename to tensorflow/lite/kernels/select_test.cc index 5b2e61cd29a7fd..5111300e479a92 100644 --- a/tensorflow/contrib/lite/kernels/select_test.cc +++ b/tensorflow/lite/kernels/select_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/shape.cc b/tensorflow/lite/kernels/shape.cc similarity index 91% rename from tensorflow/contrib/lite/kernels/shape.cc rename to tensorflow/lite/kernels/shape.cc index 66d4c9e5c1a430..934f0846b9e839 100644 --- a/tensorflow/contrib/lite/kernels/shape.cc +++ b/tensorflow/lite/kernels/shape.cc @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/shape_test.cc b/tensorflow/lite/kernels/shape_test.cc similarity index 93% rename from tensorflow/contrib/lite/kernels/shape_test.cc rename to tensorflow/lite/kernels/shape_test.cc index 27b48f4e992a8f..0c13ff45b0a3c0 100644 --- a/tensorflow/contrib/lite/kernels/shape_test.cc +++ b/tensorflow/lite/kernels/shape_test.cc @@ -16,10 +16,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/skip_gram.cc b/tensorflow/lite/kernels/skip_gram.cc similarity index 94% rename from tensorflow/contrib/lite/kernels/skip_gram.cc rename to tensorflow/lite/kernels/skip_gram.cc index de80a4016ecd6f..f20719ecaf6eda 100644 --- a/tensorflow/contrib/lite/kernels/skip_gram.cc +++ b/tensorflow/lite/kernels/skip_gram.cc @@ -33,11 +33,11 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" -#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/string_util.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/skip_gram_test.cc b/tensorflow/lite/kernels/skip_gram_test.cc similarity index 97% rename from tensorflow/contrib/lite/kernels/skip_gram_test.cc rename to tensorflow/lite/kernels/skip_gram_test.cc index 185b64cb44969b..d4430b8a343040 100644 --- a/tensorflow/contrib/lite/kernels/skip_gram_test.cc +++ b/tensorflow/lite/kernels/skip_gram_test.cc @@ -16,11 +16,11 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/string_util.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/slice.cc b/tensorflow/lite/kernels/slice.cc similarity index 95% rename from tensorflow/contrib/lite/kernels/slice.cc rename to tensorflow/lite/kernels/slice.cc index ccfee41b9ca58f..116c81e4d57a9a 100644 --- a/tensorflow/contrib/lite/kernels/slice.cc +++ b/tensorflow/lite/kernels/slice.cc @@ -16,12 +16,12 @@ limitations under the License. #include #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { @@ -107,7 +107,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // Ensure validity of input tensor and its dimension. - TF_LITE_ENSURE_EQ(context, input->type, output->type); + TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); TF_LITE_ENSURE(context, begin->type == kTfLiteInt32 || begin->type == kTfLiteInt64); TF_LITE_ENSURE(context, diff --git a/tensorflow/contrib/lite/kernels/slice_test.cc b/tensorflow/lite/kernels/slice_test.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/slice_test.cc rename to tensorflow/lite/kernels/slice_test.cc index 4828f88f36bc1e..563329ddb164d3 100644 --- a/tensorflow/contrib/lite/kernels/slice_test.cc +++ b/tensorflow/lite/kernels/slice_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/softmax_test.cc b/tensorflow/lite/kernels/softmax_test.cc similarity index 93% rename from tensorflow/contrib/lite/kernels/softmax_test.cc rename to tensorflow/lite/kernels/softmax_test.cc index bd66980226cee0..eb9d7c1d9de694 100644 --- a/tensorflow/contrib/lite/kernels/softmax_test.cc +++ b/tensorflow/lite/kernels/softmax_test.cc @@ -20,11 +20,11 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc b/tensorflow/lite/kernels/space_to_batch_nd.cc similarity index 94% rename from tensorflow/contrib/lite/kernels/space_to_batch_nd.cc rename to tensorflow/lite/kernels/space_to_batch_nd.cc index 3a10d2e60cf617..1c61b2ef30379e 100644 --- a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc +++ b/tensorflow/lite/kernels/space_to_batch_nd.cc @@ -14,13 +14,13 @@ limitations under the License. ==============================================================================*/ #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc b/tensorflow/lite/kernels/space_to_batch_nd_test.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc rename to tensorflow/lite/kernels/space_to_batch_nd_test.cc index 5756573629a519..4d55ba56b71c5e 100644 --- a/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc +++ b/tensorflow/lite/kernels/space_to_batch_nd_test.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/space_to_depth.cc b/tensorflow/lite/kernels/space_to_depth.cc similarity index 91% rename from tensorflow/contrib/lite/kernels/space_to_depth.cc rename to tensorflow/lite/kernels/space_to_depth.cc index 64c56c017b0b4a..79e28bf47d98b6 100644 --- a/tensorflow/contrib/lite/kernels/space_to_depth.cc +++ b/tensorflow/lite/kernels/space_to_depth.cc @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/space_to_depth_test.cc b/tensorflow/lite/kernels/space_to_depth_test.cc similarity index 94% rename from tensorflow/contrib/lite/kernels/space_to_depth_test.cc rename to tensorflow/lite/kernels/space_to_depth_test.cc index 997f354861a235..5744669b6d62af 100644 --- a/tensorflow/contrib/lite/kernels/space_to_depth_test.cc +++ b/tensorflow/lite/kernels/space_to_depth_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/sparse_output_fully_connected.cc b/tensorflow/lite/kernels/sparse_output_fully_connected.cc similarity index 97% rename from tensorflow/contrib/lite/kernels/sparse_output_fully_connected.cc rename to tensorflow/lite/kernels/sparse_output_fully_connected.cc index 66daf5e84a0567..73d850f0e2d094 100644 --- a/tensorflow/contrib/lite/kernels/sparse_output_fully_connected.cc +++ b/tensorflow/lite/kernels/sparse_output_fully_connected.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ // SparseOutputFullyConnected is a fully connected layer that uses a single // row in the weights and bias via a lookup. -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/lite/kernels/kernel_util.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/sparse_output_fully_connected_test.cc b/tensorflow/lite/kernels/sparse_output_fully_connected_test.cc similarity index 97% rename from tensorflow/contrib/lite/kernels/sparse_output_fully_connected_test.cc rename to tensorflow/lite/kernels/sparse_output_fully_connected_test.cc index 365986a5c177ee..c25a32bde001e6 100644 --- a/tensorflow/contrib/lite/kernels/sparse_output_fully_connected_test.cc +++ b/tensorflow/lite/kernels/sparse_output_fully_connected_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include "flatbuffers/flexbuffers.h" // TF:flatbuffers -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc b/tensorflow/lite/kernels/sparse_to_dense.cc similarity index 95% rename from tensorflow/contrib/lite/kernels/sparse_to_dense.cc rename to tensorflow/lite/kernels/sparse_to_dense.cc index 349fa0bd281ce3..de4d863facb50b 100644 --- a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc +++ b/tensorflow/lite/kernels/sparse_to_dense.cc @@ -19,13 +19,13 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" -#include "tensorflow/contrib/lite/kernels/padding.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/kernels/padding.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc b/tensorflow/lite/kernels/sparse_to_dense_test.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc rename to tensorflow/lite/kernels/sparse_to_dense_test.cc index a51ec17afcefd7..ee135c220ede17 100644 --- a/tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc +++ b/tensorflow/lite/kernels/sparse_to_dense_test.cc @@ -15,10 +15,10 @@ limitations under the License. ==============================================================================*/ #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/split.cc b/tensorflow/lite/kernels/split.cc similarity index 92% rename from tensorflow/contrib/lite/kernels/split.cc rename to tensorflow/lite/kernels/split.cc index dab887bf9ccac0..7902ed2a46d297 100644 --- a/tensorflow/contrib/lite/kernels/split.cc +++ b/tensorflow/lite/kernels/split.cc @@ -14,13 +14,13 @@ limitations under the License. ==============================================================================*/ #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/split_test.cc b/tensorflow/lite/kernels/split_test.cc similarity index 95% rename from tensorflow/contrib/lite/kernels/split_test.cc rename to tensorflow/lite/kernels/split_test.cc index 61a0759c647579..f3d9ea3bf4158d 100644 --- a/tensorflow/contrib/lite/kernels/split_test.cc +++ b/tensorflow/lite/kernels/split_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/squeeze.cc b/tensorflow/lite/kernels/squeeze.cc similarity index 92% rename from tensorflow/contrib/lite/kernels/squeeze.cc rename to tensorflow/lite/kernels/squeeze.cc index 080c51cd18204a..8be0c6b9de0810 100644 --- a/tensorflow/contrib/lite/kernels/squeeze.cc +++ b/tensorflow/lite/kernels/squeeze.cc @@ -14,11 +14,11 @@ limitations under the License. ==============================================================================*/ #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/squeeze_test.cc b/tensorflow/lite/kernels/squeeze_test.cc similarity index 95% rename from tensorflow/contrib/lite/kernels/squeeze_test.cc rename to tensorflow/lite/kernels/squeeze_test.cc index a8aab88357cacb..4a02a8ee7e17ba 100644 --- a/tensorflow/contrib/lite/kernels/squeeze_test.cc +++ b/tensorflow/lite/kernels/squeeze_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/lite/kernels/strided_slice.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/strided_slice.cc rename to tensorflow/lite/kernels/strided_slice.cc index 06b36dd1967a02..c797a98e9f1bda 100644 --- a/tensorflow/contrib/lite/kernels/strided_slice.cc +++ b/tensorflow/lite/kernels/strided_slice.cc @@ -15,12 +15,12 @@ limitations under the License. #include #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/strided_slice_test.cc b/tensorflow/lite/kernels/strided_slice_test.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/strided_slice_test.cc rename to tensorflow/lite/kernels/strided_slice_test.cc index c5d4f9affb46c8..122e01b99ecbed 100644 --- a/tensorflow/contrib/lite/kernels/strided_slice_test.cc +++ b/tensorflow/lite/kernels/strided_slice_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/lite/kernels/sub.cc similarity index 94% rename from tensorflow/contrib/lite/kernels/sub.cc rename to tensorflow/lite/kernels/sub.cc index 1be0c83f17a34c..06a3b3499a005f 100644 --- a/tensorflow/contrib/lite/kernels/sub.cc +++ b/tensorflow/lite/kernels/sub.cc @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/sub_test.cc b/tensorflow/lite/kernels/sub_test.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/sub_test.cc rename to tensorflow/lite/kernels/sub_test.cc index 5978c574d35492..f0b9447ff61ced 100644 --- a/tensorflow/contrib/lite/kernels/sub_test.cc +++ b/tensorflow/lite/kernels/sub_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/lite/kernels/svdf.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/svdf.cc rename to tensorflow/lite/kernels/svdf.cc index 7e6d81239ce36e..f07937140e9ac4 100644 --- a/tensorflow/contrib/lite/kernels/svdf.cc +++ b/tensorflow/lite/kernels/svdf.cc @@ -23,12 +23,12 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/activation_functor.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/activation_functor.h" +#include "tensorflow/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/svdf_test.cc b/tensorflow/lite/kernels/svdf_test.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/svdf_test.cc rename to tensorflow/lite/kernels/svdf_test.cc index 6d60dc63f40114..8accaa465ca8a5 100644 --- a/tensorflow/contrib/lite/kernels/svdf_test.cc +++ b/tensorflow/lite/kernels/svdf_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/lite/kernels/test_util.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/test_util.cc rename to tensorflow/lite/kernels/test_util.cc index 0c0df133e2f464..6b2a1f89c37dd3 100644 --- a/tensorflow/contrib/lite/kernels/test_util.cc +++ b/tensorflow/lite/kernels/test_util.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/version.h" +#include "tensorflow/lite/version.h" #include "tensorflow/core/platform/logging.h" namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h similarity index 89% rename from tensorflow/contrib/lite/kernels/test_util.h rename to tensorflow/lite/kernels/test_util.h index 3bef0de7f9bbf8..43a5137a941d50 100644 --- a/tensorflow/contrib/lite/kernels/test_util.h +++ b/tensorflow/lite/kernels/test_util.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ -#define TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ +#ifndef TENSORFLOW_LITE_KERNELS_TEST_UTIL_H_ +#define TENSORFLOW_LITE_KERNELS_TEST_UTIL_H_ #include #include @@ -21,12 +21,12 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/string_util.h" -#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/string_util.h" +#include "tensorflow/lite/testing/util.h" #include "tensorflow/core/platform/logging.h" namespace tflite { @@ -207,7 +207,14 @@ class SingleOpModel { template void PopulateTensor(int index, const std::initializer_list& data) { T* v = interpreter_->typed_tensor(index); - CHECK(v) << "No tensor with index '" << index << "'."; + if (!v) { + auto* t = interpreter_->tensor(index); + CHECK(t) << "No tensor with index " << index << "."; + CHECK(t->data.raw) << "Empty data for tensor with index " << index << "."; + CHECK(v) << "Type mismatch for tensor with index " << index + << ". Requested " << typeToTfLiteType() << ", got " + << t->type; + } for (T f : data) { *v = f; ++v; @@ -220,7 +227,14 @@ class SingleOpModel { template void PopulateTensor(int index, const std::vector& data) { T* v = interpreter_->typed_tensor(index); - CHECK(v) << "No tensor with index '" << index << "'."; + if (!v) { + auto* t = interpreter_->tensor(index); + CHECK(t) << "No tensor with index " << index << "."; + CHECK(t->data.raw) << "Empty data for tensor with index " << index << "."; + CHECK(v) << "Type mismatch for tensor with index " << index + << ". Requested " << typeToTfLiteType() << ", got " + << t->type; + } for (T f : data) { *v = f; ++v; @@ -264,7 +278,7 @@ class SingleOpModel { private: // TODO(gavinbelson): sync this method with - // //tensorflow/contrib/lite/kernels/internal/quantization_util.h?l=31 + // //tensorflow/lite/kernels/internal/quantization_util.h?l=31 template std::pair QuantizationParams(float f_min, float f_max) { // These are required by many quantized operations. @@ -384,9 +398,19 @@ class SingleOpTest : public ::testing::TestWithParam { } }; +// Returns the corresponding TensorType given the type T. +template +TensorType GetTensorType() { + if (std::is_same::value) return TensorType_FLOAT32; + if (std::is_same::value) return TensorType_INT32; + if (std::is_same::value) return TensorType_UINT8; + if (std::is_same::value) return TensorType_STRING; + return TensorType_MIN; // default value +} + // Strings have a special implementation that is in test_util.cc template <> std::vector SingleOpModel::ExtractVector(int index); } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ +#endif // TENSORFLOW_LITE_KERNELS_TEST_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/test_util_test.cc b/tensorflow/lite/kernels/test_util_test.cc similarity index 97% rename from tensorflow/contrib/lite/kernels/test_util_test.cc rename to tensorflow/lite/kernels/test_util_test.cc index 236580347254d3..7abb7011f9d23e 100644 --- a/tensorflow/contrib/lite/kernels/test_util_test.cc +++ b/tensorflow/lite/kernels/test_util_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/lite/kernels/test_util.h" #include namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/tile.cc b/tensorflow/lite/kernels/tile.cc similarity index 88% rename from tensorflow/contrib/lite/kernels/tile.cc rename to tensorflow/lite/kernels/tile.cc index 49421eb8708162..6d13f9e92f9bd4 100644 --- a/tensorflow/contrib/lite/kernels/tile.cc +++ b/tensorflow/lite/kernels/tile.cc @@ -14,12 +14,12 @@ limitations under the License. ==============================================================================*/ #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { namespace builtin { @@ -63,7 +63,9 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) { MultiplyShapeDims(*input->dims, multipliers, num_dimensions)); default: - context->ReportError(context, "Tile not supported multiply tensor type."); + context->ReportError( + context, "Multipliers of type '%s' are not supported by tile.", + TfLiteTypeGetName(multipliers->type)); return kTfLiteError; } } @@ -143,10 +145,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers); // Only int32 and int64 multipliers type is supported. - TF_LITE_ENSURE_MSG(context, - (multipliers->type == kTfLiteInt32) || - (multipliers->type == kTfLiteInt64), - "Tile only supports int32 and int64 mutlipliers."); + if (multipliers->type != kTfLiteInt32 && multipliers->type != kTfLiteInt64) { + context->ReportError(context, + "Multipliers of type '%s' are not supported by tile.", + TfLiteTypeGetName(multipliers->type)); + return kTfLiteError; + } if (IsConstantTensor(multipliers)) { TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); @@ -179,7 +183,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { Tile(*(input->dims), input, multipliers, output); break; default: - context->ReportError(context, "Type is currently not supported by Tile."); + context->ReportError(context, "Type '%s' is not supported by tile.", + TfLiteTypeGetName(output->type)); return kTfLiteError; } return kTfLiteOk; diff --git a/tensorflow/contrib/lite/kernels/tile_test.cc b/tensorflow/lite/kernels/tile_test.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/tile_test.cc rename to tensorflow/lite/kernels/tile_test.cc index e73ca7b7504f6f..d12a7c19a367bf 100644 --- a/tensorflow/contrib/lite/kernels/tile_test.cc +++ b/tensorflow/lite/kernels/tile_test.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/topk_v2.cc b/tensorflow/lite/kernels/topk_v2.cc similarity index 96% rename from tensorflow/contrib/lite/kernels/topk_v2.cc rename to tensorflow/lite/kernels/topk_v2.cc index 6c38b6739e8751..444b01e7b2e055 100644 --- a/tensorflow/contrib/lite/kernels/topk_v2.cc +++ b/tensorflow/lite/kernels/topk_v2.cc @@ -14,11 +14,11 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { namespace builtin { diff --git a/tensorflow/contrib/lite/kernels/topk_v2_test.cc b/tensorflow/lite/kernels/topk_v2_test.cc similarity index 94% rename from tensorflow/contrib/lite/kernels/topk_v2_test.cc rename to tensorflow/lite/kernels/topk_v2_test.cc index 16106fdafeeaaa..108b8123666aad 100644 --- a/tensorflow/contrib/lite/kernels/topk_v2_test.cc +++ b/tensorflow/lite/kernels/topk_v2_test.cc @@ -14,11 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/transpose.cc b/tensorflow/lite/kernels/transpose.cc similarity index 93% rename from tensorflow/contrib/lite/kernels/transpose.cc rename to tensorflow/lite/kernels/transpose.cc index e42a30420b278a..7a6d320674ad1c 100644 --- a/tensorflow/contrib/lite/kernels/transpose.cc +++ b/tensorflow/lite/kernels/transpose.cc @@ -14,12 +14,12 @@ limitations under the License. ==============================================================================*/ #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/lite/kernels/transpose_conv.cc similarity index 95% rename from tensorflow/contrib/lite/kernels/transpose_conv.cc rename to tensorflow/lite/kernels/transpose_conv.cc index f8c858d6453989..59eee51068c0ef 100644 --- a/tensorflow/contrib/lite/kernels/transpose_conv.cc +++ b/tensorflow/lite/kernels/transpose_conv.cc @@ -19,14 +19,14 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/eigen_support.h" -#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" -#include "tensorflow/contrib/lite/kernels/padding.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/eigen_support.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/kernels/padding.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/transpose_conv_test.cc b/tensorflow/lite/kernels/transpose_conv_test.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/transpose_conv_test.cc rename to tensorflow/lite/kernels/transpose_conv_test.cc index 07fca344d6a012..0520d84a30b502 100644 --- a/tensorflow/contrib/lite/kernels/transpose_conv_test.cc +++ b/tensorflow/lite/kernels/transpose_conv_test.cc @@ -15,10 +15,10 @@ limitations under the License. #include #include #include "absl/memory/memory.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/transpose_test.cc b/tensorflow/lite/kernels/transpose_test.cc similarity index 97% rename from tensorflow/contrib/lite/kernels/transpose_test.cc rename to tensorflow/lite/kernels/transpose_test.cc index 79ef0a7c562d07..3ebaf3ca27ffd2 100644 --- a/tensorflow/contrib/lite/kernels/transpose_test.cc +++ b/tensorflow/lite/kernels/transpose_test.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc rename to tensorflow/lite/kernels/unidirectional_sequence_lstm.cc index bd6d4d1f884581..497777b9aff6c6 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc @@ -20,14 +20,14 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/activation_functor.h" -#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/lstm_eval.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/activation_functor.h" +#include "tensorflow/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/lstm_eval.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/lite/kernels/unidirectional_sequence_lstm_test.cc similarity index 99% rename from tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc rename to tensorflow/lite/kernels/unidirectional_sequence_lstm_test.cc index 1de14dd60db0c0..ae7dd6b2bee1da 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_lstm_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc similarity index 97% rename from tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc rename to tensorflow/lite/kernels/unidirectional_sequence_rnn.cc index 550a0bc02a195e..4c0fe00272a04e 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc @@ -19,12 +19,12 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/activation_functor.h" -#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/activation_functor.h" +#include "tensorflow/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc b/tensorflow/lite/kernels/unidirectional_sequence_rnn_test.cc similarity index 98% rename from tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc rename to tensorflow/lite/kernels/unidirectional_sequence_rnn_test.cc index 6b48e3fff7a9db..a2f82ac67b1b22 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_rnn_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/unpack.cc b/tensorflow/lite/kernels/unpack.cc similarity index 93% rename from tensorflow/contrib/lite/kernels/unpack.cc rename to tensorflow/lite/kernels/unpack.cc index a7d3a9bc7672be..1caffe14f90b8c 100644 --- a/tensorflow/contrib/lite/kernels/unpack.cc +++ b/tensorflow/lite/kernels/unpack.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/unpack_test.cc b/tensorflow/lite/kernels/unpack_test.cc similarity index 97% rename from tensorflow/contrib/lite/kernels/unpack_test.cc rename to tensorflow/lite/kernels/unpack_test.cc index 4efc92a0fdd680..9b60cce549804a 100644 --- a/tensorflow/contrib/lite/kernels/unpack_test.cc +++ b/tensorflow/lite/kernels/unpack_test.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/kernels/zeros_like.cc b/tensorflow/lite/kernels/zeros_like.cc similarity index 93% rename from tensorflow/contrib/lite/kernels/zeros_like.cc rename to tensorflow/lite/kernels/zeros_like.cc index cce5240a9bdc9c..a187306fa251c4 100644 --- a/tensorflow/contrib/lite/kernels/zeros_like.cc +++ b/tensorflow/lite/kernels/zeros_like.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/zeros_like_test.cc b/tensorflow/lite/kernels/zeros_like_test.cc similarity index 92% rename from tensorflow/contrib/lite/kernels/zeros_like_test.cc rename to tensorflow/lite/kernels/zeros_like_test.cc index d3382d1d5b865e..0a1d9afe33f897 100644 --- a/tensorflow/contrib/lite/kernels/zeros_like_test.cc +++ b/tensorflow/lite/kernels/zeros_like_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/lib_package/BUILD b/tensorflow/lite/lib_package/BUILD similarity index 100% rename from tensorflow/contrib/lite/lib_package/BUILD rename to tensorflow/lite/lib_package/BUILD diff --git a/tensorflow/contrib/lite/lib_package/concat_licenses.sh b/tensorflow/lite/lib_package/concat_licenses.sh similarity index 100% rename from tensorflow/contrib/lite/lib_package/concat_licenses.sh rename to tensorflow/lite/lib_package/concat_licenses.sh diff --git a/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh b/tensorflow/lite/lib_package/create_ios_frameworks.sh similarity index 87% rename from tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh rename to tensorflow/lite/lib_package/create_ios_frameworks.sh index 6195426d6d441e..fa466ed5bc7ad3 100755 --- a/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh +++ b/tensorflow/lite/lib_package/create_ios_frameworks.sh @@ -32,13 +32,13 @@ mkdir -p $FW_DIR_TFLITE_HDRS echo "Headers, populating: TensorFlow Lite" cd $TFLITE_DIR/../../.. -find tensorflow/contrib/lite -name '*.h' \ - -not -path 'tensorflow/contrib/lite/tools/*' \ - -not -path 'tensorflow/contrib/lite/examples/*' \ - -not -path 'tensorflow/contrib/lite/gen/*' \ - -not -path 'tensorflow/contrib/lite/toco/*' \ - -not -path 'tensorflow/contrib/lite/nnapi/*' \ - -not -path 'tensorflow/contrib/lite/java/*' \ +find tensorflow/lite -name '*.h' \ + -not -path 'tensorflow/lite/tools/*' \ + -not -path 'tensorflow/lite/examples/*' \ + -not -path 'tensorflow/lite/gen/*' \ + -not -path 'tensorflow/lite/toco/*' \ + -not -path 'tensorflow/lite/nnapi/*' \ + -not -path 'tensorflow/lite/java/*' \ | tar -cf $FW_DIR_TFLITE_HDRS/tmp.tar -T - cd $FW_DIR_TFLITE_HDRS tar xf tmp.tar diff --git a/tensorflow/contrib/lite/memory_planner.h b/tensorflow/lite/memory_planner.h similarity index 88% rename from tensorflow/contrib/lite/memory_planner.h rename to tensorflow/lite/memory_planner.h index 2d4707f849f5d1..fa2a44a1c89d70 100644 --- a/tensorflow/contrib/lite/memory_planner.h +++ b/tensorflow/lite/memory_planner.h @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_ -#define TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_ +#ifndef TENSORFLOW_LITE_MEMORY_PLANNER_H_ +#define TENSORFLOW_LITE_MEMORY_PLANNER_H_ -#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/lite/c/c_api_internal.h" namespace tflite { @@ -42,4 +42,4 @@ class MemoryPlanner { } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_ +#endif // TENSORFLOW_LITE_MEMORY_PLANNER_H_ diff --git a/tensorflow/contrib/lite/mmap_allocation.cc b/tensorflow/lite/mmap_allocation.cc similarity index 94% rename from tensorflow/contrib/lite/mmap_allocation.cc rename to tensorflow/lite/mmap_allocation.cc index 92934d1fd15777..11e59956996f26 100644 --- a/tensorflow/contrib/lite/mmap_allocation.cc +++ b/tensorflow/lite/mmap_allocation.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/allocation.h" -#include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include "tensorflow/lite/allocation.h" +#include "tensorflow/lite/core/api/error_reporter.h" namespace tflite { diff --git a/tensorflow/contrib/lite/mmap_allocation_disabled.cc b/tensorflow/lite/mmap_allocation_disabled.cc similarity index 96% rename from tensorflow/contrib/lite/mmap_allocation_disabled.cc rename to tensorflow/lite/mmap_allocation_disabled.cc index f3d4cf1a257d43..efb0991b5941f1 100644 --- a/tensorflow/contrib/lite/mmap_allocation_disabled.cc +++ b/tensorflow/lite/mmap_allocation_disabled.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/lite/allocation.h" #include diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/lite/model.cc similarity index 97% rename from tensorflow/contrib/lite/model.cc rename to tensorflow/lite/model.cc index a8a010be1a1800..5ac0532afeffc0 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/lite/model.cc @@ -19,15 +19,15 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/allocation.h" -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/core/api/error_reporter.h" -#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/allocation.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/core/api/flatbuffer_conversions.h" +#include "tensorflow/lite/model.h" #ifndef TFLITE_MCU -#include "tensorflow/contrib/lite/nnapi_delegate.h" +#include "tensorflow/lite/nnapi_delegate.h" #endif -#include "tensorflow/contrib/lite/version.h" +#include "tensorflow/lite/version.h" namespace tflite { @@ -404,8 +404,7 @@ TfLiteStatus InterpreterBuilder::ApplyDelegates(Interpreter* interpreter) { } if (auto flex_delegate = AcquireFlexDelegate()) { - return interpreter->ModifyGraphWithDelegate(std::move(flex_delegate), - /*allow_dynamic_tensors=*/true); + return interpreter->ModifyGraphWithDelegate(std::move(flex_delegate)); } return kTfLiteOk; diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/lite/model.h similarity index 95% rename from tensorflow/contrib/lite/model.h rename to tensorflow/lite/model.h index 9505824dcc933b..01e7c682056b2b 100644 --- a/tensorflow/contrib/lite/model.h +++ b/tensorflow/lite/model.h @@ -31,15 +31,15 @@ limitations under the License. // OpResolver must be defined to provide your kernel implementations to the // interpreter. This is environment specific and may consist of just the builtin // ops, or some custom operators you defined to extend tflite. -#ifndef TENSORFLOW_CONTRIB_LITE_MODEL_H_ -#define TENSORFLOW_CONTRIB_LITE_MODEL_H_ +#ifndef TENSORFLOW_LITE_MODEL_H_ +#define TENSORFLOW_LITE_MODEL_H_ #include -#include "tensorflow/contrib/lite/core/api/error_reporter.h" -#include "tensorflow/contrib/lite/core/api/op_resolver.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/mutable_op_resolver.h" -#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/mutable_op_resolver.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { @@ -186,4 +186,4 @@ class InterpreterBuilder { } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_MODEL_H_ +#endif // TENSORFLOW_LITE_MODEL_H_ diff --git a/tensorflow/contrib/lite/model_flex_test.cc b/tensorflow/lite/model_flex_test.cc similarity index 86% rename from tensorflow/contrib/lite/model_flex_test.cc rename to tensorflow/lite/model_flex_test.cc index 52e76bee4941c4..88b3c886b21d16 100644 --- a/tensorflow/contrib/lite/model_flex_test.cc +++ b/tensorflow/lite/model_flex_test.cc @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/model.h" #include -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/testing/util.h" namespace tflite { @@ -24,7 +24,7 @@ namespace tflite { // appropriate delegate is linked into the client. TEST(FlexModel, WithFlexDelegate) { auto model = FlatBufferModel::BuildFromFile( - "tensorflow/contrib/lite/testdata/multi_add_flex.bin"); + "tensorflow/lite/testdata/multi_add_flex.bin"); ASSERT_TRUE(model); std::unique_ptr interpreter; diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/lite/model_test.cc similarity index 91% rename from tensorflow/contrib/lite/model_test.cc rename to tensorflow/lite/model_test.cc index b969bea5dcff2f..e677ea94a71b97 100644 --- a/tensorflow/contrib/lite/model_test.cc +++ b/tensorflow/lite/model_test.cc @@ -20,12 +20,12 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/model.h" #include -#include "tensorflow/contrib/lite/core/api/error_reporter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/testing/util.h" // Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object, // we must declare this in global namespace, so argument-dependent operator @@ -75,7 +75,7 @@ TEST(BasicFlatBufferModel, TestNonExistantFiles) { // Make sure a model with nothing in it loads properly. TEST(BasicFlatBufferModel, TestEmptyModelsAndNullDestination) { auto model = FlatBufferModel::BuildFromFile( - "tensorflow/contrib/lite/testdata/empty_model.bin"); + "tensorflow/lite/testdata/empty_model.bin"); ASSERT_TRUE(model); // Now try to build it into a model. std::unique_ptr interpreter; @@ -89,14 +89,14 @@ TEST(BasicFlatBufferModel, TestEmptyModelsAndNullDestination) { // TODO(aselle): Replace this test when multiple subgraphs are supported. TEST(BasicFlatBufferModel, TestZeroAndMultipleSubgraphs) { auto m1 = FlatBufferModel::BuildFromFile( - "tensorflow/contrib/lite/testdata/0_subgraphs.bin"); + "tensorflow/lite/testdata/0_subgraphs.bin"); ASSERT_TRUE(m1); std::unique_ptr interpreter1; ASSERT_NE(InterpreterBuilder(*m1, TrivialResolver())(&interpreter1), kTfLiteOk); auto m2 = FlatBufferModel::BuildFromFile( - "tensorflow/contrib/lite/testdata/2_subgraphs.bin"); + "tensorflow/lite/testdata/2_subgraphs.bin"); ASSERT_TRUE(m2); std::unique_ptr interpreter2; ASSERT_NE(InterpreterBuilder(*m2, TrivialResolver())(&interpreter2), @@ -106,7 +106,7 @@ TEST(BasicFlatBufferModel, TestZeroAndMultipleSubgraphs) { // Test what happens if we cannot bind any of the ops. TEST(BasicFlatBufferModel, TestModelWithoutNullRegistrations) { auto model = FlatBufferModel::BuildFromFile( - "tensorflow/contrib/lite/testdata/test_model.bin"); + "tensorflow/lite/testdata/test_model.bin"); ASSERT_TRUE(model); // Check that we get an error code and interpreter pointer is reset. std::unique_ptr interpreter(new Interpreter); @@ -118,7 +118,7 @@ TEST(BasicFlatBufferModel, TestModelWithoutNullRegistrations) { // Make sure model is read to interpreter propelrly TEST(BasicFlatBufferModel, TestModelInInterpreter) { auto model = FlatBufferModel::BuildFromFile( - "tensorflow/contrib/lite/testdata/test_model.bin"); + "tensorflow/lite/testdata/test_model.bin"); ASSERT_TRUE(model); // Check that we get an error code and interpreter pointer is reset. std::unique_ptr interpreter(new Interpreter); @@ -198,7 +198,7 @@ TEST(BasicFlatBufferModel, TestModelInInterpreter) { // not linked into the target. TEST(FlexModel, FailureWithoutFlexDelegate) { auto model = FlatBufferModel::BuildFromFile( - "tensorflow/contrib/lite/testdata/multi_add_flex.bin"); + "tensorflow/lite/testdata/multi_add_flex.bin"); ASSERT_TRUE(model); // Note that creation will succeed when using the BuiltinOpResolver, but @@ -219,7 +219,7 @@ TEST(FlexModel, FailureWithoutFlexDelegate) { // buffer. But the buffer is provided to be only 1 element. TEST(BasicFlatBufferModel, TestBrokenMmap) { ASSERT_FALSE(FlatBufferModel::BuildFromFile( - "tensorflow/contrib/lite/testdata/test_model_broken.bin")); + "tensorflow/lite/testdata/test_model_broken.bin")); } TEST(BasicFlatBufferModel, TestNullModel) { @@ -247,20 +247,20 @@ class FakeVerifier : public tflite::TfLiteVerifier { TEST(BasicFlatBufferModel, TestWithTrueVerifier) { FakeVerifier verifier(true); ASSERT_TRUE(FlatBufferModel::VerifyAndBuildFromFile( - "tensorflow/contrib/lite/testdata/test_model.bin", + "tensorflow/lite/testdata/test_model.bin", &verifier)); } TEST(BasicFlatBufferModel, TestWithFalseVerifier) { FakeVerifier verifier(false); ASSERT_FALSE(FlatBufferModel::VerifyAndBuildFromFile( - "tensorflow/contrib/lite/testdata/test_model.bin", + "tensorflow/lite/testdata/test_model.bin", &verifier)); } TEST(BasicFlatBufferModel, TestWithNullVerifier) { ASSERT_TRUE(FlatBufferModel::VerifyAndBuildFromFile( - "tensorflow/contrib/lite/testdata/test_model.bin", nullptr)); + "tensorflow/lite/testdata/test_model.bin", nullptr)); } // This makes sure the ErrorReporter is marshalled from FlatBufferModel to @@ -268,7 +268,7 @@ TEST(BasicFlatBufferModel, TestWithNullVerifier) { TEST(BasicFlatBufferModel, TestCustomErrorReporter) { TestErrorReporter reporter; auto model = FlatBufferModel::BuildFromFile( - "tensorflow/contrib/lite/testdata/empty_model.bin", + "tensorflow/lite/testdata/empty_model.bin", &reporter); ASSERT_TRUE(model); @@ -283,7 +283,7 @@ TEST(BasicFlatBufferModel, TestCustomErrorReporter) { // the Interpreter. TEST(BasicFlatBufferModel, TestNullErrorReporter) { auto model = FlatBufferModel::BuildFromFile( - "tensorflow/contrib/lite/testdata/empty_model.bin", nullptr); + "tensorflow/lite/testdata/empty_model.bin", nullptr); ASSERT_TRUE(model); std::unique_ptr interpreter; @@ -296,7 +296,7 @@ TEST(BasicFlatBufferModel, TestNullErrorReporter) { TEST(BasicFlatBufferModel, TestBuildFromModel) { TestErrorReporter reporter; FileCopyAllocation model_allocation( - "tensorflow/contrib/lite/testdata/test_model.bin", &reporter); + "tensorflow/lite/testdata/test_model.bin", &reporter); ASSERT_TRUE(model_allocation.valid()); ::flatbuffers::Verifier verifier( reinterpret_cast(model_allocation.base()), diff --git a/tensorflow/contrib/lite/models/BUILD b/tensorflow/lite/models/BUILD similarity index 74% rename from tensorflow/contrib/lite/models/BUILD rename to tensorflow/lite/models/BUILD index efa47b06fa7f06..8730160e4005df 100644 --- a/tensorflow/contrib/lite/models/BUILD +++ b/tensorflow/lite/models/BUILD @@ -7,7 +7,7 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") +load("//tensorflow/lite:build_def.bzl", "tflite_copts") exports_files(glob([ "testdata/*", diff --git a/tensorflow/contrib/lite/models/smartreply/BUILD b/tensorflow/lite/models/smartreply/BUILD similarity index 60% rename from tensorflow/contrib/lite/models/smartreply/BUILD rename to tensorflow/lite/models/smartreply/BUILD index 9d88c396ba6994..078b8e6bc6a288 100644 --- a/tensorflow/contrib/lite/models/smartreply/BUILD +++ b/tensorflow/lite/models/smartreply/BUILD @@ -1,6 +1,6 @@ package(default_visibility = ["//visibility:public"]) -load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") +load("//tensorflow/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") licenses(["notice"]) # Apache 2.0 @@ -19,9 +19,9 @@ cc_library( ], copts = tflite_copts(), deps = [ - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:string_util", - "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/lite:framework", + "//tensorflow/lite:string_util", + "//tensorflow/lite/kernels:builtin_ops", "@com_google_absl//absl/strings", "@com_googlesource_code_re2//:re2", "@farmhash_archive//:farmhash", @@ -35,9 +35,9 @@ cc_library( copts = tflite_copts(), deps = [ ":custom_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:string_util", - "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/lite:framework", + "//tensorflow/lite:string_util", + "//tensorflow/lite/kernels:builtin_ops", "@com_google_absl//absl/strings", "@com_googlesource_code_re2//:re2", ], @@ -50,9 +50,9 @@ cc_test( tags = ["no_oss"], deps = [ ":custom_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:builtin_ops", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", "@farmhash_archive//:farmhash", ], @@ -65,10 +65,10 @@ cc_test( tags = ["no_oss"], deps = [ ":custom_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:string_util", - "//tensorflow/contrib/lite/kernels:builtin_ops", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite:string_util", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) @@ -80,10 +80,10 @@ cc_test( tags = ["no_oss"], deps = [ ":custom_ops", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:string_util", - "//tensorflow/contrib/lite/kernels:builtin_ops", - "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/lite:framework", + "//tensorflow/lite:string_util", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/kernels:test_util", "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml b/tensorflow/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml similarity index 100% rename from tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml rename to tensorflow/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD b/tensorflow/lite/models/smartreply/demo/app/src/main/BUILD similarity index 88% rename from tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD rename to tensorflow/lite/models/smartreply/demo/app/src/main/BUILD index 2e5033dab1356e..b14af4cb20b893 100644 --- a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD +++ b/tensorflow/lite/models/smartreply/demo/app/src/main/BUILD @@ -5,7 +5,7 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 load( - "//tensorflow/contrib/lite:build_def.bzl", + "//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_jni_binary", ) @@ -61,8 +61,8 @@ cc_library( "-ldl", ], deps = [ - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/models/smartreply:predictor_lib", + "//tensorflow/lite:framework", + "//tensorflow/lite/models/smartreply:predictor_lib", ], alwayslink = 1, ) diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD b/tensorflow/lite/models/smartreply/demo/app/src/main/assets/BUILD similarity index 100% rename from tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD rename to tensorflow/lite/models/smartreply/demo/app/src/main/assets/BUILD diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt b/tensorflow/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt similarity index 100% rename from tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt rename to tensorflow/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java b/tensorflow/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java similarity index 100% rename from tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java rename to tensorflow/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java b/tensorflow/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java similarity index 100% rename from tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java rename to tensorflow/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java b/tensorflow/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java similarity index 100% rename from tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java rename to tensorflow/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml b/tensorflow/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml similarity index 100% rename from tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml rename to tensorflow/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/smartreply_jni.cc b/tensorflow/lite/models/smartreply/demo/app/src/main/smartreply_jni.cc similarity index 97% rename from tensorflow/contrib/lite/models/smartreply/demo/app/src/main/smartreply_jni.cc rename to tensorflow/lite/models/smartreply/demo/app/src/main/smartreply_jni.cc index f158cc511a9bee..9b5df36c37a1d2 100644 --- a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/smartreply_jni.cc +++ b/tensorflow/lite/models/smartreply/demo/app/src/main/smartreply_jni.cc @@ -17,8 +17,8 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/models/smartreply/predictor.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/models/smartreply/predictor.h" const char kIllegalStateException[] = "java/lang/IllegalStateException"; diff --git a/tensorflow/contrib/lite/models/smartreply/g3doc/README.md b/tensorflow/lite/models/smartreply/g3doc/README.md similarity index 98% rename from tensorflow/contrib/lite/models/smartreply/g3doc/README.md rename to tensorflow/lite/models/smartreply/g3doc/README.md index a6d75648b3f3da..1b8ff15196cd4d 100644 --- a/tensorflow/contrib/lite/models/smartreply/g3doc/README.md +++ b/tensorflow/lite/models/smartreply/g3doc/README.md @@ -38,7 +38,7 @@ The On-Device Smart Reply model is aimed towards improving the messaging experience for day-to-day conversational chat messages. We recommend using this model for similar use cases. Some sample messages on which the model does well are provided in this [tsv -file](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/testdata/smartreply_samples.tsv) +file](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/models/testdata/smartreply_samples.tsv) for reference. The file format is: ``` @@ -143,4 +143,4 @@ Following are the ops supported for using On-Device Smart Reply model: ## Further Information * Open source code - [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/smartreply/). + [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/models/smartreply/). diff --git a/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc b/tensorflow/lite/models/smartreply/ops/extract_feature.cc similarity index 95% rename from tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc rename to tensorflow/lite/models/smartreply/ops/extract_feature.cc index 29c8ad2286d705..f9d29229457c40 100644 --- a/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc +++ b/tensorflow/lite/models/smartreply/ops/extract_feature.cc @@ -24,9 +24,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/string_util.h" #include namespace tflite { diff --git a/tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc b/tensorflow/lite/models/smartreply/ops/extract_feature_test.cc similarity index 93% rename from tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc rename to tensorflow/lite/models/smartreply/ops/extract_feature_test.cc index 9b8676bab6e811..efe59eeb4667cc 100644 --- a/tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc +++ b/tensorflow/lite/models/smartreply/ops/extract_feature_test.cc @@ -16,10 +16,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" #include namespace tflite { diff --git a/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc b/tensorflow/lite/models/smartreply/ops/normalize.cc similarity index 95% rename from tensorflow/contrib/lite/models/smartreply/ops/normalize.cc rename to tensorflow/lite/models/smartreply/ops/normalize.cc index c55ac9f52f7293..8480260f279c00 100644 --- a/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc +++ b/tensorflow/lite/models/smartreply/ops/normalize.cc @@ -28,9 +28,9 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/strip.h" #include "re2/re2.h" -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/string_util.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/models/smartreply/ops/normalize_test.cc b/tensorflow/lite/models/smartreply/ops/normalize_test.cc similarity index 90% rename from tensorflow/contrib/lite/models/smartreply/ops/normalize_test.cc rename to tensorflow/lite/models/smartreply/ops/normalize_test.cc index 4d35dba9a64a84..8c5131565d5892 100644 --- a/tensorflow/contrib/lite/models/smartreply/ops/normalize_test.cc +++ b/tensorflow/lite/models/smartreply/ops/normalize_test.cc @@ -16,11 +16,11 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/string_util.h" namespace tflite { diff --git a/tensorflow/contrib/lite/models/smartreply/ops/predict.cc b/tensorflow/lite/models/smartreply/ops/predict.cc similarity index 99% rename from tensorflow/contrib/lite/models/smartreply/ops/predict.cc rename to tensorflow/lite/models/smartreply/ops/predict.cc index 7b23adb990cf10..bb2ed4a3153ceb 100644 --- a/tensorflow/contrib/lite/models/smartreply/ops/predict.cc +++ b/tensorflow/lite/models/smartreply/ops/predict.cc @@ -31,7 +31,7 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/lite/context.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/models/smartreply/ops/predict_test.cc b/tensorflow/lite/models/smartreply/ops/predict_test.cc similarity index 95% rename from tensorflow/contrib/lite/models/smartreply/ops/predict_test.cc rename to tensorflow/lite/models/smartreply/ops/predict_test.cc index e97c58cbd18502..ca64dcaad47108 100644 --- a/tensorflow/contrib/lite/models/smartreply/ops/predict_test.cc +++ b/tensorflow/lite/models/smartreply/ops/predict_test.cc @@ -16,11 +16,11 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/string_util.h" namespace tflite { diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.cc b/tensorflow/lite/models/smartreply/predictor.cc similarity index 92% rename from tensorflow/contrib/lite/models/smartreply/predictor.cc rename to tensorflow/lite/models/smartreply/predictor.cc index 5d6c47dce8d901..7db2502977707d 100644 --- a/tensorflow/contrib/lite/models/smartreply/predictor.cc +++ b/tensorflow/lite/models/smartreply/predictor.cc @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/models/smartreply/predictor.h" +#include "tensorflow/lite/models/smartreply/predictor.h" #include "absl/strings/str_split.h" #include "re2/re2.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/op_resolver.h" -#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/op_resolver.h" +#include "tensorflow/lite/string_util.h" void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.h b/tensorflow/lite/models/smartreply/predictor.h similarity index 91% rename from tensorflow/contrib/lite/models/smartreply/predictor.h rename to tensorflow/lite/models/smartreply/predictor.h index 3151192d9277b6..6b8f9298a36f6f 100644 --- a/tensorflow/contrib/lite/models/smartreply/predictor.h +++ b/tensorflow/lite/models/smartreply/predictor.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_ -#define TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_ +#ifndef TENSORFLOW_LITE_MODELS_SMARTREPLY_PREDICTOR_H_ +#define TENSORFLOW_LITE_MODELS_SMARTREPLY_PREDICTOR_H_ #include #include -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/model.h" namespace tflite { namespace custom { @@ -77,4 +77,4 @@ struct SmartReplyConfig { } // namespace custom } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_ +#endif // TENSORFLOW_LITE_MODELS_SMARTREPLY_PREDICTOR_H_ diff --git a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc b/tensorflow/lite/models/smartreply/predictor_test.cc similarity index 94% rename from tensorflow/contrib/lite/models/smartreply/predictor_test.cc rename to tensorflow/lite/models/smartreply/predictor_test.cc index c7e08814fdf502..7eba26993e5917 100644 --- a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc +++ b/tensorflow/lite/models/smartreply/predictor_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/models/smartreply/predictor.h" +#include "tensorflow/lite/models/smartreply/predictor.h" #include #include @@ -22,8 +22,8 @@ limitations under the License. #include #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" -//#include "tensorflow/contrib/lite/models/test_utils.h" -#include "tensorflow/contrib/lite/string_util.h" +//#include "tensorflow/lite/models/test_utils.h" +#include "tensorflow/lite/string_util.h" #include "tensorflow/core/platform/test.h" namespace tflite { @@ -36,7 +36,7 @@ const char kSamples[] = "smartreply_samples.tsv"; string TestDataPath() { return string(absl::StrCat(tensorflow::testing::TensorFlowSrcRoot(), "/", - "contrib/lite/models/testdata/")); + "lite/models/testdata/")); } MATCHER_P(IncludeAnyResponesIn, expected_response, "contains the response") { diff --git a/tensorflow/contrib/lite/models/speech_test.cc b/tensorflow/lite/models/speech_test.cc similarity index 96% rename from tensorflow/contrib/lite/models/speech_test.cc rename to tensorflow/lite/models/speech_test.cc index 8ecf0b6154a622..17b7e8f28e8fb0 100644 --- a/tensorflow/contrib/lite/models/speech_test.cc +++ b/tensorflow/lite/models/speech_test.cc @@ -21,14 +21,14 @@ limitations under the License. #include "testing/base/public/googletest.h" #include -#include "tensorflow/contrib/lite/testing/parse_testdata.h" -#include "tensorflow/contrib/lite/testing/split.h" -#include "tensorflow/contrib/lite/testing/tflite_driver.h" +#include "tensorflow/lite/testing/parse_testdata.h" +#include "tensorflow/lite/testing/split.h" +#include "tensorflow/lite/testing/tflite_driver.h" namespace tflite { namespace { -const char kDataPath[] = "third_party/tensorflow/contrib/lite/models/testdata/"; +const char kDataPath[] = "third_party/tensorflow/lite/models/testdata/"; bool Init(const string& in_file_name, testing::TfLiteDriver* driver, std::ifstream* in_file) { diff --git a/tensorflow/contrib/lite/models/testdata/g3doc/README.md b/tensorflow/lite/models/testdata/g3doc/README.md similarity index 93% rename from tensorflow/contrib/lite/models/testdata/g3doc/README.md rename to tensorflow/lite/models/testdata/g3doc/README.md index 1c47e00aae2a0e..2a4f1c143a2172 100644 --- a/tensorflow/contrib/lite/models/testdata/g3doc/README.md +++ b/tensorflow/lite/models/testdata/g3doc/README.md @@ -118,26 +118,26 @@ model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/spee ### Test benches [Speech hotword model -test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_hotword_model_test.cc) +test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/models/speech_hotword_model_test.cc) [Speaker-id model -test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc) +test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/models/speech_speakerid_model_test.cc) [TTS model -test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_tts_model_test.cc) +test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/models/speech_tts_model_test.cc) [ASR AM model -test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_asr_am_model_test.cc) +test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/models/speech_asr_am_model_test.cc) [ASR LM model -test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_asr_lm_model_test.cc) +test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/models/speech_asr_lm_model_test.cc) [Endpointer model -test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_endpointer_model_test.cc) +test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/models/speech_endpointer_model_test.cc) ## Android Support The models have been tested on Android phones, using the following tests: -[Hotword] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/android/BUILD?rcl=172930882&l=25) +[Hotword] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/android/BUILD?rcl=172930882&l=25) -[Speaker-id] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/android/BUILD?rcl=172930882&l=36) +[Speaker-id] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/android/BUILD?rcl=172930882&l=36) diff --git a/tensorflow/contrib/lite/models/testdata/g3doc/asr_am.svg b/tensorflow/lite/models/testdata/g3doc/asr_am.svg similarity index 100% rename from tensorflow/contrib/lite/models/testdata/g3doc/asr_am.svg rename to tensorflow/lite/models/testdata/g3doc/asr_am.svg diff --git a/tensorflow/contrib/lite/models/testdata/g3doc/asr_lm.svg b/tensorflow/lite/models/testdata/g3doc/asr_lm.svg similarity index 100% rename from tensorflow/contrib/lite/models/testdata/g3doc/asr_lm.svg rename to tensorflow/lite/models/testdata/g3doc/asr_lm.svg diff --git a/tensorflow/contrib/lite/models/testdata/g3doc/endpointer.svg b/tensorflow/lite/models/testdata/g3doc/endpointer.svg similarity index 100% rename from tensorflow/contrib/lite/models/testdata/g3doc/endpointer.svg rename to tensorflow/lite/models/testdata/g3doc/endpointer.svg diff --git a/tensorflow/contrib/lite/models/testdata/g3doc/hotword.svg b/tensorflow/lite/models/testdata/g3doc/hotword.svg similarity index 100% rename from tensorflow/contrib/lite/models/testdata/g3doc/hotword.svg rename to tensorflow/lite/models/testdata/g3doc/hotword.svg diff --git a/tensorflow/contrib/lite/models/testdata/g3doc/speakerid.svg b/tensorflow/lite/models/testdata/g3doc/speakerid.svg similarity index 100% rename from tensorflow/contrib/lite/models/testdata/g3doc/speakerid.svg rename to tensorflow/lite/models/testdata/g3doc/speakerid.svg diff --git a/tensorflow/contrib/lite/models/testdata/g3doc/tts.svg b/tensorflow/lite/models/testdata/g3doc/tts.svg similarity index 100% rename from tensorflow/contrib/lite/models/testdata/g3doc/tts.svg rename to tensorflow/lite/models/testdata/g3doc/tts.svg diff --git a/tensorflow/contrib/lite/models/testdata/smartreply_samples.tsv b/tensorflow/lite/models/testdata/smartreply_samples.tsv similarity index 100% rename from tensorflow/contrib/lite/models/testdata/smartreply_samples.tsv rename to tensorflow/lite/models/testdata/smartreply_samples.tsv diff --git a/tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec b/tensorflow/lite/models/testdata/speech_asr_lm_model.test_spec similarity index 100% rename from tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec rename to tensorflow/lite/models/testdata/speech_asr_lm_model.test_spec diff --git a/tensorflow/contrib/lite/mutable_op_resolver.cc b/tensorflow/lite/mutable_op_resolver.cc similarity index 97% rename from tensorflow/contrib/lite/mutable_op_resolver.cc rename to tensorflow/lite/mutable_op_resolver.cc index a36404399bb3e0..36c512dcaacef9 100644 --- a/tensorflow/contrib/lite/mutable_op_resolver.cc +++ b/tensorflow/lite/mutable_op_resolver.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/mutable_op_resolver.h" +#include "tensorflow/lite/mutable_op_resolver.h" namespace tflite { diff --git a/tensorflow/contrib/lite/mutable_op_resolver.h b/tensorflow/lite/mutable_op_resolver.h similarity index 91% rename from tensorflow/contrib/lite/mutable_op_resolver.h rename to tensorflow/lite/mutable_op_resolver.h index efd6cfac2ac899..b5700595499714 100644 --- a/tensorflow/contrib/lite/mutable_op_resolver.h +++ b/tensorflow/lite/mutable_op_resolver.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_ -#define TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_ +#ifndef TENSORFLOW_LITE_MUTABLE_OP_RESOLVER_H_ +#define TENSORFLOW_LITE_MUTABLE_OP_RESOLVER_H_ #include -#include "tensorflow/contrib/lite/core/api/op_resolver.h" -#include "tensorflow/contrib/lite/util.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/util.h" namespace tflite { @@ -78,4 +78,4 @@ class MutableOpResolver : public OpResolver { } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_ +#endif // TENSORFLOW_LITE_MUTABLE_OP_RESOLVER_H_ diff --git a/tensorflow/contrib/lite/mutable_op_resolver_test.cc b/tensorflow/lite/mutable_op_resolver_test.cc similarity index 98% rename from tensorflow/contrib/lite/mutable_op_resolver_test.cc rename to tensorflow/lite/mutable_op_resolver_test.cc index b70c7038396782..64fc68a16ca62d 100644 --- a/tensorflow/contrib/lite/mutable_op_resolver_test.cc +++ b/tensorflow/lite/mutable_op_resolver_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/mutable_op_resolver.h" +#include "tensorflow/lite/mutable_op_resolver.h" #include -#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/lite/testing/util.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/nnapi/BUILD b/tensorflow/lite/nnapi/BUILD similarity index 100% rename from tensorflow/contrib/lite/nnapi/BUILD rename to tensorflow/lite/nnapi/BUILD diff --git a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h b/tensorflow/lite/nnapi/NeuralNetworksShim.h similarity index 99% rename from tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h rename to tensorflow/lite/nnapi/NeuralNetworksShim.h index eccf4aefb6372b..c39502f4acc5dc 100644 --- a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h +++ b/tensorflow/lite/nnapi/NeuralNetworksShim.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_NNAPI_NEURALNETWORKSSHIM_H_ -#define TENSORFLOW_CONTRIB_LITE_NNAPI_NEURALNETWORKSSHIM_H_ +#ifndef TENSORFLOW_LITE_NNAPI_NEURALNETWORKSSHIM_H_ +#define TENSORFLOW_LITE_NNAPI_NEURALNETWORKSSHIM_H_ #include #include @@ -1009,4 +1009,4 @@ inline void ANeuralNetworksEvent_free(ANeuralNetworksEvent* event) { /**/ -#endif // TENSORFLOW_CONTRIB_LITE_NNAPI_NEURALNETWORKSSHIM_H_ +#endif // TENSORFLOW_LITE_NNAPI_NEURALNETWORKSSHIM_H_ diff --git a/tensorflow/contrib/lite/nnapi/README.md b/tensorflow/lite/nnapi/README.md similarity index 100% rename from tensorflow/contrib/lite/nnapi/README.md rename to tensorflow/lite/nnapi/README.md diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/lite/nnapi_delegate.cc similarity index 99% rename from tensorflow/contrib/lite/nnapi_delegate.cc rename to tensorflow/lite/nnapi_delegate.cc index 9cca2293324923..950bdb39425f89 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/lite/nnapi_delegate.cc @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/nnapi_delegate.h" +#include "tensorflow/lite/nnapi_delegate.h" #include #include #include #include -#include "tensorflow/contrib/lite/c/builtin_op_data.h" -#include "tensorflow/contrib/lite/core/api/error_reporter.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/nnapi/NeuralNetworksShim.h" #ifdef __ANDROID__ #include diff --git a/tensorflow/contrib/lite/nnapi_delegate.h b/tensorflow/lite/nnapi_delegate.h similarity index 86% rename from tensorflow/contrib/lite/nnapi_delegate.h rename to tensorflow/lite/nnapi_delegate.h index 22359d557e61e3..63b408c1416ed1 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.h +++ b/tensorflow/lite/nnapi_delegate.h @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_ -#define TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_ +#ifndef TENSORFLOW_LITE_NNAPI_DELEGATE_H_ +#define TENSORFLOW_LITE_NNAPI_DELEGATE_H_ -#include "tensorflow/contrib/lite/allocation.h" -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/core/api/error_reporter.h" -#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/lite/allocation.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/interpreter.h" class ANeuralNetworksModel; class ANeuralNetworksMemory; @@ -77,4 +77,4 @@ class NNAPIDelegate { } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_ +#endif // TENSORFLOW_LITE_NNAPI_DELEGATE_H_ diff --git a/tensorflow/contrib/lite/nnapi_delegate_disabled.cc b/tensorflow/lite/nnapi_delegate_disabled.cc similarity index 96% rename from tensorflow/contrib/lite/nnapi_delegate_disabled.cc rename to tensorflow/lite/nnapi_delegate_disabled.cc index e3536d3db6c59f..44dc21f1b6c2b3 100644 --- a/tensorflow/contrib/lite/nnapi_delegate_disabled.cc +++ b/tensorflow/lite/nnapi_delegate_disabled.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/nnapi_delegate.h" +#include "tensorflow/lite/nnapi_delegate.h" #include diff --git a/tensorflow/contrib/lite/op_resolver.h b/tensorflow/lite/op_resolver.h similarity index 73% rename from tensorflow/contrib/lite/op_resolver.h rename to tensorflow/lite/op_resolver.h index e93134cbdecd58..96490d44b91c10 100644 --- a/tensorflow/contrib/lite/op_resolver.h +++ b/tensorflow/lite/op_resolver.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // Compatibility shim for moved header location. -#ifndef TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_ -#define TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_ +#ifndef TENSORFLOW_LITE_OP_RESOLVER_H_ +#define TENSORFLOW_LITE_OP_RESOLVER_H_ -#include "tensorflow/contrib/lite/core/api/op_resolver.h" -#include "tensorflow/contrib/lite/mutable_op_resolver.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/mutable_op_resolver.h" -#endif // TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_ +#endif // TENSORFLOW_LITE_OP_RESOLVER_H_ diff --git a/tensorflow/contrib/lite/optional_debug_tools.cc b/tensorflow/lite/optional_debug_tools.cc similarity index 98% rename from tensorflow/contrib/lite/optional_debug_tools.cc rename to tensorflow/lite/optional_debug_tools.cc index 64ba2d8baa2ea2..020d1d8de5ff0e 100644 --- a/tensorflow/contrib/lite/optional_debug_tools.cc +++ b/tensorflow/lite/optional_debug_tools.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/optional_debug_tools.h" +#include "tensorflow/lite/optional_debug_tools.h" namespace tflite { diff --git a/tensorflow/contrib/lite/optional_debug_tools.h b/tensorflow/lite/optional_debug_tools.h similarity index 80% rename from tensorflow/contrib/lite/optional_debug_tools.h rename to tensorflow/lite/optional_debug_tools.h index 82a6e114a66eb3..fb2f78e5ae42ab 100644 --- a/tensorflow/contrib/lite/optional_debug_tools.h +++ b/tensorflow/lite/optional_debug_tools.h @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ // Optional debugging functionality. For small sized binaries, these are not // needed. -#ifndef TENSORFLOW_CONTRIB_LITE_OPTIONAL_DEBUG_TOOLS_H_ -#define TENSORFLOW_CONTRIB_LITE_OPTIONAL_DEBUG_TOOLS_H_ +#ifndef TENSORFLOW_LITE_OPTIONAL_DEBUG_TOOLS_H_ +#define TENSORFLOW_LITE_OPTIONAL_DEBUG_TOOLS_H_ -#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/lite/interpreter.h" namespace tflite { @@ -26,4 +26,4 @@ void PrintInterpreterState(Interpreter* interpreter); } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_OPTIONAL_DEBUG_TOOLS_H_ +#endif // TENSORFLOW_LITE_OPTIONAL_DEBUG_TOOLS_H_ diff --git a/tensorflow/contrib/lite/profiling/BUILD b/tensorflow/lite/profiling/BUILD similarity index 71% rename from tensorflow/contrib/lite/profiling/BUILD rename to tensorflow/lite/profiling/BUILD index 1172722f7a7077..c7a8e4f06ae190 100644 --- a/tensorflow/contrib/lite/profiling/BUILD +++ b/tensorflow/lite/profiling/BUILD @@ -2,7 +2,7 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 -load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") +load("//tensorflow/lite:build_def.bzl", "tflite_copts") common_copts = [ "-Wall", @@ -22,7 +22,7 @@ cc_test( defines = ["TFLITE_PROFILING_ENABLED"], deps = [ ":profiler", - "//tensorflow/contrib/lite/testing:util", + "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -48,9 +48,9 @@ cc_library( copts = common_copts, deps = [ ":profiler", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/schema:schema_fbs", "//tensorflow/core:stats_calculator_portable", + "//tensorflow/lite:framework", + "//tensorflow/lite/schema:schema_fbs", ], ) @@ -61,12 +61,12 @@ cc_test( tags = ["no_oss"], deps = [ ":profile_summarizer", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:schema_fbs_version", - "//tensorflow/contrib/lite/kernels:builtin_ops", - "//tensorflow/contrib/lite/kernels:kernel_util", - "//tensorflow/contrib/lite/kernels:test_util", - "//tensorflow/contrib/lite/testing:util", + "//tensorflow/lite:framework", + "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/kernels:kernel_util", + "//tensorflow/lite/kernels:test_util", + "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -78,7 +78,7 @@ cc_test( defines = ["TFLITE_PROFILING_ENABLED"], deps = [ ":profile_buffer", - "//tensorflow/contrib/lite/testing:util", + "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/contrib/lite/profiling/profile_buffer.h b/tensorflow/lite/profiling/profile_buffer.h similarity index 95% rename from tensorflow/contrib/lite/profiling/profile_buffer.h rename to tensorflow/lite/profiling/profile_buffer.h index 65d86dce47f397..247ebb37c53e7a 100644 --- a/tensorflow/contrib/lite/profiling/profile_buffer.h +++ b/tensorflow/lite/profiling/profile_buffer.h @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_BUFFER_H_ -#define TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_BUFFER_H_ +#ifndef TENSORFLOW_LITE_PROFILING_PROFILE_BUFFER_H_ +#define TENSORFLOW_LITE_PROFILING_PROFILE_BUFFER_H_ #include #include -#include "tensorflow/contrib/lite/profiling/time.h" +#include "tensorflow/lite/profiling/time.h" namespace tflite { namespace profiling { @@ -143,4 +143,4 @@ class ProfileBuffer { } // namespace profiling } // namespace tflite #endif // TFLITE_PROFILING_ENABLED -#endif // TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_BUFFER_H_ +#endif // TENSORFLOW_LITE_PROFILING_PROFILE_BUFFER_H_ diff --git a/tensorflow/contrib/lite/profiling/profile_buffer_test.cc b/tensorflow/lite/profiling/profile_buffer_test.cc similarity index 96% rename from tensorflow/contrib/lite/profiling/profile_buffer_test.cc rename to tensorflow/lite/profiling/profile_buffer_test.cc index b8784cca455cfc..6642a15884fdf5 100644 --- a/tensorflow/contrib/lite/profiling/profile_buffer_test.cc +++ b/tensorflow/lite/profiling/profile_buffer_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/profiling/profile_buffer.h" -#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/lite/profiling/profile_buffer.h" +#include "tensorflow/lite/testing/util.h" namespace tflite { namespace profiling { diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer.cc b/tensorflow/lite/profiling/profile_summarizer.cc similarity index 97% rename from tensorflow/contrib/lite/profiling/profile_summarizer.cc rename to tensorflow/lite/profiling/profile_summarizer.cc index 720bd717b9e3b0..64b1bd7ad771c1 100644 --- a/tensorflow/contrib/lite/profiling/profile_summarizer.cc +++ b/tensorflow/lite/profiling/profile_summarizer.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/profiling/profile_summarizer.h" +#include "tensorflow/lite/profiling/profile_summarizer.h" #include -#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace profiling { diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer.h b/tensorflow/lite/profiling/profile_summarizer.h similarity index 84% rename from tensorflow/contrib/lite/profiling/profile_summarizer.h rename to tensorflow/lite/profiling/profile_summarizer.h index a529ff87428d70..d4f5da7be96adc 100644 --- a/tensorflow/contrib/lite/profiling/profile_summarizer.h +++ b/tensorflow/lite/profiling/profile_summarizer.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_ -#define TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_ +#ifndef TENSORFLOW_LITE_PROFILING_PROFILE_SUMMARIZER_H_ +#define TENSORFLOW_LITE_PROFILING_PROFILE_SUMMARIZER_H_ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/profiling/profiler.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/profiling/profiler.h" #include "tensorflow/core/util/stats_calculator.h" namespace tflite { @@ -52,4 +52,4 @@ class ProfileSummarizer { } // namespace profiling } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_ +#endif // TENSORFLOW_LITE_PROFILING_PROFILE_SUMMARIZER_H_ diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc b/tensorflow/lite/profiling/profile_summarizer_test.cc similarity index 93% rename from tensorflow/contrib/lite/profiling/profile_summarizer_test.cc rename to tensorflow/lite/profiling/profile_summarizer_test.cc index 465c294962df77..bbb64b832aecae 100644 --- a/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc +++ b/tensorflow/lite/profiling/profile_summarizer_test.cc @@ -18,13 +18,13 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/kernels/kernel_util.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/profiling/profile_summarizer.h" -#include "tensorflow/contrib/lite/testing/util.h" -#include "tensorflow/contrib/lite/version.h" +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/profiling/profile_summarizer.h" +#include "tensorflow/lite/testing/util.h" +#include "tensorflow/lite/version.h" namespace tflite { namespace profiling { diff --git a/tensorflow/contrib/lite/profiling/profiler.h b/tensorflow/lite/profiling/profiler.h similarity index 95% rename from tensorflow/contrib/lite/profiling/profiler.h rename to tensorflow/lite/profiling/profiler.h index 8c3e4dc76d8061..89c05cba37b37a 100644 --- a/tensorflow/contrib/lite/profiling/profiler.h +++ b/tensorflow/lite/profiling/profiler.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILER_H_ -#define TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILER_H_ +#ifndef TENSORFLOW_LITE_PROFILING_PROFILER_H_ +#define TENSORFLOW_LITE_PROFILING_PROFILER_H_ #include -#include "tensorflow/contrib/lite/profiling/profile_buffer.h" +#include "tensorflow/lite/profiling/profile_buffer.h" #ifdef TFLITE_PROFILING_ENABLED @@ -176,4 +176,4 @@ class Profiler { #endif // TFLITE_PROFILING_ENABLED -#endif // TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILER_H_ +#endif // TENSORFLOW_LITE_PROFILING_PROFILER_H_ diff --git a/tensorflow/contrib/lite/profiling/profiler_test.cc b/tensorflow/lite/profiling/profiler_test.cc similarity index 97% rename from tensorflow/contrib/lite/profiling/profiler_test.cc rename to tensorflow/lite/profiling/profiler_test.cc index cf56eed2a4643e..82d053729c900f 100644 --- a/tensorflow/contrib/lite/profiling/profiler_test.cc +++ b/tensorflow/lite/profiling/profiler_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/profiling/profiler.h" -#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/lite/profiling/profiler.h" +#include "tensorflow/lite/testing/util.h" namespace tflite { namespace profiling { diff --git a/tensorflow/contrib/lite/profiling/time.cc b/tensorflow/lite/profiling/time.cc similarity index 96% rename from tensorflow/contrib/lite/profiling/time.cc rename to tensorflow/lite/profiling/time.cc index 875ddb02bcfc30..3e7db03d9d8df1 100644 --- a/tensorflow/contrib/lite/profiling/time.cc +++ b/tensorflow/lite/profiling/time.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/profiling/time.h" +#include "tensorflow/lite/profiling/time.h" #if defined(_MSC_VER) #include // NOLINT(build/c++11) diff --git a/tensorflow/contrib/lite/profiling/time.h b/tensorflow/lite/profiling/time.h similarity index 84% rename from tensorflow/contrib/lite/profiling/time.h rename to tensorflow/lite/profiling/time.h index cc2ec319b8a95b..66233a480fd390 100644 --- a/tensorflow/contrib/lite/profiling/time.h +++ b/tensorflow/lite/profiling/time.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_PROFILING_TIME_H_ -#define TENSORFLOW_CONTRIB_LITE_PROFILING_TIME_H_ +#ifndef TENSORFLOW_LITE_PROFILING_TIME_H_ +#define TENSORFLOW_LITE_PROFILING_TIME_H_ #include @@ -24,4 +24,4 @@ uint64_t NowMicros(); } // namespace time } // namespace profiling } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_PROFILING_TIME_H_ +#endif // TENSORFLOW_LITE_PROFILING_TIME_H_ diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD new file mode 100644 index 00000000000000..017dd72f781561 --- /dev/null +++ b/tensorflow/lite/python/BUILD @@ -0,0 +1,189 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "interpreter_test_data", + srcs = glob(["**/testdata/*"]), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "interpreter", + srcs = [ + "interpreter.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/lite/python/interpreter_wrapper:tensorflow_wrap_interpreter_wrapper", + "//tensorflow/python:util", + "//third_party/py/numpy", + ], +) + +py_test( + name = "interpreter_test", + srcs = ["interpreter_test.py"], + data = [":interpreter_test_data"], + srcs_version = "PY2AND3", + tags = ["no_oss"], + deps = [ + ":interpreter", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform", + "//third_party/py/numpy", + ], +) + +py_binary( + name = "tflite_convert", + srcs = ["tflite_convert.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":lite", + ], +) + +py_library( + name = "lite", + srcs = ["lite.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":convert", + ":convert_saved_model", + ":interpreter", + ":lite_constants", + ":op_hint", + "//tensorflow/python:graph_util", + "//tensorflow/python/keras", + "//tensorflow/python/saved_model:constants", + "//tensorflow/python/saved_model:loader", + ], +) + +py_test( + name = "lite_test", + srcs = ["lite_test.py"], + data = ["@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb"], + srcs_version = "PY2AND3", + tags = [ + "no_oss", + "no_windows", + ], + deps = [ + ":lite", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + ], +) + +py_library( + name = "lite_constants", + srcs = ["lite_constants.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/lite/toco:toco_flags_proto_py", + ], +) + +py_library( + name = "convert", + srcs = ["convert.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":lite_constants", + "//tensorflow/lite/toco:model_flags_proto_py", + "//tensorflow/lite/toco:toco_flags_proto_py", + "//tensorflow/lite/toco/python:tensorflow_wrap_toco", + "//tensorflow/lite/toco/python:toco_from_protos", + "//tensorflow/python:platform", + ], +) + +py_library( + name = "op_hint", + srcs = ["op_hint.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/contrib/graph_editor:graph_editor_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:framework", + "//tensorflow/python:platform", + "//tensorflow/python:util", + ], +) + +py_test( + name = "convert_test", + srcs = ["convert_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":convert", + ":interpreter", + ":op_hint", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:platform_test", + "//tensorflow/python:session", + ], +) + +py_library( + name = "convert_saved_model", + srcs = ["convert_saved_model.py"], + srcs_version = "PY2AND3", + visibility = [ + "//tensorflow/contrib/lite:__subpackages__", + "//tensorflow/lite:__subpackages__", + ], + deps = [ + ":convert", + "//tensorflow/python:graph_util", + "//tensorflow/python:platform", + "//tensorflow/python/saved_model", + ], +) + +py_binary( + name = "create_custom_op", + srcs = ["create_custom_op.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:platform", + "@absl_py//absl/flags", + ], +) + +py_test( + name = "convert_saved_model_test", + srcs = ["convert_saved_model_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_windows", + ], + visibility = ["//visibility:public"], + deps = [ + ":convert_saved_model", + "//tensorflow/python:client_testlib", + "//tensorflow/python:layers", + "//tensorflow/python:nn", + "//tensorflow/python:platform_test", + "//tensorflow/python:session", + "//tensorflow/python/keras", + "//tensorflow/python/ops/losses", + "//tensorflow/python/saved_model", + ], +) diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/lite/python/convert.py similarity index 95% rename from tensorflow/contrib/lite/python/convert.py rename to tensorflow/lite/python/convert.py index 7e7a34d31065f4..9991fb2a7335dd 100644 --- a/tensorflow/contrib/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -25,18 +25,19 @@ import subprocess as _subprocess import tempfile as _tempfile -from tensorflow.contrib.lite.python import lite_constants -from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2 -from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 +from tensorflow.lite.python import lite_constants +from tensorflow.lite.toco import model_flags_pb2 as _model_flags_pb2 +from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2 from tensorflow.python.platform import resource_loader as _resource_loader from tensorflow.python.util import deprecation from tensorflow.python.util.lazy_loader import LazyLoader +from tensorflow.python.util.tf_export import tf_export as _tf_export # Lazy load since some of the performance benchmark skylark rules # break dependencies. _toco_python = LazyLoader( "tensorflow_wrap_toco", globals(), - "tensorflow.contrib.lite.toco.python." + "tensorflow.lite.toco.python." "tensorflow_wrap_toco") del LazyLoader @@ -90,11 +91,13 @@ class ConverterError(Exception): pass +# Don't expose these for now. +# @_tf_export("lite.toco_convert_protos") def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str): """Convert `input_data_str` according to model and toco parameters. Unless you know what you are doing consider using - the more friendly `tf.contrib.lite.toco_convert`. + the more friendly `tf.lite.toco_convert`. Args: model_flags_str: Serialized proto describing model properties, see @@ -113,12 +116,12 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str): # TODO(aselle): When toco does not use fatal errors for failure, we can # switch this on. if not _toco_from_proto_bin: - model_str = _toco_python.TocoConvert(model_flags_str, toco_flags_str, - input_data_str) - if not model_str: - raise ConverterError( - "TOCO returned an empty string. See console for more info.") - return model_str + try: + model_str = _toco_python.TocoConvert(model_flags_str, toco_flags_str, + input_data_str) + return model_str + except Exception as e: + raise ConverterError("TOCO failed: %s" % e) # Windows and TemporaryFile are not that useful together, # since you cannot have two readers/writers. So we have to @@ -181,6 +184,8 @@ def tensor_name(x): return x.name.split(":")[0] +# Don't expose these for now. +# @_tf_export("lite.build_toco_convert_protos") def build_toco_convert_protos(input_tensors, output_tensors, inference_type=lite_constants.FLOAT, @@ -288,10 +293,10 @@ def build_toco_convert_protos(input_tensors, toco.dump_graphviz_include_video = dump_graphviz_video if target_ops: if set(target_ops) == set([OpsSet.TFLITE_BUILTINS, OpsSet.SELECT_TF_OPS]): - toco.allow_flex_ops = True + toco.enable_select_tf_ops = True elif set(target_ops) == set([OpsSet.SELECT_TF_OPS]): - toco.allow_flex_ops = True - toco.force_flex_ops = True + toco.enable_select_tf_ops = True + toco.force_select_tf_ops = True model = _model_flags_pb2.ModelFlags() model.change_concat_input_ranges = change_concat_input_ranges @@ -393,6 +398,7 @@ def toco_convert_impl(input_data, input_tensors, output_tensors, *args, return data +@_tf_export("lite.toco_convert") @deprecation.deprecated(None, "Use `lite.TFLiteConverter` instead.") def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs): """Convert a model using TOCO. diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/lite/python/convert_saved_model.py similarity index 99% rename from tensorflow/contrib/lite/python/convert_saved_model.py rename to tensorflow/lite/python/convert_saved_model.py index d18b60d0ea04ee..3f54d2559c4d85 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model.py +++ b/tensorflow/lite/python/convert_saved_model.py @@ -18,7 +18,7 @@ from __future__ import division from __future__ import print_function -from tensorflow.contrib.lite.python.convert import tensor_name +from tensorflow.lite.python.convert import tensor_name from tensorflow.core.framework import types_pb2 from tensorflow.python.client import session from tensorflow.python.framework import graph_util as tf_graph_util diff --git a/tensorflow/contrib/lite/python/convert_saved_model_test.py b/tensorflow/lite/python/convert_saved_model_test.py similarity index 70% rename from tensorflow/contrib/lite/python/convert_saved_model_test.py rename to tensorflow/lite/python/convert_saved_model_test.py index fd81ac7f3883ac..dff582f1a16d2f 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model_test.py +++ b/tensorflow/lite/python/convert_saved_model_test.py @@ -24,25 +24,17 @@ from __future__ import print_function import os -from tensorflow.contrib.lite.python import convert_saved_model -from tensorflow.python import keras +from tensorflow.lite.python import convert_saved_model from tensorflow.python.client import session -from tensorflow.python.estimator import estimator_lib as estimator from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util -from tensorflow.python.layers import layers from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import random_ops -from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test from tensorflow.python.saved_model import saved_model from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import tag_constants -from tensorflow.python.training import training as train class TensorFunctionsTest(test_util.TensorFlowTestCase): @@ -310,150 +302,5 @@ def testMultipleMetaGraphDef(self): self.assertEqual(self._getArrayShapes(in_tensors), [[1, 28, 28]]) -class Model(keras.Model): - """Model to recognize digits in the MNIST dataset. - - Train and export SavedModel, used for testOnflyTrainMnistSavedModel - - Network structure is equivalent to: - https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py - and - https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py - - But written as a ops.keras.Model using the layers API. - """ - - def __init__(self, data_format): - """Creates a model for classifying a hand-written digit. - - Args: - data_format: Either "channels_first" or "channels_last". - "channels_first" is typically faster on GPUs while "channels_last" is - typically faster on CPUs. See - https://www.tensorflow.org/performance/performance_guide#data_formats - """ - super(Model, self).__init__() - self._input_shape = [-1, 28, 28, 1] - - self.conv1 = layers.Conv2D( - 32, 5, padding="same", data_format=data_format, activation=nn.relu) - self.conv2 = layers.Conv2D( - 64, 5, padding="same", data_format=data_format, activation=nn.relu) - self.fc1 = layers.Dense(1024, activation=nn.relu) - self.fc2 = layers.Dense(10) - self.dropout = layers.Dropout(0.4) - self.max_pool2d = layers.MaxPooling2D( - (2, 2), (2, 2), padding="same", data_format=data_format) - - def __call__(self, inputs, training): - """Add operations to classify a batch of input images. - - Args: - inputs: A Tensor representing a batch of input images. - training: A boolean. Set to True to add operations required only when - training the classifier. - - Returns: - A logits Tensor with shape [, 10]. - """ - y = array_ops.reshape(inputs, self._input_shape) - y = self.conv1(y) - y = self.max_pool2d(y) - y = self.conv2(y) - y = self.max_pool2d(y) - y = layers.flatten(y) - y = self.fc1(y) - y = self.dropout(y, training=training) - return self.fc2(y) - - -def model_fn(features, labels, mode, params): - """The model_fn argument for creating an Estimator.""" - model = Model(params["data_format"]) - image = features - if isinstance(image, dict): - image = features["image"] - - if mode == estimator.ModeKeys.PREDICT: - logits = model(image, training=False) - predictions = { - "classes": math_ops.argmax(logits, axis=1), - "probabilities": nn.softmax(logits), - } - return estimator.EstimatorSpec( - mode=estimator.ModeKeys.PREDICT, - predictions=predictions, - export_outputs={ - "classify": estimator.export.PredictOutput(predictions) - }) - - elif mode == estimator.ModeKeys.TRAIN: - optimizer = train.AdamOptimizer(learning_rate=1e-4) - - logits = model(image, training=True) - loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) - return estimator.EstimatorSpec( - mode=estimator.ModeKeys.TRAIN, - loss=loss, - train_op=optimizer.minimize(loss, train.get_or_create_global_step())) - - elif mode == estimator.ModeKeys.EVAL: - logits = model(image, training=False) - loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) - return estimator.EstimatorSpec( - mode=estimator.ModeKeys.EVAL, - loss=loss, - eval_metric_ops={ - "accuracy": - ops.metrics.accuracy( - labels=labels, predictions=math_ops.argmax(logits, axis=1)), - }) - - -def dummy_input_fn(): - image = random_ops.random_uniform([100, 784]) - labels = random_ops.random_uniform([100, 1], maxval=9, dtype=dtypes.int32) - return image, labels - - -class FreezeSavedModelTestTrainGraph(test_util.TensorFlowTestCase): - - def testTrainedMnistSavedModel(self): - """Test mnist SavedModel, trained with dummy data and small steps.""" - # Build classifier - classifier = estimator.Estimator( - model_fn=model_fn, - params={ - "data_format": "channels_last" # tflite format - }) - - # Train and pred for serving - classifier.train(input_fn=dummy_input_fn, steps=2) - image = array_ops.placeholder(dtypes.float32, [None, 28, 28]) - pred_input_fn = estimator.export.build_raw_serving_input_receiver_fn({ - "image": image, - }) - - # Export SavedModel - saved_model_dir = os.path.join(self.get_temp_dir(), "mnist_savedmodel") - classifier.export_saved_model(saved_model_dir, pred_input_fn) - - # Convert to tflite and test output - saved_model_name = os.listdir(saved_model_dir)[0] - saved_model_final_dir = os.path.join(saved_model_dir, saved_model_name) - - # TODO(zhixianyan): no need to limit output_arrays to `Softmax' - # once b/74205001 fixed and argmax implemented in tflite. - result = convert_saved_model.freeze_saved_model( - saved_model_dir=saved_model_final_dir, - input_arrays=None, - input_shapes=None, - output_arrays=["Softmax"], - tag_set=set([tag_constants.SERVING]), - signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY) - - self.assertTrue(result) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/lite/python/convert_test.py b/tensorflow/lite/python/convert_test.py similarity index 95% rename from tensorflow/contrib/lite/python/convert_test.py rename to tensorflow/lite/python/convert_test.py index 40a8b5fafb2dbf..7a0bce921b599f 100644 --- a/tensorflow/contrib/lite/python/convert_test.py +++ b/tensorflow/lite/python/convert_test.py @@ -19,10 +19,10 @@ import numpy as np -from tensorflow.contrib.lite.python import convert -from tensorflow.contrib.lite.python import lite_constants -from tensorflow.contrib.lite.python import op_hint -from tensorflow.contrib.lite.python.interpreter import Interpreter +from tensorflow.lite.python import convert +from tensorflow.lite.python import lite_constants +from tensorflow.lite.python import op_hint +from tensorflow.lite.python.interpreter import Interpreter from tensorflow.python.client import session from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util @@ -146,7 +146,7 @@ def _getGraphOpTypes(self, graphdef, output_nodes): """Returns used op types in `graphdef` reachable from `output_nodes`. This is used to check that after the stub transformation the expected - nodes are there. Typically use this with self.assertCountEqual(...). + nodes are there. NOTE: this is not a exact test that the graph is the correct output, but it balances compact expressibility of test with sanity checking. @@ -196,11 +196,11 @@ def _swish(input_tensor, scale): stubbed_graphdef = op_hint.convert_op_hints_to_stubs( graph_def=sess.graph_def) - self.assertCountEqual( + self.assertEqual( self._getGraphOpTypes( stubbed_graphdef, output_nodes=[op_hint._tensor_name_base(output.name)]), - ["cool_activation", "Const", "Identity"]) + set(["cool_activation", "Const", "Identity"])) def testScaleAndBiasAndIdentity(self): """This tests a scaled add which has 3 inputs and 2 outputs.""" @@ -223,11 +223,11 @@ def _scaled_and_bias_and_identity(a, x, b): stubbed_graphdef = op_hint.convert_op_hints_to_stubs( graph_def=sess.graph_def) - self.assertCountEqual( + self.assertEqual( self._getGraphOpTypes( stubbed_graphdef, output_nodes=[op_hint._tensor_name_base(output.name)]), - ["scale_and_bias_and_identity", "Const", "Identity", "Pack"]) + set(["scale_and_bias_and_identity", "Const", "Identity", "Pack"])) def testTwoFunctions(self): """Tests if two functions are converted correctly.""" @@ -248,11 +248,11 @@ def _double_values(x): self.assertEqual(self._countIdentities(sess.graph_def.node), 5) stubbed_graphdef = op_hint.convert_op_hints_to_stubs( graph_def=sess.graph_def) - self.assertCountEqual( + self.assertEqual( self._getGraphOpTypes( stubbed_graphdef, output_nodes=[op_hint._tensor_name_base(output.name)]), - ["add_test", "Const", "Identity", "Add"]) + set(["add_test", "Const", "Identity", "Add"])) def _get_input_index(self, x): return x.op.node_def.attr[op_hint.OpHint.FUNCTION_INPUT_INDEX_ATTR].i @@ -323,11 +323,11 @@ def testAggregate(self): with self.cached_session() as sess: stubbed_graphdef = op_hint.convert_op_hints_to_stubs( graph_def=sess.graph_def) - self.assertCountEqual( + self.assertEqual( self._getGraphOpTypes( stubbed_graphdef, output_nodes=[op_hint._tensor_name_base(output.name)]), - ["agg", "Const", "Identity"]) + set(["agg", "Const", "Identity"])) if __name__ == "__main__": diff --git a/tensorflow/contrib/lite/python/create_custom_op.py b/tensorflow/lite/python/create_custom_op.py similarity index 98% rename from tensorflow/contrib/lite/python/create_custom_op.py rename to tensorflow/lite/python/create_custom_op.py index 830f95358c4550..344cd28d160f2d 100644 --- a/tensorflow/contrib/lite/python/create_custom_op.py +++ b/tensorflow/lite/python/create_custom_op.py @@ -19,7 +19,7 @@ Example: -bazel run tensorflow/contrib/lite/python:create_custom_op -- \ +bazel run tensorflow/lite/python:create_custom_op -- \ --input_graph=/tmp/input.pb \ --output_graph=/tmp/output.pb \ --inputs=concat,concat_1 \ diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/lite/python/interpreter.py similarity index 98% rename from tensorflow/contrib/lite/python/interpreter.py rename to tensorflow/lite/python/interpreter.py index 4bacccdbacf3e8..a6183d13b56c78 100644 --- a/tensorflow/contrib/lite/python/interpreter.py +++ b/tensorflow/lite/python/interpreter.py @@ -20,6 +20,7 @@ import sys import numpy as np from tensorflow.python.util.lazy_loader import LazyLoader +from tensorflow.python.util.tf_export import tf_export as _tf_export # Lazy load since some of the performance benchmark skylark rules # break dependencies. Must use double quotes to match code internal rewrite @@ -27,13 +28,14 @@ # pylint: disable=g-inconsistent-quotes _interpreter_wrapper = LazyLoader( "_interpreter_wrapper", globals(), - "tensorflow.contrib.lite.python.interpreter_wrapper." + "tensorflow.lite.python.interpreter_wrapper." "tensorflow_wrap_interpreter_wrapper") # pylint: enable=g-inconsistent-quotes del LazyLoader +@_tf_export('lite.Interpreter') class Interpreter(object): """Interpreter inferace for TF-Lite Models.""" diff --git a/tensorflow/contrib/lite/python/interpreter_test.py b/tensorflow/lite/python/interpreter_test.py similarity index 98% rename from tensorflow/contrib/lite/python/interpreter_test.py rename to tensorflow/lite/python/interpreter_test.py index e77d52ca9950ec..7ec56a21c9ffa8 100644 --- a/tensorflow/contrib/lite/python/interpreter_test.py +++ b/tensorflow/lite/python/interpreter_test.py @@ -21,7 +21,7 @@ import numpy as np import six -from tensorflow.contrib.lite.python import interpreter as interpreter_wrapper +from tensorflow.lite.python import interpreter as interpreter_wrapper from tensorflow.python.framework import test_util from tensorflow.python.platform import resource_loader from tensorflow.python.platform import test diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD b/tensorflow/lite/python/interpreter_wrapper/BUILD similarity index 86% rename from tensorflow/contrib/lite/python/interpreter_wrapper/BUILD rename to tensorflow/lite/python/interpreter_wrapper/BUILD index 69ee95c320b72b..767a9fc476398d 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD +++ b/tensorflow/lite/python/interpreter_wrapper/BUILD @@ -11,8 +11,8 @@ cc_library( srcs = ["interpreter_wrapper.cc"], hdrs = ["interpreter_wrapper.h"], deps = [ - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:builtin_ops", "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", "@com_google_absl//absl/memory", diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc similarity index 98% rename from tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc rename to tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc index 1e2384b6d23167..e71752fe6318e8 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h" +#include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h" #include #include #include "absl/memory/memory.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" // Disallow Numpy 1.7 deprecated symbols. #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h similarity index 93% rename from tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h rename to tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h index b98046fe8a2ce5..ffb02780255e41 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_ -#define TENSORFLOW_CONTRIB_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_ +#ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_ +#define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_ #include #include @@ -104,4 +104,4 @@ class InterpreterWrapper { } // namespace interpreter_wrapper } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_ +#endif // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_ diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.i similarity index 88% rename from tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i rename to tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.i index afb2092eacab1d..f52ef1eeca7db3 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.i @@ -18,13 +18,13 @@ limitations under the License. %{ #define SWIG_FILE_WITH_INIT -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h" %} -%include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h" +%include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h" namespace tflite { namespace interpreter_wrapper { diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/lite/python/lite.py similarity index 94% rename from tensorflow/contrib/lite/python/lite.py rename to tensorflow/lite/python/lite.py index 155b436a31b7ec..5810553da2cf8e 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -39,21 +39,21 @@ from google.protobuf import text_format as _text_format from google.protobuf.message import DecodeError -from tensorflow.contrib.lite.python import lite_constants as constants -from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import -from tensorflow.contrib.lite.python.convert import ConverterError # pylint: disable=unused-import -from tensorflow.contrib.lite.python.convert import OpsSet -from tensorflow.contrib.lite.python.convert import tensor_name as _tensor_name -from tensorflow.contrib.lite.python.convert import toco_convert # pylint: disable=unused-import -from tensorflow.contrib.lite.python.convert import toco_convert_graph_def as _toco_convert_graph_def -from tensorflow.contrib.lite.python.convert import toco_convert_impl as _toco_convert_impl -from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import -from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model -from tensorflow.contrib.lite.python.convert_saved_model import get_tensors_from_tensor_names as _get_tensors_from_tensor_names -from tensorflow.contrib.lite.python.convert_saved_model import set_tensor_shapes as _set_tensor_shapes -from tensorflow.contrib.lite.python.interpreter import Interpreter # pylint: disable=unused-import -from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import -from tensorflow.contrib.lite.python.op_hint import OpHint # pylint: disable=unused-import +from tensorflow.lite.python import lite_constants as constants +from tensorflow.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import +from tensorflow.lite.python.convert import ConverterError # pylint: disable=unused-import +from tensorflow.lite.python.convert import OpsSet +from tensorflow.lite.python.convert import tensor_name as _tensor_name +from tensorflow.lite.python.convert import toco_convert # pylint: disable=unused-import +from tensorflow.lite.python.convert import toco_convert_graph_def as _toco_convert_graph_def +from tensorflow.lite.python.convert import toco_convert_impl as _toco_convert_impl +from tensorflow.lite.python.convert import toco_convert_protos # pylint: disable=unused-import +from tensorflow.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model +from tensorflow.lite.python.convert_saved_model import get_tensors_from_tensor_names as _get_tensors_from_tensor_names +from tensorflow.lite.python.convert_saved_model import set_tensor_shapes as _set_tensor_shapes +from tensorflow.lite.python.interpreter import Interpreter # pylint: disable=unused-import +from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import +from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import from tensorflow.core.framework import graph_pb2 as _graph_pb2 from tensorflow.python import keras as _keras from tensorflow.python.client import session as _session @@ -65,8 +65,10 @@ from tensorflow.python.saved_model import signature_constants as _signature_constants from tensorflow.python.saved_model import tag_constants as _tag_constants from tensorflow.python.util import deprecation as _deprecation +from tensorflow.python.util.tf_export import tf_export as _tf_export +@_tf_export("lite.TFLiteConverter") class TFLiteConverter(object): """Convert a TensorFlow model into `output_format` using TOCO. @@ -499,6 +501,7 @@ def _set_batch_size(self, batch_size): tensor.set_shape(shape) +@_tf_export("lite.TocoConverter") class TocoConverter(object): """Convert a TensorFlow model into `output_format` using TOCO. diff --git a/tensorflow/contrib/lite/python/lite_constants.py b/tensorflow/lite/python/lite_constants.py similarity index 71% rename from tensorflow/contrib/lite/python/lite_constants.py rename to tensorflow/lite/python/lite_constants.py index f3c01f455b50f6..fdefc5e6cf0448 100644 --- a/tensorflow/contrib/lite/python/lite_constants.py +++ b/tensorflow/lite/python/lite_constants.py @@ -18,9 +18,10 @@ from __future__ import division from __future__ import print_function -from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 -from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2 +from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2 +from tensorflow.lite.toco import types_pb2 as _types_pb2 from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.util.tf_export import tf_export as _tf_export # Enum types from the protobuf promoted to the API FLOAT = _types_pb2.FLOAT @@ -33,6 +34,16 @@ TFLITE = _toco_flags_pb2.TFLITE GRAPHVIZ_DOT = _toco_flags_pb2.GRAPHVIZ_DOT +_tf_export("lite.constants.FLOAT").export_constant(__name__, "FLOAT") +_tf_export("lite.constants.INT32").export_constant(__name__, "INT32") +_tf_export("lite.constants.INT64").export_constant(__name__, "INT64") +_tf_export("lite.constants.STRING").export_constant(__name__, "STRING") +_tf_export("lite.constants.QUANTIZED_UINT8").export_constant( + __name__, "QUANTIZED_UINT8") +_tf_export("lite.constants.TFLITE").export_constant(__name__, "TFLITE") +_tf_export("lite.constants.GRAPHVIZ_DOT").export_constant( + __name__, "GRAPHVIZ_DOT") + # Currently the default mode of operation is to shell to another python process # to protect against crashes. However, it breaks some dependent targets because # it forces us to depend on an external py_binary. The experimental API doesn't diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py similarity index 98% rename from tensorflow/contrib/lite/python/lite_test.py rename to tensorflow/lite/python/lite_test.py index 70494afa155e28..5a5697db92b216 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/lite/python/lite_test.py @@ -22,9 +22,9 @@ import tempfile import numpy as np -from tensorflow.contrib.lite.python import lite -from tensorflow.contrib.lite.python import lite_constants -from tensorflow.contrib.lite.python.interpreter import Interpreter +from tensorflow.lite.python import lite +from tensorflow.lite.python import lite_constants +from tensorflow.lite.python.interpreter import Interpreter from tensorflow.python import keras from tensorflow.python.client import session from tensorflow.python.framework import constant_op @@ -594,8 +594,17 @@ def testInvalidFileBadData(self): # TODO(nupurgarg): Test model loading in open source. def _initObjectDetectionArgs(self): # Initializes the arguments required for the object detection model. - self._graph_def_file = resource_loader.get_path_to_datafile( - 'testdata/tflite_graph.pb') + # Looks for the model file which is saved in a different location interally + # and externally. + filename = resource_loader.get_path_to_datafile('testdata/tflite_graph.pb') + if not os.path.exists(filename): + filename = os.path.join( + resource_loader.get_root_dir_with_all_resources(), + '../tflite_mobilenet_ssd_quant_protobuf/tflite_graph.pb') + if not os.path.exists(filename): + raise IOError("File '{0}' does not exist.".format(filename)) + + self._graph_def_file = filename self._input_arrays = ['normalized_input_image_tensor'] self._output_arrays = [ 'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1', diff --git a/tensorflow/contrib/lite/python/op_hint.py b/tensorflow/lite/python/op_hint.py similarity index 98% rename from tensorflow/contrib/lite/python/op_hint.py rename to tensorflow/lite/python/op_hint.py index 8c920132e5c2dd..3afce1baf2e3c2 100644 --- a/tensorflow/contrib/lite/python/op_hint.py +++ b/tensorflow/lite/python/op_hint.py @@ -24,7 +24,7 @@ Example: def tflite_cool_activation(input): # A cool activation function. - custom = tf.contrib.lite.OpHint("cool_activation") + custom = tf.lite.OpHint("cool_activation") input, = custom.add_inputs(input) output = tf.sigmoid(input) * input output, = custom.add_outputs(output) @@ -35,8 +35,8 @@ def tflite_cool_activation(input): session = tf.Session() - graphdef_to_convert = tf.contrib.lite.convert_op_hints_to_stubs(session) - tflite_graph = tf.contrib.lite.toco_convert(graphdef_to_convert, + graphdef_to_convert = tf.lite.convert_op_hints_to_stubs(session) + tflite_graph = tf.lite.toco_convert(graphdef_to_convert, [image], [output]) [image], [output]) with open("/tmp/graph.fb", "wb") as fp: @@ -86,8 +86,10 @@ def tflite_cool_activation(input): from tensorflow.python.ops import array_ops as _array_ops from tensorflow.python.util import compat as _compat from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.util.tf_export import tf_export as _tf_export +@_tf_export("lite.OpHint") class OpHint(object): """A class that helps build tflite function invocations. @@ -136,14 +138,14 @@ class OpHint(object): # Types of aggregations # stack: stacks all ophints with matching tags. i.e. for a static rnn. # specifically, this is good for an input or output to a static rnn cell. - AGGREGATE_STACK = _compat.as_bytes("stack") + AGGREGATE_STACK = "stack" # first: only takes the first output (one with lowest sort index) # of matching tags. This is good for the input state to an RNN. - AGGREGATE_FIRST = _compat.as_bytes("first") + AGGREGATE_FIRST = "first" # aggregation last takes only the last tag (one with highest sort index). # This is good for an output value on the last stack item of a # static rnn. - AGGREGATE_LAST = _compat.as_bytes("last") + AGGREGATE_LAST = "last" class OpHintArgumentTracker(object): """Conceptually tracks indices of arguments of "OpHint functions". @@ -656,7 +658,7 @@ def _find_all_hints_in_graph_def(graphdef): if sort == -1: sort = None aggregation = None if OpHint.FUNCTION_AGGREGATE_ATTR in attr: - aggregation = attr[OpHint.FUNCTION_AGGREGATE_ATTR].s + aggregation = _compat.as_text(attr[OpHint.FUNCTION_AGGREGATE_ATTR].s) # Add the input or output def put_operand(stuff, index, sort, operand, aggregation): @@ -936,6 +938,7 @@ def _remove_redundant_stack_unstack(graph_def): return curr +@_tf_export("lite.convert_op_hints_to_stubs") def _convert_op_hints_to_stubs_helper( graph_def, write_callback=lambda sess, graph_def: None): """Converts a graph_def to a new graph_def where all op hints are stubbed. diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/lite/python/tflite_convert.py similarity index 98% rename from tensorflow/contrib/lite/python/tflite_convert.py rename to tensorflow/lite/python/tflite_convert.py index 551424c4b43203..00ea6d722e2493 100644 --- a/tensorflow/contrib/lite/python/tflite_convert.py +++ b/tensorflow/lite/python/tflite_convert.py @@ -22,10 +22,10 @@ import os import sys -from tensorflow.contrib.lite.python import lite -from tensorflow.contrib.lite.python import lite_constants -from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 -from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2 +from tensorflow.lite.python import lite +from tensorflow.lite.python import lite_constants +from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2 +from tensorflow.lite.toco import types_pb2 as _types_pb2 from tensorflow.python.platform import app diff --git a/tensorflow/contrib/lite/schema/BUILD b/tensorflow/lite/schema/BUILD similarity index 96% rename from tensorflow/contrib/lite/schema/BUILD rename to tensorflow/lite/schema/BUILD index d892466c7a1d9c..69d5458c6e432a 100644 --- a/tensorflow/contrib/lite/schema/BUILD +++ b/tensorflow/lite/schema/BUILD @@ -5,7 +5,7 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "py_test") -load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") +load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") py_binary( name = "upgrade_schema", diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/BUILD b/tensorflow/lite/schema/builtin_ops_header/BUILD similarity index 86% rename from tensorflow/contrib/lite/schema/builtin_ops_header/BUILD rename to tensorflow/lite/schema/builtin_ops_header/BUILD index 4a627761daf45b..8a01541d575e28 100644 --- a/tensorflow/contrib/lite/schema/builtin_ops_header/BUILD +++ b/tensorflow/lite/schema/builtin_ops_header/BUILD @@ -9,7 +9,7 @@ cc_library( srcs = ["generator.cc"], hdrs = ["generator.h"], deps = [ - "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_fbs", ], ) @@ -35,7 +35,7 @@ cc_test( name = "consistency_test", srcs = ["consistency_test.cc"], data = [ - "//tensorflow/contrib/lite:builtin_ops.h", + "//tensorflow/lite:builtin_ops.h", ], tags = ["no_oss"], deps = [ diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/README.md b/tensorflow/lite/schema/builtin_ops_header/README.md similarity index 65% rename from tensorflow/contrib/lite/schema/builtin_ops_header/README.md rename to tensorflow/lite/schema/builtin_ops_header/README.md index f20d4f664e62fd..e34a30b8182560 100644 --- a/tensorflow/contrib/lite/schema/builtin_ops_header/README.md +++ b/tensorflow/lite/schema/builtin_ops_header/README.md @@ -7,6 +7,6 @@ Whenever you add a new builtin op, please execute: ```sh bazel run \ - //tensorflow/contrib/lite/schema/builtin_ops_header:generate > \ - tensorflow/contrib/lite/builtin_ops.h + //tensorflow/lite/schema/builtin_ops_header:generate > \ + tensorflow/lite/builtin_ops.h ``` diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/consistency_test.cc b/tensorflow/lite/schema/builtin_ops_header/consistency_test.cc similarity index 93% rename from tensorflow/contrib/lite/schema/builtin_ops_header/consistency_test.cc rename to tensorflow/lite/schema/builtin_ops_header/consistency_test.cc index d55c125c117db3..f62dcda2e82851 100644 --- a/tensorflow/contrib/lite/schema/builtin_ops_header/consistency_test.cc +++ b/tensorflow/lite/schema/builtin_ops_header/consistency_test.cc @@ -15,12 +15,12 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/schema/builtin_ops_header/generator.h" +#include "tensorflow/lite/schema/builtin_ops_header/generator.h" namespace { const char* kHeaderFileName = - "tensorflow/contrib/lite/builtin_ops.h"; + "tensorflow/lite/builtin_ops.h"; // The test ensures that `builtin_ops.h` is consistent with the FlatBuffer // schema definition. When the schema is modified, it's required to run the diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/generate.cc b/tensorflow/lite/schema/builtin_ops_header/generate.cc similarity index 92% rename from tensorflow/contrib/lite/schema/builtin_ops_header/generate.cc rename to tensorflow/lite/schema/builtin_ops_header/generate.cc index 72a28987b8d486..125dcd485be0fb 100644 --- a/tensorflow/contrib/lite/schema/builtin_ops_header/generate.cc +++ b/tensorflow/lite/schema/builtin_ops_header/generate.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/schema/builtin_ops_header/generator.h" +#include "tensorflow/lite/schema/builtin_ops_header/generator.h" // This executable is used to generate builtin_ops.h in TensorFlow Lite. // Please see README.md for more details. diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc b/tensorflow/lite/schema/builtin_ops_header/generator.cc similarity index 92% rename from tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc rename to tensorflow/lite/schema/builtin_ops_header/generator.cc index 9dc8daa227dd68..e2967aee0ff4cb 100644 --- a/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc +++ b/tensorflow/lite/schema/builtin_ops_header/generator.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/schema/builtin_ops_header/generator.h" -#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/builtin_ops_header/generator.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace builtin_ops_header { @@ -35,8 +35,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ -#define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ +#ifndef TENSORFLOW_LITE_BUILTIN_OPS_H_ +#define TENSORFLOW_LITE_BUILTIN_OPS_H_ // DO NOT EDIT MANUALLY: This file is automatically generated by // `schema/builtin_ops_header/generator.cc`. @@ -56,7 +56,7 @@ const char* kFileFooter = #ifdef __cplusplus } // extern "C" #endif // __cplusplus -#endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ +#endif // TENSORFLOW_LITE_BUILTIN_OPS_H_ )"; } // anonymous namespace diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/generator.h b/tensorflow/lite/schema/builtin_ops_header/generator.h similarity index 86% rename from tensorflow/contrib/lite/schema/builtin_ops_header/generator.h rename to tensorflow/lite/schema/builtin_ops_header/generator.h index 3241ff83d599ed..8c9383a992daa5 100644 --- a/tensorflow/contrib/lite/schema/builtin_ops_header/generator.h +++ b/tensorflow/lite/schema/builtin_ops_header/generator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // An utility library to generate pure C header for builtin ops definition. -#ifndef TENSORFLOW_CONTRIB_LITE_SCHEMA_BUILTIN_OPS_HEADER_GENERATOR_H_ -#define TENSORFLOW_CONTRIB_LITE_SCHEMA_BUILTIN_OPS_HEADER_GENERATOR_H_ +#ifndef TENSORFLOW_LITE_SCHEMA_BUILTIN_OPS_HEADER_GENERATOR_H_ +#define TENSORFLOW_LITE_SCHEMA_BUILTIN_OPS_HEADER_GENERATOR_H_ #include @@ -35,4 +35,4 @@ bool GenerateHeader(std::ostream& os); } // namespace builtin_ops_header } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_SCHEMA_BUILTIN_OPS_HEADER_GENERATOR_H_ +#endif // TENSORFLOW_LITE_SCHEMA_BUILTIN_OPS_HEADER_GENERATOR_H_ diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/generator_test.cc b/tensorflow/lite/schema/builtin_ops_header/generator_test.cc similarity index 96% rename from tensorflow/contrib/lite/schema/builtin_ops_header/generator_test.cc rename to tensorflow/lite/schema/builtin_ops_header/generator_test.cc index a7dc8e1b0486ed..c508c981bb3a75 100644 --- a/tensorflow/contrib/lite/schema/builtin_ops_header/generator_test.cc +++ b/tensorflow/lite/schema/builtin_ops_header/generator_test.cc @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/schema/builtin_ops_header/generator.h" +#include "tensorflow/lite/schema/builtin_ops_header/generator.h" #include #include diff --git a/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc b/tensorflow/lite/schema/flatbuffer_compatibility_test.cc similarity index 96% rename from tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc rename to tensorflow/lite/schema/flatbuffer_compatibility_test.cc index 22b4616ccbb756..86177aeb127246 100644 --- a/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc +++ b/tensorflow/lite/schema/flatbuffer_compatibility_test.cc @@ -62,9 +62,9 @@ TEST(SchemaTest, TestCompatibility) { // TODO(aselle): Need a reliable way to load files. std::string base_contents, current_contents; const char *base_filename = - TFLITE_TF_PREFIX "contrib/lite/schema/schema_v3.fbs"; + TFLITE_TF_PREFIX "lite/schema/schema_v3.fbs"; const char *current_filename = - TFLITE_TF_PREFIX "contrib/lite/schema/schema.fbs"; + TFLITE_TF_PREFIX "lite/schema/schema.fbs"; ASSERT_TRUE(LoadFileRaw(base_filename, &base_contents)); ASSERT_TRUE(LoadFileRaw(current_filename, ¤t_contents)); diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs similarity index 100% rename from tensorflow/contrib/lite/schema/schema.fbs rename to tensorflow/lite/schema/schema.fbs diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h similarity index 100% rename from tensorflow/contrib/lite/schema/schema_generated.h rename to tensorflow/lite/schema/schema_generated.h diff --git a/tensorflow/contrib/lite/schema/schema_v0.fbs b/tensorflow/lite/schema/schema_v0.fbs similarity index 100% rename from tensorflow/contrib/lite/schema/schema_v0.fbs rename to tensorflow/lite/schema/schema_v0.fbs diff --git a/tensorflow/contrib/lite/schema/schema_v1.fbs b/tensorflow/lite/schema/schema_v1.fbs similarity index 100% rename from tensorflow/contrib/lite/schema/schema_v1.fbs rename to tensorflow/lite/schema/schema_v1.fbs diff --git a/tensorflow/contrib/lite/schema/schema_v2.fbs b/tensorflow/lite/schema/schema_v2.fbs similarity index 100% rename from tensorflow/contrib/lite/schema/schema_v2.fbs rename to tensorflow/lite/schema/schema_v2.fbs diff --git a/tensorflow/contrib/lite/schema/schema_v3.fbs b/tensorflow/lite/schema/schema_v3.fbs similarity index 100% rename from tensorflow/contrib/lite/schema/schema_v3.fbs rename to tensorflow/lite/schema/schema_v3.fbs diff --git a/tensorflow/contrib/lite/schema/upgrade_schema.py b/tensorflow/lite/schema/upgrade_schema.py similarity index 97% rename from tensorflow/contrib/lite/schema/upgrade_schema.py rename to tensorflow/lite/schema/upgrade_schema.py index a2ddf6295014f3..d9220ba10ca2e9 100644 --- a/tensorflow/contrib/lite/schema/upgrade_schema.py +++ b/tensorflow/lite/schema/upgrade_schema.py @@ -16,11 +16,11 @@ Usage examples: -bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.json out.json -bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.bin out.bin -bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.bin out.json -bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.json out.bin -bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.tflite out.tflite +bazel run tensorflow/lite/schema/upgrade_schema -- in.json out.json +bazel run tensorflow/lite/schema/upgrade_schema -- in.bin out.bin +bazel run tensorflow/lite/schema/upgrade_schema -- in.bin out.json +bazel run tensorflow/lite/schema/upgrade_schema -- in.json out.bin +bazel run tensorflow/lite/schema/upgrade_schema -- in.tflite out.tflite """ from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/lite/schema/upgrade_schema_test.py b/tensorflow/lite/schema/upgrade_schema_test.py similarity index 99% rename from tensorflow/contrib/lite/schema/upgrade_schema_test.py rename to tensorflow/lite/schema/upgrade_schema_test.py index b5002e6f7576b6..922968c65aa760 100644 --- a/tensorflow/contrib/lite/schema/upgrade_schema_test.py +++ b/tensorflow/lite/schema/upgrade_schema_test.py @@ -20,7 +20,7 @@ import json import tempfile -from tensorflow.contrib.lite.schema import upgrade_schema as upgrade_schema_lib +from tensorflow.lite.schema import upgrade_schema as upgrade_schema_lib from tensorflow.python.framework import test_util from tensorflow.python.platform import test as test_lib diff --git a/tensorflow/contrib/lite/simple_memory_arena.cc b/tensorflow/lite/simple_memory_arena.cc similarity index 98% rename from tensorflow/contrib/lite/simple_memory_arena.cc rename to tensorflow/lite/simple_memory_arena.cc index cd0f1f7c17a50f..88bdf50c9b64c6 100644 --- a/tensorflow/contrib/lite/simple_memory_arena.cc +++ b/tensorflow/lite/simple_memory_arena.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/simple_memory_arena.h" +#include "tensorflow/lite/simple_memory_arena.h" #include #include diff --git a/tensorflow/contrib/lite/simple_memory_arena.h b/tensorflow/lite/simple_memory_arena.h similarity index 92% rename from tensorflow/contrib/lite/simple_memory_arena.h rename to tensorflow/lite/simple_memory_arena.h index 45d0d8735ee10a..42203c0c0a32d6 100644 --- a/tensorflow/contrib/lite/simple_memory_arena.h +++ b/tensorflow/lite/simple_memory_arena.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_ -#define TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_ +#ifndef TENSORFLOW_LITE_SIMPLE_MEMORY_ARENA_H_ +#define TENSORFLOW_LITE_SIMPLE_MEMORY_ARENA_H_ #include #include -#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/lite/c/c_api_internal.h" namespace tflite { @@ -86,4 +86,4 @@ class SimpleMemoryArena { } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_ +#endif // TENSORFLOW_LITE_SIMPLE_MEMORY_ARENA_H_ diff --git a/tensorflow/contrib/lite/simple_memory_arena_test.cc b/tensorflow/lite/simple_memory_arena_test.cc similarity index 97% rename from tensorflow/contrib/lite/simple_memory_arena_test.cc rename to tensorflow/lite/simple_memory_arena_test.cc index 60d4d5e768aeda..caf13db2c1a6e9 100644 --- a/tensorflow/contrib/lite/simple_memory_arena_test.cc +++ b/tensorflow/lite/simple_memory_arena_test.cc @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/simple_memory_arena.h" +#include "tensorflow/lite/simple_memory_arena.h" #include #include -#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/lite/testing/util.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/special_rules.bzl b/tensorflow/lite/special_rules.bzl similarity index 59% rename from tensorflow/contrib/lite/special_rules.bzl rename to tensorflow/lite/special_rules.bzl index 54083c49182c70..e10af3d240eebf 100644 --- a/tensorflow/contrib/lite/special_rules.bzl +++ b/tensorflow/lite/special_rules.bzl @@ -1,6 +1,6 @@ """External versions of build rules that differ outside of Google.""" def tflite_portable_test_suite(**kwargs): - """This is a no-op outside of Google.""" - _ignore = [kwargs] - pass + """This is a no-op outside of Google.""" + _ignore = [kwargs] + pass diff --git a/tensorflow/contrib/lite/stderr_reporter.cc b/tensorflow/lite/stderr_reporter.cc similarity index 96% rename from tensorflow/contrib/lite/stderr_reporter.cc rename to tensorflow/lite/stderr_reporter.cc index e29a6345fdfe4c..09eb1d254a608b 100644 --- a/tensorflow/contrib/lite/stderr_reporter.cc +++ b/tensorflow/lite/stderr_reporter.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/stderr_reporter.h" +#include "tensorflow/lite/stderr_reporter.h" #include #include diff --git a/tensorflow/contrib/lite/stderr_reporter.h b/tensorflow/lite/stderr_reporter.h similarity index 78% rename from tensorflow/contrib/lite/stderr_reporter.h rename to tensorflow/lite/stderr_reporter.h index c6f4ffbdffb4b3..7582b421ee3c95 100644 --- a/tensorflow/contrib/lite/stderr_reporter.h +++ b/tensorflow/lite/stderr_reporter.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_STDERR_REPORTER_H_ -#define TENSORFLOW_CONTRIB_LITE_STDERR_REPORTER_H_ +#ifndef TENSORFLOW_LITE_STDERR_REPORTER_H_ +#define TENSORFLOW_LITE_STDERR_REPORTER_H_ #include -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/core/api/error_reporter.h" namespace tflite { @@ -31,4 +31,4 @@ ErrorReporter* DefaultErrorReporter(); } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_STDERR_REPORTER_H_ +#endif // TENSORFLOW_LITE_STDERR_REPORTER_H_ diff --git a/tensorflow/contrib/lite/string.h b/tensorflow/lite/string.h similarity index 86% rename from tensorflow/contrib/lite/string.h rename to tensorflow/lite/string.h index af3fadfcb35074..65142b11de389f 100644 --- a/tensorflow/contrib/lite/string.h +++ b/tensorflow/lite/string.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // Abstract string. We don't want even absl at this level. -#ifndef TENSORFLOW_CONTRIB_LITE_STRING_H_ -#define TENSORFLOW_CONTRIB_LITE_STRING_H_ +#ifndef TENSORFLOW_LITE_STRING_H_ +#define TENSORFLOW_LITE_STRING_H_ #include @@ -26,4 +26,4 @@ using std::string; } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_STRING_H_ +#endif // TENSORFLOW_LITE_STRING_H_ diff --git a/tensorflow/contrib/lite/string_util.cc b/tensorflow/lite/string_util.cc similarity index 90% rename from tensorflow/contrib/lite/string_util.cc rename to tensorflow/lite/string_util.cc index b991e999b66daa..1b33f5bcba01bf 100644 --- a/tensorflow/contrib/lite/string_util.cc +++ b/tensorflow/lite/string_util.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/lite/string_util.h" +#include #include #include -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/lite/c/c_api_internal.h" namespace tflite { namespace { @@ -97,13 +97,19 @@ int DynamicBuffer::WriteToBuffer(char** buffer) { } void DynamicBuffer::WriteToTensor(TfLiteTensor* tensor) { + // Set tensor content pointer to tensor_buffer, and release original data. + auto dims = TfLiteIntArrayCreate(1); + dims->data[0] = offset_.size() - 1; // Store number of strings. + WriteToTensor(tensor, dims); +} + +void DynamicBuffer::WriteToTensor(TfLiteTensor* tensor, + TfLiteIntArray* new_shape) { char* tensor_buffer; int bytes = WriteToBuffer(&tensor_buffer); // Set tensor content pointer to tensor_buffer, and release original data. - auto dims = TfLiteIntArrayCreate(1); - dims->data[0] = offset_.size() - 1; // Store number of strings. - TfLiteTensorReset(tensor->type, tensor->name, dims, tensor->params, + TfLiteTensorReset(tensor->type, tensor->name, new_shape, tensor->params, tensor_buffer, bytes, kTfLiteDynamic, tensor->allocation, tensor->is_variable, tensor); } diff --git a/tensorflow/contrib/lite/string_util.h b/tensorflow/lite/string_util.h similarity index 87% rename from tensorflow/contrib/lite/string_util.h rename to tensorflow/lite/string_util.h index d24627b509558d..c9b74482f7d04b 100644 --- a/tensorflow/contrib/lite/string_util.h +++ b/tensorflow/lite/string_util.h @@ -37,13 +37,13 @@ limitations under the License. // # described above. // buf.WriteToTensor(tensor) -#ifndef TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_ -#define TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_ +#ifndef TENSORFLOW_LITE_STRING_UTIL_H_ +#define TENSORFLOW_LITE_STRING_UTIL_H_ #include -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/string.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/string.h" namespace tflite { @@ -74,7 +74,11 @@ class DynamicBuffer { // The function allocates space for the buffer but does NOT take ownership. int WriteToBuffer(char** buffer); - // Fill content into a string tensor. + // Fill content into a string tensor, with the given new_shape. The new + // shape must match the number of strings in this object. + void WriteToTensor(TfLiteTensor* tensor, TfLiteIntArray* new_shape); + + // Fill content into a string tensor. Set shape to {num_strings}. void WriteToTensor(TfLiteTensor* tensor); private: @@ -94,4 +98,4 @@ StringRef GetString(const char* raw_buffer, int string_index); StringRef GetString(const TfLiteTensor* tensor, int string_index); } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_ +#endif // TENSORFLOW_LITE_STRING_UTIL_H_ diff --git a/tensorflow/contrib/lite/string_util_test.cc b/tensorflow/lite/string_util_test.cc similarity index 86% rename from tensorflow/contrib/lite/string_util_test.cc rename to tensorflow/lite/string_util_test.cc index a583a9184be91b..377cdd77eb4651 100644 --- a/tensorflow/contrib/lite/string_util_test.cc +++ b/tensorflow/lite/string_util_test.cc @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/lite/string_util.h" #include -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/testing/util.h" namespace tflite { @@ -50,9 +50,21 @@ TEST(StringUtil, TestStringUtil) { DynamicBuffer buf1; buf1.AddString(s1.data(), s1.length()); buf0.AddString(s2, 0); - buf0.WriteToTensor(t0); + + auto new_shape = TfLiteIntArrayCreate(2); + new_shape->data[0] = 2; + new_shape->data[1] = 1; + buf0.WriteToTensor(t0, new_shape); buf1.WriteToTensor(t1); + // Check tensor shapes. + EXPECT_EQ(t0->dims->size, 2); + EXPECT_EQ(t0->dims->data[0], 2); + EXPECT_EQ(t0->dims->data[1], 1); + + EXPECT_EQ(t1->dims->size, 1); + EXPECT_EQ(t1->dims->data[0], 1); + // Read strings from tensors. ASSERT_EQ(GetStringCount(t0), 2); StringRef str_ref; diff --git a/tensorflow/contrib/lite/testdata/0_subgraphs.bin b/tensorflow/lite/testdata/0_subgraphs.bin similarity index 100% rename from tensorflow/contrib/lite/testdata/0_subgraphs.bin rename to tensorflow/lite/testdata/0_subgraphs.bin diff --git a/tensorflow/contrib/lite/testdata/2_subgraphs.bin b/tensorflow/lite/testdata/2_subgraphs.bin similarity index 100% rename from tensorflow/contrib/lite/testdata/2_subgraphs.bin rename to tensorflow/lite/testdata/2_subgraphs.bin diff --git a/tensorflow/lite/testdata/add.bin b/tensorflow/lite/testdata/add.bin new file mode 100644 index 00000000000000..b4c02350c09130 Binary files /dev/null and b/tensorflow/lite/testdata/add.bin differ diff --git a/tensorflow/lite/testdata/add.json b/tensorflow/lite/testdata/add.json new file mode 100644 index 00000000000000..f589bebfbf257b --- /dev/null +++ b/tensorflow/lite/testdata/add.json @@ -0,0 +1,79 @@ +{ + version: 3, + operator_codes: [ + { + } + ], + subgraphs: [ + { + tensors: [ + { + shape: [ + 1, + 8, + 8, + 3 + ], + name: "add" + }, + { + shape: [ + 1, + 8, + 8, + 3 + ], + name: "input" + }, + { + shape: [ + 1, + 8, + 8, + 3 + ], + name: "output" + } + ], + inputs: [ + 1 + ], + outputs: [ + 2 + ], + operators: [ + { + inputs: [ + 1, + 1 + ], + outputs: [ + 0 + ], + builtin_options_type: "AddOptions", + builtin_options: { + } + }, + { + inputs: [ + 0, + 1 + ], + outputs: [ + 2 + ], + builtin_options_type: "AddOptions", + builtin_options: { + } + } + ] + } + ], + buffers: [ + { + data: [ + + ] + } + ] +} diff --git a/tensorflow/lite/testdata/add_quantized.bin b/tensorflow/lite/testdata/add_quantized.bin new file mode 100644 index 00000000000000..07d48b93eb87f9 Binary files /dev/null and b/tensorflow/lite/testdata/add_quantized.bin differ diff --git a/tensorflow/lite/testdata/add_quantized.json b/tensorflow/lite/testdata/add_quantized.json new file mode 100644 index 00000000000000..f70ed8143e99c7 --- /dev/null +++ b/tensorflow/lite/testdata/add_quantized.json @@ -0,0 +1,123 @@ +{ + version: 3, + operator_codes: [ + { + } + ], + subgraphs: [ + { + tensors: [ + { + shape: [ + 1, + 8, + 8, + 3 + ], + name: "add", + quantization: { + min: [ + 0.0 + ], + max: [ + 1.0 + ], + scale: [ + 0.003922 + ], + zero_point: [ + 0 + ] + } + }, + { + shape: [ + 1, + 8, + 8, + 3 + ], + type: "UINT8", + name: "input", + quantization: { + min: [ + 0.0 + ], + max: [ + 1.0 + ], + scale: [ + 0.003922 + ], + zero_point: [ + 0 + ] + } + }, + { + shape: [ + 1, + 8, + 8, + 3 + ], + type: "UINT8", + name: "output", + quantization: { + min: [ + 0.0 + ], + max: [ + 1.0 + ], + scale: [ + 0.003922 + ], + zero_point: [ + 0 + ] + } + } + ], + inputs: [ + 1 + ], + outputs: [ + 2 + ], + operators: [ + { + inputs: [ + 1, + 1 + ], + outputs: [ + 0 + ], + builtin_options_type: "AddOptions", + builtin_options: { + } + }, + { + inputs: [ + 0, + 1 + ], + outputs: [ + 2 + ], + builtin_options_type: "AddOptions", + builtin_options: { + } + } + ] + } + ], + buffers: [ + { + data: [ + + ] + } + ] +} diff --git a/tensorflow/contrib/lite/testdata/empty_model.bin b/tensorflow/lite/testdata/empty_model.bin similarity index 100% rename from tensorflow/contrib/lite/testdata/empty_model.bin rename to tensorflow/lite/testdata/empty_model.bin diff --git a/tensorflow/contrib/lite/testdata/multi_add.bin b/tensorflow/lite/testdata/multi_add.bin similarity index 100% rename from tensorflow/contrib/lite/testdata/multi_add.bin rename to tensorflow/lite/testdata/multi_add.bin diff --git a/tensorflow/contrib/lite/testdata/multi_add.json b/tensorflow/lite/testdata/multi_add.json similarity index 100% rename from tensorflow/contrib/lite/testdata/multi_add.json rename to tensorflow/lite/testdata/multi_add.json diff --git a/tensorflow/contrib/lite/testdata/multi_add.pb b/tensorflow/lite/testdata/multi_add.pb similarity index 100% rename from tensorflow/contrib/lite/testdata/multi_add.pb rename to tensorflow/lite/testdata/multi_add.pb diff --git a/tensorflow/contrib/lite/testdata/multi_add_flex.bin b/tensorflow/lite/testdata/multi_add_flex.bin similarity index 100% rename from tensorflow/contrib/lite/testdata/multi_add_flex.bin rename to tensorflow/lite/testdata/multi_add_flex.bin diff --git a/tensorflow/contrib/lite/testdata/no_subgraphs.bin b/tensorflow/lite/testdata/no_subgraphs.bin similarity index 100% rename from tensorflow/contrib/lite/testdata/no_subgraphs.bin rename to tensorflow/lite/testdata/no_subgraphs.bin diff --git a/tensorflow/contrib/lite/testdata/test_model.bin b/tensorflow/lite/testdata/test_model.bin similarity index 100% rename from tensorflow/contrib/lite/testdata/test_model.bin rename to tensorflow/lite/testdata/test_model.bin diff --git a/tensorflow/contrib/lite/testdata/test_model_broken.bin b/tensorflow/lite/testdata/test_model_broken.bin similarity index 100% rename from tensorflow/contrib/lite/testdata/test_model_broken.bin rename to tensorflow/lite/testdata/test_model_broken.bin diff --git a/tensorflow/contrib/lite/testdata/test_model_broken.json b/tensorflow/lite/testdata/test_model_broken.json similarity index 100% rename from tensorflow/contrib/lite/testdata/test_model_broken.json rename to tensorflow/lite/testdata/test_model_broken.json diff --git a/tensorflow/contrib/lite/testdata/two_subgraphs.bin b/tensorflow/lite/testdata/two_subgraphs.bin similarity index 100% rename from tensorflow/contrib/lite/testdata/two_subgraphs.bin rename to tensorflow/lite/testdata/two_subgraphs.bin diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/lite/testing/BUILD similarity index 80% rename from tensorflow/contrib/lite/testing/BUILD rename to tensorflow/lite/testing/BUILD index 891d44d2b60c71..df448e8a880f68 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/lite/testing/BUILD @@ -5,11 +5,11 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 load( - "//tensorflow/contrib/lite:build_def.bzl", + "//tensorflow/lite:build_def.bzl", "gen_zip_test", "generated_test_models_all", ) -load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") +load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") load( "//tensorflow:tensorflow.bzl", "tf_cc_test", @@ -36,7 +36,6 @@ load( shard_count = 20, tags = tags + [ "gen_zip_test", - "no_oss", "tflite_not_portable_intentional", ], test_name = test_name, @@ -46,9 +45,9 @@ load( ":util", "@com_google_googletest//:gtest", "@com_googlesource_code_re2//:re2", - "//tensorflow/contrib/lite:builtin_op_data", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/lite:builtin_op_data", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:builtin_ops", ] + select({ "//conditions:default": [ "//tensorflow/core:framework_internal", @@ -73,7 +72,7 @@ py_binary( name = "generate_examples", srcs = ["generate_examples.py"], data = [ - "//tensorflow/contrib/lite/toco", + "//tensorflow/lite/toco", ], srcs_version = "PY2AND3", deps = [ @@ -99,7 +98,7 @@ cc_library( ":message", ":split", ":test_runner", - "//tensorflow/contrib/lite:framework", + "//tensorflow/lite:framework", ], ) @@ -124,7 +123,7 @@ cc_library( srcs = ["split.cc"], hdrs = ["split.h"], deps = [ - "//tensorflow/contrib/lite:string", + "//tensorflow/lite:string", ], ) @@ -141,7 +140,7 @@ cc_test( cc_library( name = "join", hdrs = ["join.h"], - deps = ["//tensorflow/contrib/lite:string"], + deps = ["//tensorflow/lite:string"], ) cc_test( @@ -161,10 +160,10 @@ cc_library( deps = [ ":split", ":test_runner", - "//tensorflow/contrib/lite:builtin_op_data", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/delegates/flex:delegate", - "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/lite:builtin_op_data", + "//tensorflow/lite:framework", + "//tensorflow/lite/delegates/flex:delegate", + "//tensorflow/lite/kernels:builtin_ops", ], ) @@ -172,7 +171,7 @@ tf_cc_test( name = "tflite_driver_test", size = "small", srcs = ["tflite_driver_test.cc"], - data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"], + data = ["//tensorflow/lite:testdata/multi_add.bin"], tags = [ "tflite_not_portable_android", "tflite_not_portable_ios", @@ -188,7 +187,7 @@ cc_library( srcs = ["tokenize.cc"], hdrs = ["tokenize.h"], deps = [ - "//tensorflow/contrib/lite:string", + "//tensorflow/lite:string", ], ) @@ -205,7 +204,7 @@ cc_library( name = "test_runner", hdrs = ["test_runner.h"], deps = [ - "//tensorflow/contrib/lite:string", + "//tensorflow/lite:string", ], ) @@ -213,9 +212,9 @@ cc_library( name = "util", hdrs = ["util.h"], deps = [ - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:string", - "//tensorflow/contrib/lite/core/api", + "//tensorflow/lite:framework", + "//tensorflow/lite:string", + "//tensorflow/lite/core/api", ], ) @@ -234,7 +233,7 @@ cc_binary( deps = [ ":parse_testdata_lib", ":tflite_driver", - "//tensorflow/contrib/lite/nnapi:nnapi_lib", + "//tensorflow/lite/nnapi:nnapi_lib", ], ) @@ -258,9 +257,8 @@ cc_test( name = "tf_driver_test", size = "small", srcs = ["tf_driver_test.cc"], - data = ["//tensorflow/contrib/lite:testdata/multi_add.pb"], + data = ["//tensorflow/lite:testdata/multi_add.pb"], tags = [ - "no_oss", "tflite_not_portable", ], deps = [ @@ -277,8 +275,8 @@ cc_library( ":join", ":split", ":tf_driver", - "//tensorflow/contrib/lite:string", "//tensorflow/core:framework", + "//tensorflow/lite:string", ], ) @@ -287,7 +285,6 @@ cc_test( size = "small", srcs = ["generate_testspec_test.cc"], tags = [ - "no_oss", "tflite_not_portable", ], deps = [ @@ -305,9 +302,9 @@ cc_library( "init_tensorflow.h", ], visibility = [ - "//tensorflow/contrib/lite/java/src/main/native:__subpackages__", - "//tensorflow/contrib/lite/testing:__subpackages__", - "//tensorflow/contrib/lite/tools/benchmark:__subpackages__", + "//tensorflow/lite/java/src/main/native:__subpackages__", + "//tensorflow/lite/testing:__subpackages__", + "//tensorflow/lite/tools/benchmark:__subpackages__", ], deps = select({ "//conditions:default": [ @@ -327,8 +324,8 @@ cc_library( ":generate_testspec", ":parse_testdata_lib", ":tflite_driver", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:string", + "//tensorflow/lite:framework", + "//tensorflow/lite:string", ], ) @@ -354,16 +351,16 @@ tf_cc_test( size = "medium", srcs = ["tflite_diff_example_test.cc"], args = [ - "--tensorflow_model=third_party/tensorflow/contrib/lite/testdata/multi_add.pb", - "--tflite_model=third_party/tensorflow/contrib/lite/testdata/multi_add.bin", + "--tensorflow_model=third_party/tensorflow/lite/testdata/multi_add.pb", + "--tflite_model=third_party/tensorflow/lite/testdata/multi_add.bin", "--input_layer=a,b,c,d", "--input_layer_type=float,float,float,float", "--input_layer_shape=1,3,4,3:1,3,4,3:1,3,4,3:1,3,4,3", "--output_layer=x,y", ], data = [ - "//tensorflow/contrib/lite:testdata/multi_add.bin", - "//tensorflow/contrib/lite:testdata/multi_add.pb", + "//tensorflow/lite:testdata/multi_add.bin", + "//tensorflow/lite:testdata/multi_add.pb", ], tags = [ "no_cuda_on_cpu_tap", diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/lite/testing/generate_examples.py similarity index 99% rename from tensorflow/contrib/lite/testing/generate_examples.py rename to tensorflow/lite/testing/generate_examples.py index 408b540bf11ad5..81b5ed80987771 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/lite/testing/generate_examples.py @@ -19,7 +19,7 @@ generate_examples -bazel run //tensorflow/contrib/lite/testing:generate_examples +bazel run //tensorflow/lite/testing:generate_examples To more easily debug failures use (or override) the --save_graphdefs flag to place text proto graphdefs into the generated zip files. @@ -51,7 +51,7 @@ import tensorflow as tf from google.protobuf import text_format # TODO(aselle): switch to TensorFlow's resource_loader -from tensorflow.contrib.lite.testing import generate_examples_report as report_lib +from tensorflow.lite.testing import generate_examples_report as report_lib from tensorflow.python.framework import graph_util as tf_graph_util from tensorflow.python.ops import rnn @@ -343,7 +343,7 @@ def toco_convert(graph_def_str, input_tensors, output_tensors, opts = ("--input_arrays={0} --output_arrays={1}".format( ",".join(input_arrays), ",".join(output_tensors))) elif FLAGS.run_with_flex: - opts += " --allow_flex_ops --force_flex_ops" + opts += " --enable_select_tf_ops --force_select_tf_ops" cmd = ("%s --input_file=%s --output_file=%s %s > %s 2>&1" % (bin_path, graphdef_file.name, output_file.name, opts, stdout_file.name)) @@ -806,6 +806,11 @@ def make_binary_op_tests(zip_path, binary_operator): "input_shape_1": [[]], "input_shape_2": [[]], "activation": [False] + }, { + "dtype": [tf.float32], + "input_shape_1": [[0]], + "input_shape_2": [[1]], + "activation": [False] }] def build_graph(parameters): @@ -1134,7 +1139,7 @@ def make_gather_tests(zip_path): "params_shape": [[10], [1, 2, 20]], "indices_dtype": [tf.int32], "indices_shape": [[3], [5]], - "axis": [0, 1], + "axis": [-1, 0, 1], }] def build_graph(parameters): @@ -1147,7 +1152,8 @@ def build_graph(parameters): dtype=parameters["indices_dtype"], name="indices", shape=parameters["indices_shape"]) - out = tf.gather(params, indices, axis=parameters["axis"]) + axis = min(len(parameters["params_shape"]), parameters["axis"]) + out = tf.gather(params, indices, axis=axis) return [params, indices], [out] def build_inputs(parameters, sess, inputs, outputs): diff --git a/tensorflow/contrib/lite/testing/generate_examples_report.py b/tensorflow/lite/testing/generate_examples_report.py similarity index 100% rename from tensorflow/contrib/lite/testing/generate_examples_report.py rename to tensorflow/lite/testing/generate_examples_report.py diff --git a/tensorflow/contrib/lite/testing/generate_testspec.cc b/tensorflow/lite/testing/generate_testspec.cc similarity index 95% rename from tensorflow/contrib/lite/testing/generate_testspec.cc rename to tensorflow/lite/testing/generate_testspec.cc index 62cbeccd3315f2..74e4d2549830f4 100644 --- a/tensorflow/contrib/lite/testing/generate_testspec.cc +++ b/tensorflow/lite/testing/generate_testspec.cc @@ -15,10 +15,10 @@ limitations under the License. #include -#include "tensorflow/contrib/lite/testing/generate_testspec.h" -#include "tensorflow/contrib/lite/testing/join.h" -#include "tensorflow/contrib/lite/testing/split.h" -#include "tensorflow/contrib/lite/testing/tf_driver.h" +#include "tensorflow/lite/testing/generate_testspec.h" +#include "tensorflow/lite/testing/join.h" +#include "tensorflow/lite/testing/split.h" +#include "tensorflow/lite/testing/tf_driver.h" #include "tensorflow/core/framework/types.h" namespace tflite { diff --git a/tensorflow/contrib/lite/testing/generate_testspec.h b/tensorflow/lite/testing/generate_testspec.h similarity index 91% rename from tensorflow/contrib/lite/testing/generate_testspec.h rename to tensorflow/lite/testing/generate_testspec.h index b3d0db31c01a8c..bda636f2c8081f 100644 --- a/tensorflow/contrib/lite/testing/generate_testspec.h +++ b/tensorflow/lite/testing/generate_testspec.h @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_GENERATE_TESTSPEC_H_ -#define TENSORFLOW_CONTRIB_LITE_TESTING_GENERATE_TESTSPEC_H_ +#ifndef TENSORFLOW_LITE_TESTING_GENERATE_TESTSPEC_H_ +#define TENSORFLOW_LITE_TESTING_GENERATE_TESTSPEC_H_ #include #include #include -#include "tensorflow/contrib/lite/string.h" +#include "tensorflow/lite/string.h" namespace tflite { namespace testing { @@ -65,4 +65,4 @@ std::vector GenerateRandomTensor(const std::vector& shape, } // namespace testing } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_TESTING_GENERATE_TESTSPEC_H_ +#endif // TENSORFLOW_LITE_TESTING_GENERATE_TESTSPEC_H_ diff --git a/tensorflow/contrib/lite/testing/generate_testspec_test.cc b/tensorflow/lite/testing/generate_testspec_test.cc similarity index 96% rename from tensorflow/contrib/lite/testing/generate_testspec_test.cc rename to tensorflow/lite/testing/generate_testspec_test.cc index 2a97b757a41324..4450da289d2e33 100644 --- a/tensorflow/contrib/lite/testing/generate_testspec_test.cc +++ b/tensorflow/lite/testing/generate_testspec_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/testing/generate_testspec.h" +#include "tensorflow/lite/testing/generate_testspec.h" #include #include diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/lite/testing/generated_examples_zip_test.cc similarity index 98% rename from tensorflow/contrib/lite/testing/generated_examples_zip_test.cc rename to tensorflow/lite/testing/generated_examples_zip_test.cc index 1ec471365e2ec6..49f7b527bb7587 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/lite/testing/generated_examples_zip_test.cc @@ -20,9 +20,9 @@ limitations under the License. #include #include #include "re2/re2.h" -#include "tensorflow/contrib/lite/testing/parse_testdata.h" -#include "tensorflow/contrib/lite/testing/tflite_driver.h" -#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/lite/testing/parse_testdata.h" +#include "tensorflow/lite/testing/tflite_driver.h" +#include "tensorflow/lite/testing/util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/subprocess.h" @@ -85,9 +85,6 @@ std::map kBrokenTests = { // Transpose only supports 1D-4D input tensors. {R"(^\/transpose.*input_shape=\[.,.,.,.,.\])", "71545879"}, - // No support for axis!=0 in GatherV2. - {R"(^\/gather.*axis=1)", "76910444"}, - // No Support for float. {R"(^\/floor_div.*dtype=tf\.float32)", "112859002"}, diff --git a/tensorflow/contrib/lite/testing/init_tensorflow.cc b/tensorflow/lite/testing/init_tensorflow.cc similarity index 94% rename from tensorflow/contrib/lite/testing/init_tensorflow.cc rename to tensorflow/lite/testing/init_tensorflow.cc index f3dcf620a20e30..ed4d12374489ed 100644 --- a/tensorflow/contrib/lite/testing/init_tensorflow.cc +++ b/tensorflow/lite/testing/init_tensorflow.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/testing/init_tensorflow.h" +#include "tensorflow/lite/testing/init_tensorflow.h" #include #include diff --git a/tensorflow/contrib/lite/testing/init_tensorflow.h b/tensorflow/lite/testing/init_tensorflow.h similarity index 82% rename from tensorflow/contrib/lite/testing/init_tensorflow.h rename to tensorflow/lite/testing/init_tensorflow.h index 2cc89bbbcade5e..0c36a247912b71 100644 --- a/tensorflow/contrib/lite/testing/init_tensorflow.h +++ b/tensorflow/lite/testing/init_tensorflow.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_INIT_TENSORFLOW_H_ -#define TENSORFLOW_CONTRIB_LITE_TESTING_INIT_TENSORFLOW_H_ +#ifndef TENSORFLOW_LITE_TESTING_INIT_TENSORFLOW_H_ +#define TENSORFLOW_LITE_TESTING_INIT_TENSORFLOW_H_ namespace tflite { @@ -23,4 +23,4 @@ void InitTensorFlow(); } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_TESTING_INIT_TENSORFLOW_H_ +#endif // TENSORFLOW_LITE_TESTING_INIT_TENSORFLOW_H_ diff --git a/tensorflow/contrib/lite/testing/join.h b/tensorflow/lite/testing/join.h similarity index 89% rename from tensorflow/contrib/lite/testing/join.h rename to tensorflow/lite/testing/join.h index 4be19ad7569c33..7d0040c488a4ce 100644 --- a/tensorflow/contrib/lite/testing/join.h +++ b/tensorflow/lite/testing/join.h @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_JOIN_H_ -#define TENSORFLOW_CONTRIB_LITE_TESTING_JOIN_H_ +#ifndef TENSORFLOW_LITE_TESTING_JOIN_H_ +#define TENSORFLOW_LITE_TESTING_JOIN_H_ #include #include -#include "tensorflow/contrib/lite/string.h" +#include "tensorflow/lite/string.h" namespace tflite { namespace testing { @@ -56,4 +56,4 @@ inline string Join(uint8_t* data, size_t len, } // namespace testing } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_TESTING_JOIN_H_ +#endif // TENSORFLOW_LITE_TESTING_JOIN_H_ diff --git a/tensorflow/contrib/lite/testing/join_test.cc b/tensorflow/lite/testing/join_test.cc similarity index 96% rename from tensorflow/contrib/lite/testing/join_test.cc rename to tensorflow/lite/testing/join_test.cc index bd04528381f6d3..a8d036c547ded3 100644 --- a/tensorflow/contrib/lite/testing/join_test.cc +++ b/tensorflow/lite/testing/join_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/testing/join.h" +#include "tensorflow/lite/testing/join.h" #include #include diff --git a/tensorflow/contrib/lite/testing/message.cc b/tensorflow/lite/testing/message.cc similarity index 96% rename from tensorflow/contrib/lite/testing/message.cc rename to tensorflow/lite/testing/message.cc index 03fae4bb86a30e..08aac6f6aa192c 100644 --- a/tensorflow/contrib/lite/testing/message.cc +++ b/tensorflow/lite/testing/message.cc @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/testing/message.h" +#include "tensorflow/lite/testing/message.h" #include -#include "tensorflow/contrib/lite/testing/tokenize.h" +#include "tensorflow/lite/testing/tokenize.h" namespace tflite { namespace testing { diff --git a/tensorflow/contrib/lite/testing/message.h b/tensorflow/lite/testing/message.h similarity index 94% rename from tensorflow/contrib/lite/testing/message.h rename to tensorflow/lite/testing/message.h index e2bc4082141f06..e6566ab11ca7dd 100644 --- a/tensorflow/contrib/lite/testing/message.h +++ b/tensorflow/lite/testing/message.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_ -#define TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_ +#ifndef TENSORFLOW_LITE_TESTING_MESSAGE_H_ +#define TENSORFLOW_LITE_TESTING_MESSAGE_H_ #include #include @@ -79,4 +79,4 @@ class Message { } // namespace testing } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_ +#endif // TENSORFLOW_LITE_TESTING_MESSAGE_H_ diff --git a/tensorflow/contrib/lite/testing/message_test.cc b/tensorflow/lite/testing/message_test.cc similarity index 98% rename from tensorflow/contrib/lite/testing/message_test.cc rename to tensorflow/lite/testing/message_test.cc index fb6a49bd6f1ea8..bec4915e5853d6 100644 --- a/tensorflow/contrib/lite/testing/message_test.cc +++ b/tensorflow/lite/testing/message_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/testing/message.h" +#include "tensorflow/lite/testing/message.h" #include diff --git a/tensorflow/contrib/lite/testing/model_coverage/BUILD b/tensorflow/lite/testing/model_coverage/BUILD similarity index 85% rename from tensorflow/contrib/lite/testing/model_coverage/BUILD rename to tensorflow/lite/testing/model_coverage/BUILD index c8359bab064b7c..7e6a65997d38d1 100644 --- a/tensorflow/contrib/lite/testing/model_coverage/BUILD +++ b/tensorflow/lite/testing/model_coverage/BUILD @@ -1,5 +1,5 @@ package(default_visibility = [ - "//tensorflow/contrib/lite:__subpackages__", + "//tensorflow/lite:__subpackages__", ]) licenses(["notice"]) # Apache 2.0 @@ -10,7 +10,7 @@ py_binary( srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ - "//tensorflow/contrib/lite/python:lite", + "//tensorflow/lite/python:lite", "//tensorflow/python:platform", ], ) diff --git a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py b/tensorflow/lite/testing/model_coverage/model_coverage_lib.py similarity index 98% rename from tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py rename to tensorflow/lite/testing/model_coverage/model_coverage_lib.py index 2dc5aeb3023ce0..ce8ef0b1960021 100644 --- a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py +++ b/tensorflow/lite/testing/model_coverage/model_coverage_lib.py @@ -20,8 +20,8 @@ import numpy as np -from tensorflow.contrib.lite.python import convert_saved_model as _convert_saved_model -from tensorflow.contrib.lite.python import lite as _lite +from tensorflow.lite.python import convert_saved_model as _convert_saved_model +from tensorflow.lite.python import lite as _lite from tensorflow.core.framework import graph_pb2 as _graph_pb2 from tensorflow.python import keras as _keras from tensorflow.python.client import session as _session diff --git a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py b/tensorflow/lite/testing/model_coverage/model_coverage_lib_test.py similarity index 97% rename from tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py rename to tensorflow/lite/testing/model_coverage/model_coverage_lib_test.py index 98dbff4d7911bc..6b4e7427ed9c69 100644 --- a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py +++ b/tensorflow/lite/testing/model_coverage/model_coverage_lib_test.py @@ -22,8 +22,8 @@ import tempfile import numpy as np -from tensorflow.contrib.lite.python import lite -from tensorflow.contrib.lite.testing.model_coverage import model_coverage_lib as model_coverage +from tensorflow.lite.python import lite +from tensorflow.lite.testing.model_coverage import model_coverage_lib as model_coverage from tensorflow.python import keras from tensorflow.python.client import session from tensorflow.python.framework import constant_op diff --git a/tensorflow/contrib/lite/testing/nnapi_example.cc b/tensorflow/lite/testing/nnapi_example.cc similarity index 91% rename from tensorflow/contrib/lite/testing/nnapi_example.cc rename to tensorflow/lite/testing/nnapi_example.cc index 5870782b69217f..22df8dbd882143 100644 --- a/tensorflow/contrib/lite/testing/nnapi_example.cc +++ b/tensorflow/lite/testing/nnapi_example.cc @@ -17,7 +17,7 @@ limitations under the License. // the future. // // Usage: bazel run -c opt \ -// tensorflow/contrib/lite/nnapi:nnapi_example -- +// tensorflow/lite/nnapi:nnapi_example -- // #include #include @@ -25,9 +25,9 @@ limitations under the License. #include #include #include -#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" -#include "tensorflow/contrib/lite/testing/parse_testdata.h" -#include "tensorflow/contrib/lite/testing/tflite_driver.h" +#include "tensorflow/lite/nnapi/NeuralNetworksShim.h" +#include "tensorflow/lite/testing/parse_testdata.h" +#include "tensorflow/lite/testing/tflite_driver.h" string dirname(const string& s) { return s.substr(0, s.find_last_of("/")); } diff --git a/tensorflow/contrib/lite/testing/parse_testdata.cc b/tensorflow/lite/testing/parse_testdata.cc similarity index 98% rename from tensorflow/contrib/lite/testing/parse_testdata.cc rename to tensorflow/lite/testing/parse_testdata.cc index 389688d552051e..bb540087942b7e 100644 --- a/tensorflow/contrib/lite/testing/parse_testdata.cc +++ b/tensorflow/lite/testing/parse_testdata.cc @@ -16,7 +16,7 @@ limitations under the License. // Format is ASCII // TODO(aselle): Switch to protobuf, but the android team requested a simple // ASCII file. -#include "tensorflow/contrib/lite/testing/parse_testdata.h" +#include "tensorflow/lite/testing/parse_testdata.h" #include #include @@ -26,9 +26,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/error_reporter.h" -#include "tensorflow/contrib/lite/testing/message.h" -#include "tensorflow/contrib/lite/testing/split.h" +#include "tensorflow/lite/error_reporter.h" +#include "tensorflow/lite/testing/message.h" +#include "tensorflow/lite/testing/split.h" namespace tflite { namespace testing { diff --git a/tensorflow/contrib/lite/testing/parse_testdata.h b/tensorflow/lite/testing/parse_testdata.h similarity index 89% rename from tensorflow/contrib/lite/testing/parse_testdata.h rename to tensorflow/lite/testing/parse_testdata.h index 26ee8258662e68..0f3dc32afca974 100644 --- a/tensorflow/contrib/lite/testing/parse_testdata.h +++ b/tensorflow/lite/testing/parse_testdata.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_PARSE_TESTDATA_H_ -#define TENSORFLOW_CONTRIB_LITE_TESTING_PARSE_TESTDATA_H_ +#ifndef TENSORFLOW_LITE_TESTING_PARSE_TESTDATA_H_ +#define TENSORFLOW_LITE_TESTING_PARSE_TESTDATA_H_ #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/testing/test_runner.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/testing/test_runner.h" namespace tflite { namespace testing { @@ -72,4 +72,4 @@ bool ParseAndRunTests(std::istream* input, TestRunner* test_runner, } // namespace testing } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_TESTING_PARSE_TESTDATA_H_ +#endif // TENSORFLOW_LITE_TESTING_PARSE_TESTDATA_H_ diff --git a/tensorflow/contrib/lite/testing/split.cc b/tensorflow/lite/testing/split.cc similarity index 96% rename from tensorflow/contrib/lite/testing/split.cc rename to tensorflow/lite/testing/split.cc index 5836f4ff049b70..594b0d3f8a2af8 100644 --- a/tensorflow/contrib/lite/testing/split.cc +++ b/tensorflow/lite/testing/split.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/testing/split.h" +#include "tensorflow/lite/testing/split.h" namespace tflite { namespace testing { diff --git a/tensorflow/contrib/lite/testing/split.h b/tensorflow/lite/testing/split.h similarity index 93% rename from tensorflow/contrib/lite/testing/split.h rename to tensorflow/lite/testing/split.h index 896f2949efa6ae..c33738997cae58 100644 --- a/tensorflow/contrib/lite/testing/split.h +++ b/tensorflow/lite/testing/split.h @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_ -#define TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_ +#ifndef TENSORFLOW_LITE_TESTING_SPLIT_H_ +#define TENSORFLOW_LITE_TESTING_SPLIT_H_ #include #include #include #include -#include "tensorflow/contrib/lite/string.h" +#include "tensorflow/lite/string.h" namespace tflite { namespace testing { @@ -93,4 +93,4 @@ inline std::vector Split(const string& s, const string& delimiter) { } // namespace testing } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_ +#endif // TENSORFLOW_LITE_TESTING_SPLIT_H_ diff --git a/tensorflow/contrib/lite/testing/split_test.cc b/tensorflow/lite/testing/split_test.cc similarity index 97% rename from tensorflow/contrib/lite/testing/split_test.cc rename to tensorflow/lite/testing/split_test.cc index 76b918cbcd83ef..77de485d5710ea 100644 --- a/tensorflow/contrib/lite/testing/split_test.cc +++ b/tensorflow/lite/testing/split_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/testing/split.h" +#include "tensorflow/lite/testing/split.h" #include #include diff --git a/tensorflow/contrib/lite/testing/test_runner.h b/tensorflow/lite/testing/test_runner.h similarity index 95% rename from tensorflow/contrib/lite/testing/test_runner.h rename to tensorflow/lite/testing/test_runner.h index fac7d01aab4b1e..303155b072bc3a 100644 --- a/tensorflow/contrib/lite/testing/test_runner.h +++ b/tensorflow/lite/testing/test_runner.h @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_ -#define TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_ +#ifndef TENSORFLOW_LITE_TESTING_TEST_RUNNER_H_ +#define TENSORFLOW_LITE_TESTING_TEST_RUNNER_H_ #include #include #include #include -#include "tensorflow/contrib/lite/string.h" +#include "tensorflow/lite/string.h" namespace tflite { namespace testing { @@ -127,4 +127,4 @@ class TestRunner { } // namespace testing } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_ +#endif // TENSORFLOW_LITE_TESTING_TEST_RUNNER_H_ diff --git a/tensorflow/contrib/lite/testing/test_runner_test.cc b/tensorflow/lite/testing/test_runner_test.cc similarity index 97% rename from tensorflow/contrib/lite/testing/test_runner_test.cc rename to tensorflow/lite/testing/test_runner_test.cc index 3f04aa20bd7de8..39ec81582bcd8f 100644 --- a/tensorflow/contrib/lite/testing/test_runner_test.cc +++ b/tensorflow/lite/testing/test_runner_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/testing/test_runner.h" +#include "tensorflow/lite/testing/test_runner.h" #include #include diff --git a/tensorflow/contrib/lite/testing/tf_driver.cc b/tensorflow/lite/testing/tf_driver.cc similarity index 97% rename from tensorflow/contrib/lite/testing/tf_driver.cc rename to tensorflow/lite/testing/tf_driver.cc index 30381ba028352e..36c556ba049509 100644 --- a/tensorflow/contrib/lite/testing/tf_driver.cc +++ b/tensorflow/lite/testing/tf_driver.cc @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/testing/tf_driver.h" +#include "tensorflow/lite/testing/tf_driver.h" #include #include -#include "tensorflow/contrib/lite/testing/join.h" -#include "tensorflow/contrib/lite/testing/split.h" +#include "tensorflow/lite/testing/join.h" +#include "tensorflow/lite/testing/split.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace tflite { diff --git a/tensorflow/contrib/lite/testing/tf_driver.h b/tensorflow/lite/testing/tf_driver.h similarity index 90% rename from tensorflow/contrib/lite/testing/tf_driver.h rename to tensorflow/lite/testing/tf_driver.h index b766f85c4ddee9..f10689cb58c175 100644 --- a/tensorflow/contrib/lite/testing/tf_driver.h +++ b/tensorflow/lite/testing/tf_driver.h @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TF_DRIVER_H_ -#define TENSORFLOW_CONTRIB_LITE_TESTING_TF_DRIVER_H_ +#ifndef TENSORFLOW_LITE_TESTING_TF_DRIVER_H_ +#define TENSORFLOW_LITE_TESTING_TF_DRIVER_H_ #include #include -#include "tensorflow/contrib/lite/testing/split.h" -#include "tensorflow/contrib/lite/testing/test_runner.h" +#include "tensorflow/lite/testing/split.h" +#include "tensorflow/lite/testing/test_runner.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" @@ -72,4 +72,4 @@ class TfDriver : public TestRunner { } // namespace testing } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_TESTING_TF_DRIVER_H_ +#endif // TENSORFLOW_LITE_TESTING_TF_DRIVER_H_ diff --git a/tensorflow/contrib/lite/testing/tf_driver_test.cc b/tensorflow/lite/testing/tf_driver_test.cc similarity index 93% rename from tensorflow/contrib/lite/testing/tf_driver_test.cc rename to tensorflow/lite/testing/tf_driver_test.cc index c0faa4676adc3e..d178ccf1e3f7d8 100644 --- a/tensorflow/contrib/lite/testing/tf_driver_test.cc +++ b/tensorflow/lite/testing/tf_driver_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/testing/tf_driver.h" +#include "tensorflow/lite/testing/tf_driver.h" #include #include @@ -29,7 +29,7 @@ TEST(TfDriverTest, SimpleTest) { {"1,8,8,3", "1,8,8,3", "1,8,8,3", "1,8,8,3"}, {"x", "y"})); runner->LoadModel( - "third_party/tensorflow/contrib/lite/testdata/multi_add.pb"); + "third_party/tensorflow/lite/testdata/multi_add.pb"); EXPECT_TRUE(runner->IsValid()) << runner->GetErrorMessage(); ASSERT_THAT(runner->GetInputs(), ElementsAre(0, 1, 2, 3)); diff --git a/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc b/tensorflow/lite/testing/tflite_diff_example_test.cc similarity index 90% rename from tensorflow/contrib/lite/testing/tflite_diff_example_test.cc rename to tensorflow/lite/testing/tflite_diff_example_test.cc index 49696ac76be9c2..cb61cd4e942140 100644 --- a/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc +++ b/tensorflow/lite/testing/tflite_diff_example_test.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/testing/init_tensorflow.h" -#include "tensorflow/contrib/lite/testing/tflite_diff_flags.h" -#include "tensorflow/contrib/lite/testing/tflite_diff_util.h" +#include "tensorflow/lite/testing/init_tensorflow.h" +#include "tensorflow/lite/testing/tflite_diff_flags.h" +#include "tensorflow/lite/testing/tflite_diff_util.h" int main(int argc, char** argv) { ::tflite::InitTensorFlow(); // For Flex support. diff --git a/tensorflow/contrib/lite/testing/tflite_diff_flags.h b/tensorflow/lite/testing/tflite_diff_flags.h similarity index 92% rename from tensorflow/contrib/lite/testing/tflite_diff_flags.h rename to tensorflow/lite/testing/tflite_diff_flags.h index ad889a2f198644..2fe068eb20f1fb 100644 --- a/tensorflow/contrib/lite/testing/tflite_diff_flags.h +++ b/tensorflow/lite/testing/tflite_diff_flags.h @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DIFF_FLAGS_H_ -#define TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DIFF_FLAGS_H_ +#ifndef TENSORFLOW_LITE_TESTING_TFLITE_DIFF_FLAGS_H_ +#define TENSORFLOW_LITE_TESTING_TFLITE_DIFF_FLAGS_H_ #include -#include "tensorflow/contrib/lite/testing/split.h" -#include "tensorflow/contrib/lite/testing/tflite_diff_util.h" +#include "tensorflow/lite/testing/split.h" +#include "tensorflow/lite/testing/tflite_diff_util.h" #include "tensorflow/core/util/command_line_flags.h" namespace tflite { @@ -88,4 +88,4 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) { } // namespace testing } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DIFF_FLAGS_H_ +#endif // TENSORFLOW_LITE_TESTING_TFLITE_DIFF_FLAGS_H_ diff --git a/tensorflow/contrib/lite/testing/tflite_diff_util.cc b/tensorflow/lite/testing/tflite_diff_util.cc similarity index 85% rename from tensorflow/contrib/lite/testing/tflite_diff_util.cc rename to tensorflow/lite/testing/tflite_diff_util.cc index c6ca796ac25d2c..0142ae4217eaea 100644 --- a/tensorflow/contrib/lite/testing/tflite_diff_util.cc +++ b/tensorflow/lite/testing/tflite_diff_util.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/testing/generate_testspec.h" -#include "tensorflow/contrib/lite/testing/parse_testdata.h" -#include "tensorflow/contrib/lite/testing/tflite_diff_util.h" -#include "tensorflow/contrib/lite/testing/tflite_driver.h" +#include "tensorflow/lite/testing/generate_testspec.h" +#include "tensorflow/lite/testing/parse_testdata.h" +#include "tensorflow/lite/testing/tflite_diff_util.h" +#include "tensorflow/lite/testing/tflite_driver.h" namespace tflite { namespace testing { diff --git a/tensorflow/contrib/lite/testing/tflite_diff_util.h b/tensorflow/lite/testing/tflite_diff_util.h similarity index 89% rename from tensorflow/contrib/lite/testing/tflite_diff_util.h rename to tensorflow/lite/testing/tflite_diff_util.h index 28b14bd143ab0e..3f9f10892db287 100644 --- a/tensorflow/contrib/lite/testing/tflite_diff_util.h +++ b/tensorflow/lite/testing/tflite_diff_util.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DIFF_UTIL_H_ -#define TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DIFF_UTIL_H_ +#ifndef TENSORFLOW_LITE_TESTING_TFLITE_DIFF_UTIL_H_ +#define TENSORFLOW_LITE_TESTING_TFLITE_DIFF_UTIL_H_ #include -#include "tensorflow/contrib/lite/string.h" +#include "tensorflow/lite/string.h" namespace tflite { namespace testing { @@ -55,4 +55,4 @@ bool RunDiffTest(const DiffOptions& options, int num_invocations); } // namespace testing } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DIFF_UTIL_H_ +#endif // TENSORFLOW_LITE_TESTING_TFLITE_DIFF_UTIL_H_ diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/lite/testing/tflite_driver.cc similarity index 95% rename from tensorflow/contrib/lite/testing/tflite_driver.cc rename to tensorflow/lite/testing/tflite_driver.cc index ef49e6f8bc30a6..8619f5f83662bf 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.cc +++ b/tensorflow/lite/testing/tflite_driver.cc @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/testing/tflite_driver.h" +#include "tensorflow/lite/testing/tflite_driver.h" #include -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/delegates/flex/delegate.h" -#include "tensorflow/contrib/lite/testing/split.h" +#include "tensorflow/lite/builtin_op_data.h" +#include "tensorflow/lite/delegates/flex/delegate.h" +#include "tensorflow/lite/testing/split.h" namespace tflite { namespace testing { @@ -173,9 +173,7 @@ void TfLiteDriver::LoadModel(const string& bin_file_path) { interpreter_->UseNNAPI(use_nnapi_); if (delegate_) { - if (interpreter_->ModifyGraphWithDelegate(delegate_.get(), - /*allow_dynamic_tensors=*/true) != - kTfLiteOk) { + if (interpreter_->ModifyGraphWithDelegate(delegate_.get()) != kTfLiteOk) { Invalidate("Unable to the build graph using the delegate"); return; } diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/lite/testing/tflite_driver.h similarity index 81% rename from tensorflow/contrib/lite/testing/tflite_driver.h rename to tensorflow/lite/testing/tflite_driver.h index dc2a4e58773a9e..785baf0f004f33 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.h +++ b/tensorflow/lite/testing/tflite_driver.h @@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_ -#define TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_ +#ifndef TENSORFLOW_LITE_TESTING_TFLITE_DRIVER_H_ +#define TENSORFLOW_LITE_TESTING_TFLITE_DRIVER_H_ #include -#include "tensorflow/contrib/lite/delegates/flex/delegate.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/testing/test_runner.h" +#include "tensorflow/lite/delegates/flex/delegate.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/testing/test_runner.h" namespace tflite { namespace testing { @@ -64,4 +64,4 @@ class TfLiteDriver : public TestRunner { } // namespace testing } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_ +#endif // TENSORFLOW_LITE_TESTING_TFLITE_DRIVER_H_ diff --git a/tensorflow/contrib/lite/testing/tflite_driver_test.cc b/tensorflow/lite/testing/tflite_driver_test.cc similarity index 93% rename from tensorflow/contrib/lite/testing/tflite_driver_test.cc rename to tensorflow/lite/testing/tflite_driver_test.cc index 37010c468f250f..6e953e5e19b8f6 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver_test.cc +++ b/tensorflow/lite/testing/tflite_driver_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/testing/tflite_driver.h" +#include "tensorflow/lite/testing/tflite_driver.h" #include #include @@ -26,7 +26,7 @@ using ::testing::ElementsAre; TEST(TfliteDriverTest, SimpleTest) { std::unique_ptr runner(new TfLiteDriver(/*use_nnapi=*/false)); - runner->SetModelBaseDir("tensorflow/contrib/lite"); + runner->SetModelBaseDir("tensorflow/lite"); runner->LoadModel("testdata/multi_add.bin"); ASSERT_TRUE(runner->IsValid()); diff --git a/tensorflow/contrib/lite/testing/tokenize.cc b/tensorflow/lite/testing/tokenize.cc similarity index 96% rename from tensorflow/contrib/lite/testing/tokenize.cc rename to tensorflow/lite/testing/tokenize.cc index 2e84ea475cae60..bb4753580131ad 100644 --- a/tensorflow/contrib/lite/testing/tokenize.cc +++ b/tensorflow/lite/testing/tokenize.cc @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/testing/tokenize.h" +#include "tensorflow/lite/testing/tokenize.h" #include #include -#include "tensorflow/contrib/lite/string.h" +#include "tensorflow/lite/string.h" namespace tflite { namespace testing { diff --git a/tensorflow/contrib/lite/testing/tokenize.h b/tensorflow/lite/testing/tokenize.h similarity index 89% rename from tensorflow/contrib/lite/testing/tokenize.h rename to tensorflow/lite/testing/tokenize.h index 819539185168df..7bd2783337a763 100644 --- a/tensorflow/contrib/lite/testing/tokenize.h +++ b/tensorflow/lite/testing/tokenize.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZE_H_ -#define TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZE_H_ +#ifndef TENSORFLOW_LITE_TESTING_TOKENIZE_H_ +#define TENSORFLOW_LITE_TESTING_TOKENIZE_H_ #include #include @@ -39,4 +39,4 @@ void Tokenize(std::istream* input, TokenProcessor* processor); } // namespace testing } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZE_H_ +#endif // TENSORFLOW_LITE_TESTING_TOKENIZE_H_ diff --git a/tensorflow/contrib/lite/testing/tokenize_test.cc b/tensorflow/lite/testing/tokenize_test.cc similarity index 98% rename from tensorflow/contrib/lite/testing/tokenize_test.cc rename to tensorflow/lite/testing/tokenize_test.cc index 80f44aacca7e90..302ae589d02c3e 100644 --- a/tensorflow/contrib/lite/testing/tokenize_test.cc +++ b/tensorflow/lite/testing/tokenize_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/testing/tokenize.h" +#include "tensorflow/lite/testing/tokenize.h" #include #include diff --git a/tensorflow/contrib/lite/testing/util.h b/tensorflow/lite/testing/util.h similarity index 85% rename from tensorflow/contrib/lite/testing/util.h rename to tensorflow/lite/testing/util.h index 925791d3908dc5..45751497de47bc 100644 --- a/tensorflow/contrib/lite/testing/util.h +++ b/tensorflow/lite/testing/util.h @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_ -#define TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_ +#ifndef TENSORFLOW_LITE_TESTING_UTIL_H_ +#define TENSORFLOW_LITE_TESTING_UTIL_H_ #include -#include "tensorflow/contrib/lite/core/api/error_reporter.h" -#include "tensorflow/contrib/lite/string.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/string.h" namespace tflite { @@ -56,4 +56,4 @@ inline void LogToStderr() { } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_ +#endif // TENSORFLOW_LITE_TESTING_UTIL_H_ diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/lite/toco/BUILD similarity index 96% rename from tensorflow/contrib/lite/toco/BUILD rename to tensorflow/lite/toco/BUILD index 96b88b60fc6509..14302874441c4a 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/lite/toco/BUILD @@ -74,8 +74,8 @@ cc_library( linkstatic = 1, visibility = ["//visibility:public"], deps = [ - "//tensorflow/contrib/lite/kernels/internal:reference_base", - "//tensorflow/contrib/lite/kernels/internal:types", + "//tensorflow/lite/kernels/internal:reference_base", + "//tensorflow/lite/kernels/internal:types", ], ) @@ -281,9 +281,9 @@ cc_library( ":runtime", ":toco_port", ":tooling_util", - "//tensorflow/contrib/lite/kernels/internal:quantization_util", - "//tensorflow/contrib/lite/kernels/internal:strided_slice_logic", "//tensorflow/core:lib", + "//tensorflow/lite/kernels/internal:quantization_util", + "//tensorflow/lite/kernels/internal:strided_slice_logic", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -325,13 +325,13 @@ cc_library( "@protobuf_archive//:protobuf_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", - "//tensorflow/contrib/lite/toco/tensorflow_graph_matching:resolve_cluster", - "//tensorflow/contrib/lite/toco/tflite:export", - "//tensorflow/contrib/lite/toco/tflite:import", "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/lite/toco/tensorflow_graph_matching:resolve_cluster", + "//tensorflow/lite/toco/tflite:export", + "//tensorflow/lite/toco/tflite:import", ] + select({ # Placeholder for internal darwin rule. "//conditions:default": [], @@ -373,8 +373,8 @@ cc_library( ":toco_graphviz_dump_options", ":toco_port", ":types_proto_cc", - "//tensorflow/contrib/lite/kernels/internal:types", "//tensorflow/core:lib", + "//tensorflow/lite/kernels/internal:types", "@com_google_absl//absl/strings", "@com_googlesource_code_re2//:re2", "@protobuf_archive//:protobuf_headers", diff --git a/tensorflow/contrib/lite/toco/README.md b/tensorflow/lite/toco/README.md similarity index 100% rename from tensorflow/contrib/lite/toco/README.md rename to tensorflow/lite/toco/README.md diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc b/tensorflow/lite/toco/allocate_transient_arrays.cc similarity index 98% rename from tensorflow/contrib/lite/toco/allocate_transient_arrays.cc rename to tensorflow/lite/toco/allocate_transient_arrays.cc index 18c904c6d4e8ad..3ec53c9c2d63ee 100644 --- a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc +++ b/tensorflow/lite/toco/allocate_transient_arrays.cc @@ -20,10 +20,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/allocate_transient_arrays.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/model_flags.pb.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/allocate_transient_arrays.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.h b/tensorflow/lite/toco/allocate_transient_arrays.h similarity index 87% rename from tensorflow/contrib/lite/toco/allocate_transient_arrays.h rename to tensorflow/lite/toco/allocate_transient_arrays.h index 59d8ada1e9bb98..5d43d4cc3fa802 100644 --- a/tensorflow/contrib/lite/toco/allocate_transient_arrays.h +++ b/tensorflow/lite/toco/allocate_transient_arrays.h @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_ +#ifndef TENSORFLOW_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_ +#define TENSORFLOW_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_ -#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/lite/toco/model.h" namespace toco { @@ -41,4 +41,4 @@ void AllocateTransientArrays(Model* model, } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_ +#endif // TENSORFLOW_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_ diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/lite/toco/args.h similarity index 96% rename from tensorflow/contrib/lite/toco/args.h rename to tensorflow/lite/toco/args.h index 2699ac76e1d2c3..188f2f7e7af61c 100644 --- a/tensorflow/contrib/lite/toco/args.h +++ b/tensorflow/lite/toco/args.h @@ -15,20 +15,20 @@ limitations under the License. // This abstracts command line arguments in toco. // Arg is a parseable type that can register a default value, be able to // parse itself, and keep track of whether it was specified. -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_ +#ifndef TENSORFLOW_LITE_TOCO_ARGS_H_ +#define TENSORFLOW_LITE_TOCO_ARGS_H_ #include #include #include -#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/lite/toco/toco_port.h" #if defined(PLATFORM_GOOGLE) #include "strings/split.h" #include "strings/strip.h" #endif #include "absl/strings/numbers.h" #include "absl/strings/str_split.h" -#include "tensorflow/contrib/lite/toco/toco_types.h" +#include "tensorflow/lite/toco/toco_types.h" namespace toco { @@ -248,10 +248,10 @@ struct ParsedTocoFlags { Arg dedupe_array_min_size_bytes = Arg(64); Arg split_tflite_lstm_inputs = Arg(true); // WARNING: Experimental interface, subject to change - Arg allow_flex_ops = Arg(false); + Arg enable_select_tf_ops = Arg(false); // WARNING: Experimental interface, subject to change - Arg force_flex_ops = Arg(false); + Arg force_select_tf_ops = Arg(false); }; } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_ +#endif // TENSORFLOW_LITE_TOCO_ARGS_H_ diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/lite/toco/dump_graphviz.cc similarity index 97% rename from tensorflow/contrib/lite/toco/dump_graphviz.cc rename to tensorflow/lite/toco/dump_graphviz.cc index 30525efd2391bb..8896893f3579ab 100644 --- a/tensorflow/contrib/lite/toco/dump_graphviz.cc +++ b/tensorflow/lite/toco/dump_graphviz.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/dump_graphviz.h" +#include "tensorflow/lite/toco/dump_graphviz.h" #include #include @@ -20,11 +20,11 @@ limitations under the License. #include "absl/strings/str_replace.h" #include "absl/strings/strip.h" -#include "tensorflow/contrib/lite/toco/model_flags.pb.h" -#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" -#include "tensorflow/contrib/lite/toco/toco_port.h" -#include "tensorflow/contrib/lite/toco/toco_types.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/toco_graphviz_dump_options.h" +#include "tensorflow/lite/toco/toco_port.h" +#include "tensorflow/lite/toco/toco_types.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" using toco::port::AppendF; diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.h b/tensorflow/lite/toco/dump_graphviz.h similarity index 78% rename from tensorflow/contrib/lite/toco/dump_graphviz.h rename to tensorflow/lite/toco/dump_graphviz.h index ea5a4031c39580..9697bd6f0dc434 100644 --- a/tensorflow/contrib/lite/toco/dump_graphviz.h +++ b/tensorflow/lite/toco/dump_graphviz.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_ +#ifndef TENSORFLOW_LITE_TOCO_DUMP_GRAPHVIZ_H_ +#define TENSORFLOW_LITE_TOCO_DUMP_GRAPHVIZ_H_ #include -#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/lite/toco/model.h" namespace toco { @@ -25,4 +25,4 @@ void DumpGraphviz(const Model& model, string* output_file_contents); } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_ +#endif // TENSORFLOW_LITE_TOCO_DUMP_GRAPHVIZ_H_ diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/lite/toco/export_tensorflow.cc similarity index 99% rename from tensorflow/contrib/lite/toco/export_tensorflow.cc rename to tensorflow/lite/toco/export_tensorflow.cc index 41a82b57208c40..1752745aaee987 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/lite/toco/export_tensorflow.cc @@ -22,11 +22,11 @@ limitations under the License. #include "google/protobuf/text_format.h" #include "absl/memory/memory.h" #include "absl/strings/string_view.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/model_flags.pb.h" -#include "tensorflow/contrib/lite/toco/runtime/types.h" -#include "tensorflow/contrib/lite/toco/tensorflow_util.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/runtime/types.h" +#include "tensorflow/lite/toco/tensorflow_util.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.h b/tensorflow/lite/toco/export_tensorflow.h similarity index 79% rename from tensorflow/contrib/lite/toco/export_tensorflow.h rename to tensorflow/lite/toco/export_tensorflow.h index d7310bb75f258c..09c966ded621d4 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.h +++ b/tensorflow/lite/toco/export_tensorflow.h @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_ +#ifndef TENSORFLOW_LITE_TOCO_EXPORT_TENSORFLOW_H_ +#define TENSORFLOW_LITE_TOCO_EXPORT_TENSORFLOW_H_ #include -#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/lite/toco/model.h" namespace toco { @@ -26,4 +26,4 @@ void EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(Model* model); } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_ +#endif // TENSORFLOW_LITE_TOCO_EXPORT_TENSORFLOW_H_ diff --git a/tensorflow/contrib/lite/toco/format_port.h b/tensorflow/lite/toco/format_port.h similarity index 92% rename from tensorflow/contrib/lite/toco/format_port.h rename to tensorflow/lite/toco/format_port.h index 44e66845715237..69833d965c57d3 100644 --- a/tensorflow/contrib/lite/toco/format_port.h +++ b/tensorflow/lite/toco/format_port.h @@ -16,10 +16,10 @@ limitations under the License. // and absl::StrAppendFormat. Unfortunately, type safety is not as good as a // a full C++ example. // TODO(aselle): When absl adds support for StrFormat, use that instead. -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_ +#ifndef TENSORFLOW_LITE_TOCO_FORMAT_PORT_H_ +#define TENSORFLOW_LITE_TOCO_FORMAT_PORT_H_ -#include "tensorflow/contrib/lite/toco/toco_types.h" +#include "tensorflow/lite/toco/toco_types.h" #include "tensorflow/core/lib/strings/stringprintf.h" namespace toco { @@ -74,4 +74,4 @@ inline string StringF(const char* fmt, Args&&... args) { } // namespace port } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_ +#endif // TENSORFLOW_LITE_TOCO_FORMAT_PORT_H_ diff --git a/tensorflow/contrib/lite/toco/g3doc/README.md b/tensorflow/lite/toco/g3doc/README.md similarity index 100% rename from tensorflow/contrib/lite/toco/g3doc/README.md rename to tensorflow/lite/toco/g3doc/README.md diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc b/tensorflow/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc similarity index 95% rename from tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc rename to tensorflow/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc index 8a945ac4350f21..e3b0de55557291 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include "absl/strings/str_cat.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc similarity index 96% rename from tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc rename to tensorflow/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc index a1510128910d74..a707a906a815cd 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_reorder_axes.cc b/tensorflow/lite/toco/graph_transformations/convert_reorder_axes.cc similarity index 97% rename from tensorflow/contrib/lite/toco/graph_transformations/convert_reorder_axes.cc rename to tensorflow/lite/toco/graph_transformations/convert_reorder_axes.cc index 4a264e1cf1d4fc..b4cd4635982fd4 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_reorder_axes.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_reorder_axes.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include "absl/strings/str_cat.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc b/tensorflow/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc similarity index 94% rename from tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc rename to tensorflow/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc index a0bd1ed4a4d8a1..52aaefb3d74e4e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include "absl/strings/str_cat.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc b/tensorflow/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc similarity index 93% rename from tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc rename to tensorflow/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc index d7cacf77f48b6b..130fe58a9d13ff 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc b/tensorflow/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc similarity index 94% rename from tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc rename to tensorflow/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc index 78779243a9e15d..27c503f5ddd14d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include "absl/strings/str_cat.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc b/tensorflow/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc similarity index 94% rename from tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc rename to tensorflow/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc index b6d712ca44c3e8..fb416cabededf5 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc b/tensorflow/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc similarity index 95% rename from tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc rename to tensorflow/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc index e5a96d43351761..ae97cef520e0f2 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/lite/toco/graph_transformations/create_im2col_arrays.cc similarity index 94% rename from tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc rename to tensorflow/lite/toco/graph_transformations/create_im2col_arrays.cc index ebc0e9afca22d0..8e93bc237897b6 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc +++ b/tensorflow/lite/toco/graph_transformations/create_im2col_arrays.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include "absl/strings/str_cat.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc b/tensorflow/lite/toco/graph_transformations/dequantize.cc similarity index 96% rename from tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc rename to tensorflow/lite/toco/graph_transformations/dequantize.cc index 2119174950b1ad..cc5dddbb40e732 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc +++ b/tensorflow/lite/toco/graph_transformations/dequantize.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc b/tensorflow/lite/toco/graph_transformations/drop_fake_quant.cc similarity index 86% rename from tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc rename to tensorflow/lite/toco/graph_transformations/drop_fake_quant.cc index 1555cf60a1cdea..bb8679bced8077 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc +++ b/tensorflow/lite/toco/graph_transformations/drop_fake_quant.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc b/tensorflow/lite/toco/graph_transformations/drop_im2col_arrays.cc similarity index 88% rename from tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc rename to tensorflow/lite/toco/graph_transformations/drop_im2col_arrays.cc index 7d66ea5dd23451..c3c95afd967dc3 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc +++ b/tensorflow/lite/toco/graph_transformations/drop_im2col_arrays.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc b/tensorflow/lite/toco/graph_transformations/ensure_bias_vectors.cc similarity index 93% rename from tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc rename to tensorflow/lite/toco/graph_transformations/ensure_bias_vectors.cc index 72b1dda3be584b..62a4b52bbb877b 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc +++ b/tensorflow/lite/toco/graph_transformations/ensure_bias_vectors.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc b/tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc similarity index 98% rename from tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc rename to tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc index 60dcd5268442fe..918bb489995cd3 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc +++ b/tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc b/tensorflow/lite/toco/graph_transformations/fuse_activation_functions.cc similarity index 93% rename from tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc rename to tensorflow/lite/toco/graph_transformations/fuse_activation_functions.cc index 88511a7d3c4258..f467a95f348663 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc +++ b/tensorflow/lite/toco/graph_transformations/fuse_activation_functions.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/runtime/types.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/runtime/types.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc b/tensorflow/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc similarity index 98% rename from tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc rename to tensorflow/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc index 0de22b8ff4276c..6b4765b23c47d0 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc +++ b/tensorflow/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/runtime/types.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/runtime/types.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc similarity index 95% rename from tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc rename to tensorflow/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc index b8da756d857355..a19e51fa943755 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc +++ b/tensorflow/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/runtime/types.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/runtime/types.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { @@ -281,6 +281,18 @@ ::tensorflow::Status FuseBinaryIntoPrecedingAffine::Run(Model* model, const auto& bias_name = preceding_op->inputs[2]; const auto& weights = model->GetArray(weights_name); const auto& bias = model->GetArray(bias_name); + + if (weights.data_type != ArrayDataType::kFloat || + bias.data_type != ArrayDataType::kFloat) { + AddMessageF( + "Not fusing %s into preceding %s because one of weights or bias array " + "is not float (types are %s and %s)", + LogName(*binary_op), LogName(*preceding_op), + ArrayDataTypeName(weights.data_type), + ArrayDataTypeName(bias.data_type)); + return ::tensorflow::Status::OK(); + } + const int count_ops_consuming_bias = CountOpsWithInput(*model, bias_name); const int count_ops_consuming_weights = CountOpsWithInput(*model, weights_name); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc b/tensorflow/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc similarity index 95% rename from tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc rename to tensorflow/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc index 4848867b9a0a73..ba3e277f676ce8 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc +++ b/tensorflow/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc b/tensorflow/lite/toco/graph_transformations/graph_transformations.cc similarity index 97% rename from tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc rename to tensorflow/lite/toco/graph_transformations/graph_transformations.cc index 8b0bc2d865ea49..a0260e24013bfd 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc +++ b/tensorflow/lite/toco/graph_transformations/graph_transformations.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include #include @@ -21,8 +21,8 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/toco_port.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/toco_port.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/lite/toco/graph_transformations/graph_transformations.h similarity index 97% rename from tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h rename to tensorflow/lite/toco/graph_transformations/graph_transformations.h index a89db320ea9d84..73a90c8239b2a2 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/lite/toco/graph_transformations/graph_transformations.h @@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_ +#ifndef TENSORFLOW_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_ +#define TENSORFLOW_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_ #include #include #include #include -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/toco_port.h" namespace toco { @@ -287,4 +287,4 @@ class IdentifyDilatedConv : public GraphTransformation { } // end namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_ +#endif // TENSORFLOW_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_ diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc similarity index 98% rename from tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc rename to tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc index a4f8d64f4dd2a0..df50f31de88cd8 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc +++ b/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc b/tensorflow/lite/toco/graph_transformations/identify_dilated_conv.cc similarity index 97% rename from tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc rename to tensorflow/lite/toco/graph_transformations/identify_dilated_conv.cc index 9e4a3005a1d534..e27f975348b7ec 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_dilated_conv.cc @@ -15,9 +15,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/lite/toco/graph_transformations/identify_l2_normalization.cc similarity index 96% rename from tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc rename to tensorflow/lite/toco/graph_transformations/identify_l2_normalization.cc index 78f60f52fbdbc7..dabd4bd209f450 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_l2_normalization.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc b/tensorflow/lite/toco/graph_transformations/identify_l2_pool.cc similarity index 95% rename from tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc rename to tensorflow/lite/toco/graph_transformations/identify_l2_pool.cc index 13664bb344def9..6e0a7cdc31af2b 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_l2_pool.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/lite/toco/graph_transformations/identify_lstm.cc similarity index 98% rename from tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc rename to tensorflow/lite/toco/graph_transformations/identify_lstm.cc index 7fd8f906e2c270..089ecee959a3ab 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_lstm.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc b/tensorflow/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc similarity index 96% rename from tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc rename to tensorflow/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc index 6ccce923f361d7..2fae01a6987277 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc @@ -18,10 +18,10 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/string_view.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/graph_transformations/lstm_utils.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc b/tensorflow/lite/toco/graph_transformations/identify_lstm_split_inputs.cc similarity index 96% rename from tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc rename to tensorflow/lite/toco/graph_transformations/identify_lstm_split_inputs.cc index ad5120e2aa5f4e..c519e654636a55 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_lstm_split_inputs.cc @@ -18,10 +18,10 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/string_view.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/graph_transformations/lstm_utils.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc b/tensorflow/lite/toco/graph_transformations/identify_prelu.cc similarity index 96% rename from tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc rename to tensorflow/lite/toco/graph_transformations/identify_prelu.cc index c11fee4dc94041..1205ddc7304f39 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_prelu.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" // This transformation rule tries to identify the PRelu structure generated by diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/lite/toco/graph_transformations/identify_relu1.cc similarity index 95% rename from tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc rename to tensorflow/lite/toco/graph_transformations/identify_relu1.cc index 51d0629362edbe..bcd5b0ca04a8a2 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_relu1.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.cc b/tensorflow/lite/toco/graph_transformations/lstm_utils.cc similarity index 98% rename from tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.cc rename to tensorflow/lite/toco/graph_transformations/lstm_utils.cc index 910a9605897988..3414a7fd7fe2d1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.cc +++ b/tensorflow/lite/toco/graph_transformations/lstm_utils.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h" +#include "tensorflow/lite/toco/graph_transformations/lstm_utils.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h b/tensorflow/lite/toco/graph_transformations/lstm_utils.h similarity index 92% rename from tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h rename to tensorflow/lite/toco/graph_transformations/lstm_utils.h index 6d8603a1133a74..949292ee84b292 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h +++ b/tensorflow/lite/toco/graph_transformations/lstm_utils.h @@ -12,20 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_LSTM_UTILS_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_LSTM_UTILS_H_ +#ifndef TENSORFLOW_LITE_TOCO_GRAPH_TRANSFORMATIONS_LSTM_UTILS_H_ +#define TENSORFLOW_LITE_TOCO_GRAPH_TRANSFORMATIONS_LSTM_UTILS_H_ #include #include #include -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" namespace toco { // For consistency with the parameters defined in extended LstmCell's kernel -// (tensorflow/contrib/lite/kernels/lstm.cc), +// (tensorflow/lite/kernels/lstm.cc), // use lowercase for these constants. enum ExtendedLstmCellInputs { @@ -108,4 +108,4 @@ bool GetMatchingRnnArray(Model* model, const string& back_edge_source_array, } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_LSTM_UTILS_H_ +#endif // TENSORFLOW_LITE_TOCO_GRAPH_TRANSFORMATIONS_LSTM_UTILS_H_ diff --git a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc b/tensorflow/lite/toco/graph_transformations/make_initial_dequantize_operator.cc similarity index 93% rename from tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc rename to tensorflow/lite/toco/graph_transformations/make_initial_dequantize_operator.cc index 5bf17d5b4cd1a0..b914838b91c965 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc +++ b/tensorflow/lite/toco/graph_transformations/make_initial_dequantize_operator.cc @@ -17,11 +17,11 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/model_flags.pb.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/graph_transformations/quantization_util.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc b/tensorflow/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc similarity index 94% rename from tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc rename to tensorflow/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc index 06de9b1cd89571..80170fe8bcb73e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc +++ b/tensorflow/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc @@ -18,11 +18,11 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/runtime/types.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/runtime/types.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc b/tensorflow/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc similarity index 97% rename from tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc rename to tensorflow/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc index f0d8d924adbd34..0f3c4d34d66ea2 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc +++ b/tensorflow/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc @@ -14,9 +14,9 @@ ==============================================================================*/ #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc b/tensorflow/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc similarity index 92% rename from tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc rename to tensorflow/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc index 9c1ed2b732dcce..95de60262e754d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc @@ -17,11 +17,11 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/runtime/types.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/runtime/types.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc similarity index 98% rename from tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc rename to tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc index 47faa20a291a0d..9a458dccb9cbb3 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -17,8 +17,8 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc b/tensorflow/lite/toco/graph_transformations/propagate_default_min_max.cc similarity index 92% rename from tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc rename to tensorflow/lite/toco/graph_transformations/propagate_default_min_max.cc index 3cf191436dc8d5..d31ba956afd986 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_default_min_max.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/graph_transformations/quantization_util.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc similarity index 97% rename from tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc rename to tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc index d0113237ce6e43..04a5a1c1687b4c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/graph_transformations/quantization_util.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc similarity index 99% rename from tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc rename to tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 514a596f1e4c20..78ea54e452b9dd 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -20,10 +20,10 @@ limitations under the License. #include #include "absl/strings/str_join.h" -#include "tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/kernels/internal/strided_slice_logic.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc b/tensorflow/lite/toco/graph_transformations/quantization_util.cc similarity index 97% rename from tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc rename to tensorflow/lite/toco/graph_transformations/quantization_util.cc index 82146c5a66127a..56f83c9793f723 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc +++ b/tensorflow/lite/toco/graph_transformations/quantization_util.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/graph_transformations/quantization_util.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h b/tensorflow/lite/toco/graph_transformations/quantization_util.h similarity index 85% rename from tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h rename to tensorflow/lite/toco/graph_transformations/quantization_util.h index cf093c6f17b458..d226aeab8b788c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h +++ b/tensorflow/lite/toco/graph_transformations/quantization_util.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_QUANTIZATION_UTIL_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_QUANTIZATION_UTIL_H_ +#ifndef TENSORFLOW_LITE_TOCO_GRAPH_TRANSFORMATIONS_QUANTIZATION_UTIL_H_ +#define TENSORFLOW_LITE_TOCO_GRAPH_TRANSFORMATIONS_QUANTIZATION_UTIL_H_ -#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" namespace toco { @@ -60,4 +60,4 @@ bool IsArrayQuantizedRangeSubset(GraphTransformation* transformation, } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_QUANTIZATION_UTIL_H_ +#endif // TENSORFLOW_LITE_TOCO_GRAPH_TRANSFORMATIONS_QUANTIZATION_UTIL_H_ diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/lite/toco/graph_transformations/quantize.cc similarity index 98% rename from tensorflow/contrib/lite/toco/graph_transformations/quantize.cc rename to tensorflow/lite/toco/graph_transformations/quantize.cc index 0a89deadcc2fca..e28b7288f0102a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/lite/toco/graph_transformations/quantize.cc @@ -20,11 +20,11 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/model_flags.pb.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/graph_transformations/quantization_util.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc b/tensorflow/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc similarity index 93% rename from tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc rename to tensorflow/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc index 0c32218ff2e972..4d621018dc3fc5 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc +++ b/tensorflow/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc b/tensorflow/lite/toco/graph_transformations/remove_final_dequantize_op.cc similarity index 90% rename from tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc rename to tensorflow/lite/toco/graph_transformations/remove_final_dequantize_op.cc index fe8023ab8fe1d8..ed551d0122348e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_final_dequantize_op.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/model_flags.pb.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc b/tensorflow/lite/toco/graph_transformations/remove_tensorflow_assert.cc similarity index 91% rename from tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc rename to tensorflow/lite/toco/graph_transformations/remove_tensorflow_assert.cc index be8c0acc7b5cc6..647146b407116a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_tensorflow_assert.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc b/tensorflow/lite/toco/graph_transformations/remove_tensorflow_identity.cc similarity index 83% rename from tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc rename to tensorflow/lite/toco/graph_transformations/remove_tensorflow_identity.cc index 37fe5fa3d7190c..e0f7bc9a053b5d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_tensorflow_identity.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc b/tensorflow/lite/toco/graph_transformations/remove_trivial_binary.cc similarity index 94% rename from tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc rename to tensorflow/lite/toco/graph_transformations/remove_trivial_binary.cc index 68c6fb65c5c6b8..8879a7cd2664ed 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_trivial_binary.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc b/tensorflow/lite/toco/graph_transformations/remove_trivial_concatenation.cc similarity index 83% rename from tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc rename to tensorflow/lite/toco/graph_transformations/remove_trivial_concatenation.cc index faaa2a828e306c..bfa9314a6964f4 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_trivial_concatenation.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc b/tensorflow/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc similarity index 93% rename from tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc rename to tensorflow/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc index ccfc181fe00745..565ccb663a8008 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_fake_quant.cc b/tensorflow/lite/toco/graph_transformations/remove_trivial_fake_quant.cc similarity index 91% rename from tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_fake_quant.cc rename to tensorflow/lite/toco/graph_transformations/remove_trivial_fake_quant.cc index 5448a816bc43ac..2891e41f3072c6 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_fake_quant.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_trivial_fake_quant.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc b/tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.cc similarity index 97% rename from tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc rename to tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.cc index d5983a1f12ffbc..5239d550762fe3 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h b/tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h similarity index 86% rename from tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h rename to tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h index 663704e5acf745..315edc0121afc6 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h +++ b/tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_ +#ifndef TENSORFLOW_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_ +#define TENSORFLOW_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_ -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" namespace toco { @@ -55,4 +55,4 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation, } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_ +#endif // TENSORFLOW_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_ diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc b/tensorflow/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc similarity index 89% rename from tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc rename to tensorflow/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc index 4133815285fdc5..56acf22f7f1de2 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc @@ -17,13 +17,13 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/runtime/types.h" -#include "tensorflow/contrib/lite/toco/toco_types.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/graph_transformations/quantization_util.h" +#include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/runtime/types.h" +#include "tensorflow/lite/toco/toco_types.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc b/tensorflow/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc similarity index 87% rename from tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc rename to tensorflow/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc index 0f0ae4af693728..f1037994c97616 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc @@ -17,13 +17,13 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/runtime/types.h" -#include "tensorflow/contrib/lite/toco/toco_types.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/graph_transformations/quantization_util.h" +#include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/runtime/types.h" +#include "tensorflow/lite/toco/toco_types.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc b/tensorflow/lite/toco/graph_transformations/remove_trivial_reshape.cc similarity index 92% rename from tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc rename to tensorflow/lite/toco/graph_transformations/remove_trivial_reshape.cc index 1caf9448797984..7dea3c79c57a48 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_trivial_reshape.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_slice.cc b/tensorflow/lite/toco/graph_transformations/remove_trivial_slice.cc similarity index 88% rename from tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_slice.cc rename to tensorflow/lite/toco/graph_transformations/remove_trivial_slice.cc index dcb0148d583f1c..330e16b3afdc88 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_slice.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_trivial_slice.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc b/tensorflow/lite/toco/graph_transformations/remove_unused_op.cc similarity index 94% rename from tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc rename to tensorflow/lite/toco/graph_transformations/remove_unused_op.cc index 3cd5d06baebc5a..ac05afb81947a4 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_unused_op.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/model_flags.pb.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc b/tensorflow/lite/toco/graph_transformations/reorder_elementwise_unary.cc similarity index 96% rename from tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc rename to tensorflow/lite/toco/graph_transformations/reorder_elementwise_unary.cc index 3c8d41108918ad..6a4b9198548956 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc +++ b/tensorflow/lite/toco/graph_transformations/reorder_elementwise_unary.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc b/tensorflow/lite/toco/graph_transformations/reorder_reshape_transpose.cc similarity index 97% rename from tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc rename to tensorflow/lite/toco/graph_transformations/reorder_reshape_transpose.cc index a2c06e71e8ec93..fdd411c84c2678 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc +++ b/tensorflow/lite/toco/graph_transformations/reorder_reshape_transpose.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc b/tensorflow/lite/toco/graph_transformations/resolve_batch_normalization.cc similarity index 96% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc rename to tensorflow/lite/toco/graph_transformations/resolve_batch_normalization.cc index a79779f55d9c13..e972e5c9014865 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_batch_normalization.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/runtime/types.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/runtime/types.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc b/tensorflow/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc similarity index 93% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc rename to tensorflow/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc index d039d7d690d715..7aa92de4f6f878 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_binary.cc similarity index 97% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc rename to tensorflow/lite/toco/graph_transformations/resolve_constant_binary.cc index 586f546a30da32..0e1671c61c6b89 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_binary.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/runtime/types.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/runtime/types.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_concatenation.cc similarity index 97% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc rename to tensorflow/lite/toco/graph_transformations/resolve_constant_concatenation.cc index a26aa21def090d..98ff4ab02ea621 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_concatenation.cc @@ -19,9 +19,9 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_fake_quant.cc similarity index 95% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc rename to tensorflow/lite/toco/graph_transformations/resolve_constant_fake_quant.cc index 4f330fdd840153..d52f7d49169c7e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_fake_quant.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/graph_transformations/quantization_util.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_fill.cc similarity index 95% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc rename to tensorflow/lite/toco/graph_transformations/resolve_constant_fill.cc index ef234563fdc518..c9021019bf4167 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_fill.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_gather.cc similarity index 96% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc rename to tensorflow/lite/toco/graph_transformations/resolve_constant_gather.cc index 26616374d8a8b4..1149930131e0d6 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_gather.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_pack.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_pack.cc similarity index 95% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_pack.cc rename to tensorflow/lite/toco/graph_transformations/resolve_constant_pack.cc index 55fa0e22eb31ee..168f79bebdaaaa 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_pack.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_pack.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_random_uniform.cc similarity index 95% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc rename to tensorflow/lite/toco/graph_transformations/resolve_constant_random_uniform.cc index db0fbba52826eb..a8afbb7de54204 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_random_uniform.cc @@ -15,9 +15,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/lib/random/philox_random.h" diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_range.cc similarity index 96% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc rename to tensorflow/lite/toco/graph_transformations/resolve_constant_range.cc index 198e95ab8833c0..4cb27d97ec1c92 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_range.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_reshape.cc similarity index 95% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc rename to tensorflow/lite/toco/graph_transformations/resolve_constant_reshape.cc index ef4a896e32f0f4..9e21fa564e89d4 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_reshape.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_select.cc similarity index 91% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc rename to tensorflow/lite/toco/graph_transformations/resolve_constant_select.cc index ab1e0bd7a076ac..82b2f4ab8782d6 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_select.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc similarity index 93% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc rename to tensorflow/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc index a1756a820734d6..00ab85882796b8 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_slice.cc similarity index 96% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc rename to tensorflow/lite/toco/graph_transformations/resolve_constant_slice.cc index 6b6465c2feb883..503807f2318c74 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_slice.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_strided_slice.cc similarity index 96% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc rename to tensorflow/lite/toco/graph_transformations/resolve_constant_strided_slice.cc index 83cb8bec0d6146..0c9effee1fd364 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_strided_slice.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/kernels/internal/strided_slice_logic.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_tile.cc similarity index 97% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc rename to tensorflow/lite/toco/graph_transformations/resolve_constant_tile.cc index 685fec96191e35..75631304968e21 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_tile.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_transpose.cc similarity index 97% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc rename to tensorflow/lite/toco/graph_transformations/resolve_constant_transpose.cc index 612d03b97ed39f..9514848682f54d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_transpose.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/lite/toco/graph_transformations/resolve_constant_unary.cc similarity index 98% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc rename to tensorflow/lite/toco/graph_transformations/resolve_constant_unary.cc index 3034c1b1eb0fcf..43070b063c4a42 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_constant_unary.cc @@ -20,10 +20,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/runtime/types.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/runtime/types.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc b/tensorflow/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc similarity index 94% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc rename to tensorflow/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc index eed971c1d50293..c0becaf7d39cdb 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc b/tensorflow/lite/toco/graph_transformations/resolve_gather_attributes.cc similarity index 91% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc rename to tensorflow/lite/toco/graph_transformations/resolve_gather_attributes.cc index 69209b8dec7dd0..ffad0d0d315128 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_gather_attributes.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc b/tensorflow/lite/toco/graph_transformations/resolve_multiply_by_zero.cc similarity index 96% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc rename to tensorflow/lite/toco/graph_transformations/resolve_multiply_by_zero.cc index e5b74e2bb155a5..51c724dd1ab058 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_multiply_by_zero.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc b/tensorflow/lite/toco/graph_transformations/resolve_pad_attributes.cc similarity index 91% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc rename to tensorflow/lite/toco/graph_transformations/resolve_pad_attributes.cc index adc87753bc71cc..25b823f8483935 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_pad_attributes.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_padv2_attributes.cc b/tensorflow/lite/toco/graph_transformations/resolve_padv2_attributes.cc similarity index 91% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_padv2_attributes.cc rename to tensorflow/lite/toco/graph_transformations/resolve_padv2_attributes.cc index 1f0f17a37a99c6..bcc9f5363ac080 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_padv2_attributes.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_padv2_attributes.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc b/tensorflow/lite/toco/graph_transformations/resolve_reduce_attributes.cc similarity index 93% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc rename to tensorflow/lite/toco/graph_transformations/resolve_reduce_attributes.cc index c3246ab90fc492..ea5d33009b4b60 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_reduce_attributes.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc b/tensorflow/lite/toco/graph_transformations/resolve_reorder_axes.cc similarity index 96% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc rename to tensorflow/lite/toco/graph_transformations/resolve_reorder_axes.cc index ee5c4810e61326..f70e80b8e7702b 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_reorder_axes.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc b/tensorflow/lite/toco/graph_transformations/resolve_reshape_attributes.cc similarity index 90% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc rename to tensorflow/lite/toco/graph_transformations/resolve_reshape_attributes.cc index 7b7a59264ff74f..24a3482a6fe9dd 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_reshape_attributes.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc b/tensorflow/lite/toco/graph_transformations/resolve_slice_attributes.cc similarity index 91% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc rename to tensorflow/lite/toco/graph_transformations/resolve_slice_attributes.cc index 5a838168de7382..1f86b35c34cab2 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_slice_attributes.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc b/tensorflow/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc similarity index 94% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc rename to tensorflow/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc index 3804145c4f8cef..dd1e6fccd72aa9 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc b/tensorflow/lite/toco/graph_transformations/resolve_squeeze_attributes.cc similarity index 86% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc rename to tensorflow/lite/toco/graph_transformations/resolve_squeeze_attributes.cc index c601b0774e6274..3f2ae471a2f10a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_squeeze_attributes.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc b/tensorflow/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc similarity index 96% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc rename to tensorflow/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc index f54f5b42a1f4c0..a62e082e836797 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc b/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_concat.cc similarity index 94% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc rename to tensorflow/lite/toco/graph_transformations/resolve_tensorflow_concat.cc index 4927ccd95d34f3..ce185847cd0dde 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_concat.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc similarity index 97% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc rename to tensorflow/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc index da039da546fc5d..637ffda533ae1a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc b/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_merge.cc similarity index 92% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc rename to tensorflow/lite/toco/graph_transformations/resolve_tensorflow_merge.cc index 9beea3e937b284..9ee4e6ec6b7a73 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_merge.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc b/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_switch.cc similarity index 96% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc rename to tensorflow/lite/toco/graph_transformations/resolve_tensorflow_switch.cc index e215981b42262f..f26efacaaeec75 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_tensorflow_switch.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc b/tensorflow/lite/toco/graph_transformations/resolve_transpose_attributes.cc similarity index 91% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc rename to tensorflow/lite/toco/graph_transformations/resolve_transpose_attributes.cc index aa7945391c766c..71c0a884da1f73 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc +++ b/tensorflow/lite/toco/graph_transformations/resolve_transpose_attributes.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc b/tensorflow/lite/toco/graph_transformations/shuffle_fc_weights.cc similarity index 97% rename from tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc rename to tensorflow/lite/toco/graph_transformations/shuffle_fc_weights.cc index e9f24a29ab4695..195ea70e34bf87 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc +++ b/tensorflow/lite/toco/graph_transformations/shuffle_fc_weights.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/lite/toco/graph_transformations/tests/BUILD similarity index 60% rename from tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD rename to tensorflow/lite/toco/graph_transformations/tests/BUILD index 6f1be298caaf11..2e9b213d0018f5 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD +++ b/tensorflow/lite/toco/graph_transformations/tests/BUILD @@ -12,9 +12,9 @@ tf_cc_test( srcs = ["lstm_utils_test.cc"], tags = ["no_oss"], deps = [ - "//tensorflow/contrib/lite/toco:graph_transformations", - "//tensorflow/contrib/lite/toco:model", - "//tensorflow/contrib/lite/toco:tooling_util", + "//tensorflow/lite/toco:graph_transformations", + "//tensorflow/lite/toco:model", + "//tensorflow/lite/toco:tooling_util", "@com_google_googletest//:gtest_main", ], ) @@ -24,9 +24,9 @@ tf_cc_test( srcs = ["resolve_constant_concatenation_test.cc"], tags = ["no_oss"], deps = [ - "//tensorflow/contrib/lite/toco:graph_transformations", - "//tensorflow/contrib/lite/toco:model", - "//tensorflow/contrib/lite/toco:tooling_util", + "//tensorflow/lite/toco:graph_transformations", + "//tensorflow/lite/toco:model", + "//tensorflow/lite/toco:tooling_util", "@com_google_googletest//:gtest_main", ], ) @@ -36,9 +36,9 @@ tf_cc_test( srcs = ["resolve_constant_unary_test.cc"], tags = ["no_oss"], deps = [ - "//tensorflow/contrib/lite/toco:graph_transformations", - "//tensorflow/contrib/lite/toco:model", - "//tensorflow/contrib/lite/toco:tooling_util", + "//tensorflow/lite/toco:graph_transformations", + "//tensorflow/lite/toco:model", + "//tensorflow/lite/toco:tooling_util", "@com_google_absl//absl/memory", "@com_google_googletest//:gtest_main", ], diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/lstm_utils_test.cc b/tensorflow/lite/toco/graph_transformations/tests/lstm_utils_test.cc similarity index 99% rename from tensorflow/contrib/lite/toco/graph_transformations/tests/lstm_utils_test.cc rename to tensorflow/lite/toco/graph_transformations/tests/lstm_utils_test.cc index 6aae0775d3445d..bdb27e8af2e359 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/tests/lstm_utils_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/lstm_utils_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/lstm_utils.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc b/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc similarity index 97% rename from tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc rename to tensorflow/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc index e2a6f12481c336..00d60b79ca96f6 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc b/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc similarity index 94% rename from tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc rename to tensorflow/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc index 57d85a0435179f..246a13a0610216 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include "absl/memory/memory.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc b/tensorflow/lite/toco/graph_transformations/unfuse_activation_functions.cc similarity index 92% rename from tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc rename to tensorflow/lite/toco/graph_transformations/unfuse_activation_functions.cc index 4ada5c3fd07260..3e36dd5a45c720 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc +++ b/tensorflow/lite/toco/graph_transformations/unfuse_activation_functions.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/runtime/types.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/runtime/types.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc b/tensorflow/lite/toco/graph_transformations/unpartition_embedding_lookup.cc similarity index 98% rename from tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc rename to tensorflow/lite/toco/graph_transformations/unpartition_embedding_lookup.cc index e19527968d67f9..e57f175812f4c6 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc +++ b/tensorflow/lite/toco/graph_transformations/unpartition_embedding_lookup.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc b/tensorflow/lite/toco/graph_transformations/unroll_batch_matmul.cc similarity index 97% rename from tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc rename to tensorflow/lite/toco/graph_transformations/unroll_batch_matmul.cc index 5ff39aa313b279..d59954fc740ed9 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc +++ b/tensorflow/lite/toco/graph_transformations/unroll_batch_matmul.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/lite/toco/import_tensorflow.cc similarity index 95% rename from tensorflow/contrib/lite/toco/import_tensorflow.cc rename to tensorflow/lite/toco/import_tensorflow.cc index c0d943039f39ca..76c6985e3a2e09 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/lite/toco/import_tensorflow.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/import_tensorflow.h" +#include "tensorflow/lite/toco/import_tensorflow.h" #include #include @@ -27,11 +27,11 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "absl/strings/strip.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/model_flags.pb.h" -#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h" -#include "tensorflow/contrib/lite/toco/tensorflow_util.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/tensorflow_graph_matching/resolve_cluster.h" +#include "tensorflow/lite/toco/tensorflow_util.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" @@ -1122,28 +1122,137 @@ tensorflow::Status ConvertConcatOperator( return tensorflow::Status::OK(); } +static constexpr int kAnyNumInputs = -1; + +enum FlexSupport { kFlexOk, kFlexNotOk }; + // This method supports simple operators without additional attributes. -template -tensorflow::Status ConvertSimpleOperator( +// Converts a simple operator that takes no attributes. The list of inputs is +// taken from the given NodeDef, and its number must match NumInputs, unless +// kAnyNumInputs is passed in. If kFlexOk is passed in the resulting operator +// will be eligible for being exported as a flex op. +template +tensorflow::Status ConvertSimpleOperatorGeneric( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { + if (NumInputs != kAnyNumInputs) { + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, NumInputs)); + } auto* op = new Op; const int num_inputs = GetInputsCount(node, tf_import_flags); for (int i = 0; i < num_inputs; ++i) { op->inputs.push_back(node.input(i)); } op->outputs.push_back(node.name()); + + if (flex == kFlexOk) { + RetainTensorFlowNodeDef(node, op); + } + model->operators.emplace_back(op); return tensorflow::Status::OK(); } -// This method supports simple operators without additional attributes. -template +// Convert a simple operator which is not valid as a flex op. +template tensorflow::Status ConvertSimpleOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { - TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, NumInputs)); - return ConvertSimpleOperator(node, tf_import_flags, model); + return ConvertSimpleOperatorGeneric( + node, tf_import_flags, model); +} + +// Convert a simple operator which is valid as a flex op. +template +tensorflow::Status ConvertSimpleOperatorFlexOk( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + return ConvertSimpleOperatorGeneric( + node, tf_import_flags, model); +} + +void GetOutputNamesFromNodeDef(const NodeDef& node, + const tensorflow::OpDef& op_def, + TensorFlowUnsupportedOperator* op) { + int next_output = 0; + auto add_output = [&node, &next_output, op]() { + if (next_output == 0) { + op->outputs.push_back(node.name()); // Implicit :0. + } else { + op->outputs.push_back(absl::StrCat(node.name(), ":", next_output)); + } + ++next_output; + }; + for (int i = 0; i < op_def.output_arg_size(); ++i) { + string multiples = op_def.output_arg(i).number_attr(); + if (!multiples.empty()) { + CHECK(HasAttr(node, multiples)) << "No attr named " << multiples; + int num_outputs = GetIntAttr(node, multiples); + for (int j = 0; j < num_outputs; ++j) { + add_output(); + } + } else { + string list = op_def.output_arg(i).type_list_attr(); + if (!list.empty()) { + CHECK(HasAttr(node, list)) << "No attr named " << list; + const AttrValue::ListValue& list_value = GetListAttr(node, list); + for (int j = 0; j < list_value.type_size(); ++j) { + add_output(); + } + } else { + add_output(); + } + } + } +} + +void GetOutputTypesFromNodeDef(const NodeDef& node, + const tensorflow::OpDef& op_def, + TensorFlowUnsupportedOperator* op) { + // The the given type to the op, or clear the types if invalid. + auto add_type = [&node, op](tensorflow::DataType type) { + if (type == tensorflow::DT_INVALID) { + LOG(WARNING) << "Op node missing output type attribute: " << node.name(); + op->output_data_types.clear(); + } else { + op->output_data_types.push_back(ConvertDataType(type)); + } + }; + + // Retrieve the data type according to the OpDef definition: either the + // "type" or "type_attr" field will be set. + auto get_type = [&node](const tensorflow::OpDef::ArgDef& a) { + if (a.type() != tensorflow::DT_INVALID) { + return a.type(); + } else if (HasAttr(node, a.type_attr())) { + return GetDataTypeAttr(node, a.type_attr()); + } else { + return tensorflow::DT_INVALID; + } + }; + + for (int i = 0; i < op_def.output_arg_size(); ++i) { + string multiples = op_def.output_arg(i).number_attr(); + if (!multiples.empty()) { + CHECK(HasAttr(node, multiples)) << "No attr named " << multiples; + int num_outputs = GetIntAttr(node, multiples); + auto type = get_type(op_def.output_arg(i)); + for (int j = 0; j < num_outputs; ++j) { + add_type(type); + } + } else { + string list = op_def.output_arg(i).type_list_attr(); + if (!list.empty()) { + CHECK(HasAttr(node, list)) << "No attr named " << list; + const AttrValue::ListValue& list_value = GetListAttr(node, list); + for (int j = 0; j < list_value.type_size(); ++j) { + add_type(list_value.type(j)); + } + } else { + add_type(get_type(op_def.output_arg(i))); + } + } + } } tensorflow::Status ConvertUnsupportedOperator( @@ -1176,19 +1285,7 @@ tensorflow::Status ConvertUnsupportedOperator( // Note that some outputs are to be multipled by a named attribute. const tensorflow::OpDef* op_def = nullptr; if (tensorflow::OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) { - int next_output = 0; - for (int i = 0; i < op_def->output_arg_size(); ++i) { - string multiples = op_def->output_arg(i).number_attr(); - int num_outputs = multiples.empty() ? 1 : GetIntAttr(node, multiples); - for (int j = 0; j < num_outputs; ++j) { - if (next_output == 0) { - op->outputs.push_back(node.name()); // Implicit :0. - } else { - op->outputs.push_back(absl::StrCat(node.name(), ":", next_output)); - } - ++next_output; - } - } + GetOutputNamesFromNodeDef(node, *op_def, op); } else { op->outputs.push_back(node.name()); // Implicit :0. } @@ -1213,19 +1310,7 @@ tensorflow::Status ConvertUnsupportedOperator( const auto& output_type = GetDataTypeAttr(node, "Tout"); op->output_data_types.push_back(ConvertDataType(output_type)); } else if (op_def != nullptr) { - for (const auto& output_arg : op_def->output_arg()) { - if (output_arg.type() != tensorflow::DT_INVALID) { - op->output_data_types.push_back(ConvertDataType(output_arg.type())); - } else if (HasAttr(node, output_arg.type_attr())) { - op->output_data_types.push_back( - ConvertDataType(GetDataTypeAttr(node, output_arg.type_attr()))); - } else { - LOG(WARNING) << "Op node missing output type attribute: " - << node.name(); - op->output_data_types.clear(); - break; - } - } + GetOutputTypesFromNodeDef(node, *op_def, op); } else { // TODO(b/113613439): Figure out how to propagate types for custom ops // that have no OpDef. @@ -2143,7 +2228,7 @@ ConverterMapType GetTensorFlowNodeConverterMapForFlex() { ConverterMapType GetTensorFlowNodeConverterMap() { return std::unordered_map({ {"Add", ConvertSimpleOperator}, - {"AddN", ConvertSimpleOperator}, + {"AddN", ConvertSimpleOperatorFlexOk}, {"All", ConvertSimpleOperator}, {"Any", ConvertReduceOperator}, {"ArgMax", ConvertArgMaxOperator}, diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.h b/tensorflow/lite/toco/import_tensorflow.h similarity index 82% rename from tensorflow/contrib/lite/toco/import_tensorflow.h rename to tensorflow/lite/toco/import_tensorflow.h index c5ff96956a748d..5b74ff2bc31a0f 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.h +++ b/tensorflow/lite/toco/import_tensorflow.h @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_ +#ifndef TENSORFLOW_LITE_TOCO_IMPORT_TENSORFLOW_H_ +#define TENSORFLOW_LITE_TOCO_IMPORT_TENSORFLOW_H_ #include #include -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/core/framework/graph.pb.h" namespace toco { @@ -30,7 +30,7 @@ struct TensorFlowImportFlags { // Do not recognize any op and import all ops as // `TensorFlowUnsupportedOperator`. This is used to populated with the - // `force_flex_ops` flag. + // `force_select_tf_ops` flag. bool import_all_ops_as_unsupported = false; }; @@ -44,4 +44,4 @@ std::unique_ptr ImportTensorFlowGraphDef( } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_ +#endif // TENSORFLOW_LITE_TOCO_IMPORT_TENSORFLOW_H_ diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/lite/toco/import_tensorflow_test.cc similarity index 82% rename from tensorflow/contrib/lite/toco/import_tensorflow_test.cc rename to tensorflow/lite/toco/import_tensorflow_test.cc index 587148a930283d..30aa725f1db874 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc +++ b/tensorflow/lite/toco/import_tensorflow_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/import_tensorflow.h" +#include "tensorflow/lite/toco/import_tensorflow.h" #include #include @@ -332,13 +332,44 @@ TEST(ImportTest, UnsupportedOpWithWildcardOutputShapes) { } TEST(ImportTest, UnsupportedOpWithMultipleOutputs) { - NodeDef node = BuildNode("Unpack", {}); + // This test needs an existing TensorFlow op to run correctly, because it + // read the OpDef from the global registry. The complex output setup of + // ParseExample allows us to test all nuances here, but we will need to add + // attributes to match the specification in the OpDef. + NodeDef node = BuildNode("ParseExample", {}); + + // Nsparse defines how many sparse indices and shapes there are. Here we set + // Nsparse to 2, meaning there will be 2 INT64 tensors for 'sparse_indices' + // and 2 INT64 tensors for 'sparse_shapes. The type of those tensors is + // defined in the OpDef. + { + AttrValue value_attr; + SetAttrValue(2, &value_attr); + (*node.mutable_attr())["Nsparse"] = value_attr; + } - // Unpack's OpDef has a single output which gets multiplied based on the - // "num" attribute of the NodeDef. - AttrValue value_attr; - SetAttrValue(3, &value_attr); // 3 outputs. - (*node.mutable_attr())["num"] = value_attr; + // The there will be a number of 'sparse_values' tensors, defined by the + // attribute 'sparse_types', which is a list of types. + { + AttrValue value_attr; + std::vector types; + types.push_back(tensorflow::DT_FLOAT); + types.push_back(tensorflow::DT_STRING); + SetAttrValue(types, &value_attr); + (*node.mutable_attr())["sparse_types"] = value_attr; + } + + // And finally there will be 'dense_values' tensors, which are controlled by + // the 'Tdense' attribute. + { + AttrValue value_attr; + std::vector types; + types.push_back(tensorflow::DT_STRING); + types.push_back(tensorflow::DT_FLOAT); + types.push_back(tensorflow::DT_INT64); + SetAttrValue(types, &value_attr); + (*node.mutable_attr())["Tdense"] = value_attr; + } Model model; EXPECT_TRUE(ImportFlexNode(node, &model).ok()); @@ -349,10 +380,34 @@ TEST(ImportTest, UnsupportedOpWithMultipleOutputs) { static_cast( model.operators[0].get()); - ASSERT_EQ(op->outputs.size(), 3); + ASSERT_EQ(op->outputs.size(), 9); + ASSERT_EQ(op->output_data_types.size(), 9); + + // The 'sparse_indices' output tensors. ASSERT_EQ(op->outputs[0], "Node1"); ASSERT_EQ(op->outputs[1], "Node1:1"); + ASSERT_EQ(op->output_data_types[0], ArrayDataType::kInt64); + ASSERT_EQ(op->output_data_types[1], ArrayDataType::kInt64); + + // The 'sparse_values' output tensors. ASSERT_EQ(op->outputs[2], "Node1:2"); + ASSERT_EQ(op->outputs[3], "Node1:3"); + ASSERT_EQ(op->output_data_types[2], ArrayDataType::kFloat); + ASSERT_EQ(op->output_data_types[3], ArrayDataType::kString); + + // The 'sparse_shapes' output tensors. + ASSERT_EQ(op->outputs[4], "Node1:4"); + ASSERT_EQ(op->outputs[5], "Node1:5"); + ASSERT_EQ(op->output_data_types[4], ArrayDataType::kInt64); + ASSERT_EQ(op->output_data_types[5], ArrayDataType::kInt64); + + // The 'dense_shapes' output tensors. + ASSERT_EQ(op->outputs[6], "Node1:6"); + ASSERT_EQ(op->outputs[7], "Node1:7"); + ASSERT_EQ(op->outputs[8], "Node1:8"); + ASSERT_EQ(op->output_data_types[6], ArrayDataType::kString); + ASSERT_EQ(op->output_data_types[7], ArrayDataType::kFloat); + ASSERT_EQ(op->output_data_types[8], ArrayDataType::kInt64); } } // namespace diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/lite/toco/model.h similarity index 99% rename from tensorflow/contrib/lite/toco/model.h rename to tensorflow/lite/toco/model.h index 716dade6c84a26..f85e1c287879e6 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/lite/toco/model.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_ +#ifndef TENSORFLOW_LITE_TOCO_MODEL_H_ +#define TENSORFLOW_LITE_TOCO_MODEL_H_ #include #include @@ -24,10 +24,10 @@ limitations under the License. #include #include "absl/types/optional.h" -#include "tensorflow/contrib/lite/toco/model_flags.pb.h" -#include "tensorflow/contrib/lite/toco/runtime/types.h" -#include "tensorflow/contrib/lite/toco/toco_port.h" -#include "tensorflow/contrib/lite/toco/toco_types.h" +#include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/runtime/types.h" +#include "tensorflow/lite/toco/toco_port.h" +#include "tensorflow/lite/toco/toco_types.h" #include "tensorflow/core/platform/logging.h" namespace toco { @@ -2144,4 +2144,4 @@ class Model { }; } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_ +#endif // TENSORFLOW_LITE_TOCO_MODEL_H_ diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/lite/toco/model_cmdline_flags.cc similarity index 98% rename from tensorflow/contrib/lite/toco/model_cmdline_flags.cc rename to tensorflow/lite/toco/model_cmdline_flags.cc index b6a401aaf2f002..717a28bc615e0a 100644 --- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc +++ b/tensorflow/lite/toco/model_cmdline_flags.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h" +#include "tensorflow/lite/toco/model_cmdline_flags.h" #include #include @@ -22,9 +22,9 @@ limitations under the License. #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" -#include "tensorflow/contrib/lite/toco/args.h" -#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" -#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/lite/toco/args.h" +#include "tensorflow/lite/toco/toco_graphviz_dump_options.h" +#include "tensorflow/lite/toco/toco_port.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/command_line_flags.h" diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.h b/tensorflow/lite/toco/model_cmdline_flags.h similarity index 80% rename from tensorflow/contrib/lite/toco/model_cmdline_flags.h rename to tensorflow/lite/toco/model_cmdline_flags.h index c868d5c7d0b5a6..1642e053199b1a 100644 --- a/tensorflow/contrib/lite/toco/model_cmdline_flags.h +++ b/tensorflow/lite/toco/model_cmdline_flags.h @@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_ +#ifndef TENSORFLOW_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_ +#define TENSORFLOW_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_ #include #include #include -#include "tensorflow/contrib/lite/toco/args.h" -#include "tensorflow/contrib/lite/toco/model_flags.pb.h" -#include "tensorflow/contrib/lite/toco/types.pb.h" +#include "tensorflow/lite/toco/args.h" +#include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/types.pb.h" namespace toco { // Parse and remove arguments for models (in toco). Returns true if parsing @@ -40,4 +40,4 @@ ParsedModelFlags* GlobalParsedModelFlags(); } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_ +#endif // TENSORFLOW_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_ diff --git a/tensorflow/contrib/lite/toco/model_flags.proto b/tensorflow/lite/toco/model_flags.proto similarity index 99% rename from tensorflow/contrib/lite/toco/model_flags.proto rename to tensorflow/lite/toco/model_flags.proto index 6c1c53658c0736..bcdac295d261c0 100644 --- a/tensorflow/contrib/lite/toco/model_flags.proto +++ b/tensorflow/lite/toco/model_flags.proto @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto2"; -import "tensorflow/contrib/lite/toco/types.proto"; +import "tensorflow/lite/toco/types.proto"; package toco; diff --git a/tensorflow/contrib/lite/toco/python/BUILD b/tensorflow/lite/toco/python/BUILD similarity index 54% rename from tensorflow/contrib/lite/toco/python/BUILD rename to tensorflow/lite/toco/python/BUILD index cf97ba7084d48e..07056f66c35536 100644 --- a/tensorflow/contrib/lite/toco/python/BUILD +++ b/tensorflow/lite/toco/python/BUILD @@ -6,19 +6,32 @@ load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") load("//tensorflow:tensorflow.bzl", "tf_py_test") load("//tensorflow:tensorflow.bzl", "py_binary") +config_setting( + name = "tflite_convert_with_select_tf_ops", + define_values = {"tflite_convert_with_select_tf_ops": "true"}, + visibility = ["//visibility:public"], +) + cc_library( name = "toco_python_api", srcs = ["toco_python_api.cc"], hdrs = ["toco_python_api.h"], deps = [ - "//tensorflow/contrib/lite/toco:model_flags_proto_cc", - "//tensorflow/contrib/lite/toco:toco_flags_proto_cc", - "//tensorflow/contrib/lite/toco:toco_graphviz_dump_options", - "//tensorflow/contrib/lite/toco:toco_port", - "//tensorflow/contrib/lite/toco:toco_tooling", - "//tensorflow/core:lib", "//third_party/python_runtime:headers", - ], + "//tensorflow/core:lib", + "//tensorflow/lite/toco:model_flags_proto_cc", + "//tensorflow/lite/toco:toco_flags_proto_cc", + "//tensorflow/lite/toco:toco_graphviz_dump_options", + "//tensorflow/lite/toco:toco_port", + "//tensorflow/lite/toco:toco_tooling", + ] + select({ + # This is required when running `tflite_convert` from `bazel`. + # It requires to link with TensorFlow Ops to get the op definitions. + ":tflite_convert_with_select_tf_ops": [ + "//tensorflow/core:ops", + ], + "//conditions:default": [], + }), ) tf_py_wrap_cc( @@ -26,8 +39,8 @@ tf_py_wrap_cc( srcs = ["toco.i"], deps = [ ":toco_python_api", - "//tensorflow/contrib/lite/toco:model_flags_proto_cc", - "//tensorflow/contrib/lite/toco:toco_flags_proto_cc", + "//tensorflow/lite/toco:model_flags_proto_cc", + "//tensorflow/lite/toco:toco_flags_proto_cc", "//third_party/python_runtime:headers", "@com_google_absl//absl/strings", ], @@ -48,8 +61,8 @@ tf_py_test( srcs = ["toco_from_protos_test.py"], additional_deps = [ "//tensorflow:tensorflow_py", - "//tensorflow/contrib/lite/toco:model_flags_proto_py", - "//tensorflow/contrib/lite/toco:toco_flags_proto_py", + "//tensorflow/lite/toco:model_flags_proto_py", + "//tensorflow/lite/toco:toco_flags_proto_py", ], data = [ ":toco_from_protos", diff --git a/tensorflow/contrib/lite/toco/python/toco.i b/tensorflow/lite/toco/python/toco.i similarity index 95% rename from tensorflow/contrib/lite/toco/python/toco.i rename to tensorflow/lite/toco/python/toco.i index 0d2fbdd67b3aa5..c7dfdc35ab274f 100644 --- a/tensorflow/contrib/lite/toco/python/toco.i +++ b/tensorflow/lite/toco/python/toco.i @@ -16,7 +16,7 @@ limitations under the License. %include "std_string.i" %{ -#include "tensorflow/contrib/lite/toco/python/toco_python_api.h" +#include "tensorflow/lite/toco/python/toco_python_api.h" %} namespace toco { diff --git a/tensorflow/contrib/lite/toco/python/toco_from_protos.py b/tensorflow/lite/toco/python/toco_from_protos.py similarity index 96% rename from tensorflow/contrib/lite/toco/python/toco_from_protos.py rename to tensorflow/lite/toco/python/toco_from_protos.py index c0b032083b2347..152dd241eabba3 100644 --- a/tensorflow/contrib/lite/toco/python/toco_from_protos.py +++ b/tensorflow/lite/toco/python/toco_from_protos.py @@ -19,7 +19,7 @@ import argparse import sys -from tensorflow.contrib.lite.toco.python import tensorflow_wrap_toco +from tensorflow.lite.toco.python import tensorflow_wrap_toco from tensorflow.python.platform import app FLAGS = None diff --git a/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py b/tensorflow/lite/toco/python/toco_from_protos_test.py similarity index 95% rename from tensorflow/contrib/lite/toco/python/toco_from_protos_test.py rename to tensorflow/lite/toco/python/toco_from_protos_test.py index 75c1c8970c9cd5..34cfd2c59fdc3a 100644 --- a/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py +++ b/tensorflow/lite/toco/python/toco_from_protos_test.py @@ -20,9 +20,9 @@ import tempfile import tensorflow as tf -from tensorflow.contrib.lite.toco import model_flags_pb2 -from tensorflow.contrib.lite.toco import toco_flags_pb2 -from tensorflow.contrib.lite.toco import types_pb2 +from tensorflow.lite.toco import model_flags_pb2 +from tensorflow.lite.toco import toco_flags_pb2 +from tensorflow.lite.toco import types_pb2 from tensorflow.python.platform import googletest from tensorflow.python.platform import resource_loader diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.cc b/tensorflow/lite/toco/python/toco_python_api.cc similarity index 90% rename from tensorflow/contrib/lite/toco/python/toco_python_api.cc rename to tensorflow/lite/toco/python/toco_python_api.cc index 302b7f9bd4037b..ce8e3c9df88ba5 100644 --- a/tensorflow/contrib/lite/toco/python/toco_python_api.cc +++ b/tensorflow/lite/toco/python/toco_python_api.cc @@ -16,13 +16,13 @@ limitations under the License. #include #include "tensorflow/core/platform/logging.h" -#include "tensorflow/contrib/lite/toco/model_flags.pb.h" -#include "tensorflow/contrib/lite/toco/python/toco_python_api.h" -#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" -#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" -#include "tensorflow/contrib/lite/toco/toco_port.h" -#include "tensorflow/contrib/lite/toco/toco_tooling.h" -#include "tensorflow/contrib/lite/toco/toco_types.h" +#include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/python/toco_python_api.h" +#include "tensorflow/lite/toco/toco_flags.pb.h" +#include "tensorflow/lite/toco/toco_graphviz_dump_options.h" +#include "tensorflow/lite/toco/toco_port.h" +#include "tensorflow/lite/toco/toco_tooling.h" +#include "tensorflow/lite/toco/toco_types.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.h b/tensorflow/lite/toco/python/toco_python_api.h similarity index 88% rename from tensorflow/contrib/lite/toco/python/toco_python_api.h rename to tensorflow/lite/toco/python/toco_python_api.h index ee054bbed9823d..4ab0961e1276e4 100644 --- a/tensorflow/contrib/lite/toco/python/toco_python_api.h +++ b/tensorflow/lite/toco/python/toco_python_api.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_ +#ifndef TENSORFLOW_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_ +#define TENSORFLOW_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_ #include #include @@ -33,4 +33,4 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_ +#endif // TENSORFLOW_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_ diff --git a/tensorflow/contrib/lite/toco/runtime/common.h b/tensorflow/lite/toco/runtime/common.h similarity index 78% rename from tensorflow/contrib/lite/toco/runtime/common.h rename to tensorflow/lite/toco/runtime/common.h index 3c6828840c4a96..1f83be8fa81e0f 100644 --- a/tensorflow/contrib/lite/toco/runtime/common.h +++ b/tensorflow/lite/toco/runtime/common.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_ +#ifndef TENSORFLOW_LITE_TOCO_RUNTIME_COMMON_H_ +#define TENSORFLOW_LITE_TOCO_RUNTIME_COMMON_H_ #ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK #ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK @@ -21,6 +21,6 @@ limitations under the License. #endif #endif -#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/common.h" -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_ +#endif // TENSORFLOW_LITE_TOCO_RUNTIME_COMMON_H_ diff --git a/tensorflow/contrib/lite/toco/runtime/types.h b/tensorflow/lite/toco/runtime/types.h similarity index 72% rename from tensorflow/contrib/lite/toco/runtime/types.h rename to tensorflow/lite/toco/runtime/types.h index 207f2c1706ef4c..eac9b8af6e6c78 100644 --- a/tensorflow/contrib/lite/toco/runtime/types.h +++ b/tensorflow/lite/toco/runtime/types.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_ +#ifndef TENSORFLOW_LITE_TOCO_RUNTIME_TYPES_H_ +#define TENSORFLOW_LITE_TOCO_RUNTIME_TYPES_H_ -#include "tensorflow/contrib/lite/kernels/internal/common.h" -#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" -#include "tensorflow/contrib/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/types.h" namespace toco { @@ -30,4 +30,4 @@ using tflite::RequiredBufferSizeForDims; } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_ +#endif // TENSORFLOW_LITE_TOCO_RUNTIME_TYPES_H_ diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD b/tensorflow/lite/toco/tensorflow_graph_matching/BUILD similarity index 80% rename from tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD rename to tensorflow/lite/toco/tensorflow_graph_matching/BUILD index ea1fc2827ead7e..56acc284cc06d6 100644 --- a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD +++ b/tensorflow/lite/toco/tensorflow_graph_matching/BUILD @@ -16,7 +16,7 @@ cc_library( "cluster_utils.h", ], deps = [ - "//tensorflow/contrib/lite/toco:toco_port", + "//tensorflow/lite/toco:toco_port", ], ) @@ -30,9 +30,9 @@ cc_library( ], deps = [ ":cluster_utils", - "//tensorflow/contrib/lite/toco:model", - "//tensorflow/contrib/lite/toco:tooling_util", "//tensorflow/core:protos_all_cc", + "//tensorflow/lite/toco:model", + "//tensorflow/lite/toco:tooling_util", ], ) @@ -48,11 +48,11 @@ cc_library( deps = [ ":cluster", ":cluster_utils", - "//tensorflow/contrib/lite/toco:model", - "//tensorflow/contrib/lite/toco:toco_port", - "//tensorflow/contrib/lite/toco:tooling_util", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/lite/toco:model", + "//tensorflow/lite/toco:toco_port", + "//tensorflow/lite/toco:tooling_util", "@protobuf_archive//:protobuf_headers", ], ) @@ -85,7 +85,7 @@ cc_library( ":cluster", ":cluster_utils", ":resolve_svdf", - "//tensorflow/contrib/lite/toco:tooling_util", "//tensorflow/core:protos_all_cc", + "//tensorflow/lite/toco:tooling_util", ], ) diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.cc b/tensorflow/lite/toco/tensorflow_graph_matching/cluster.cc similarity index 95% rename from tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.cc rename to tensorflow/lite/toco/tensorflow_graph_matching/cluster.cc index 98a130ea39c45c..afce05dc7a932f 100644 --- a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.cc +++ b/tensorflow/lite/toco/tensorflow_graph_matching/cluster.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h" +#include "tensorflow/lite/toco/tensorflow_graph_matching/cluster.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h b/tensorflow/lite/toco/tensorflow_graph_matching/cluster.h similarity index 89% rename from tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h rename to tensorflow/lite/toco/tensorflow_graph_matching/cluster.h index fda7743a27e794..af268ddd3703f3 100644 --- a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h +++ b/tensorflow/lite/toco/tensorflow_graph_matching/cluster.h @@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H_ +#ifndef TENSORFLOW_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H_ +#define TENSORFLOW_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H_ #include #include -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tensorflow_graph_matching/cluster_utils.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -98,4 +98,4 @@ class ClusterFactoryInterface { } // end namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H_ +#endif // TENSORFLOW_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H_ diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.cc b/tensorflow/lite/toco/tensorflow_graph_matching/cluster_utils.cc similarity index 95% rename from tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.cc rename to tensorflow/lite/toco/tensorflow_graph_matching/cluster_utils.cc index 14c3cd6487841d..8a010ef8208ee9 100644 --- a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.cc +++ b/tensorflow/lite/toco/tensorflow_graph_matching/cluster_utils.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/toco/toco_types.h" +#include "tensorflow/lite/toco/toco_types.h" namespace toco { bool StrContains(const string& x, const string& search_pattern) { diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h b/tensorflow/lite/toco/tensorflow_graph_matching/cluster_utils.h similarity index 82% rename from tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h rename to tensorflow/lite/toco/tensorflow_graph_matching/cluster_utils.h index b57bded305ffbb..9b9c4fc20862c0 100644 --- a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h +++ b/tensorflow/lite/toco/tensorflow_graph_matching/cluster_utils.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_UTILS_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_UTILS_H_ +#ifndef TENSORFLOW_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_UTILS_H_ +#define TENSORFLOW_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_UTILS_H_ #include @@ -30,4 +30,4 @@ void Transpose2DTensor(const float* tensor, int row, int col, } // end namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_UTILS_H_ +#endif // TENSORFLOW_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_UTILS_H_ diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.cc b/tensorflow/lite/toco/tensorflow_graph_matching/resolve_cluster.cc similarity index 93% rename from tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.cc rename to tensorflow/lite/toco/tensorflow_graph_matching/resolve_cluster.cc index 5e421ba944cccd..7a1875120788a5 100644 --- a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.cc +++ b/tensorflow/lite/toco/tensorflow_graph_matching/resolve_cluster.cc @@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h" +#include "tensorflow/lite/toco/tensorflow_graph_matching/resolve_cluster.h" #include #include #include -#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h" -#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h" -#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/tensorflow_graph_matching/cluster.h" +#include "tensorflow/lite/toco/tensorflow_graph_matching/cluster_utils.h" +#include "tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h b/tensorflow/lite/toco/tensorflow_graph_matching/resolve_cluster.h similarity index 86% rename from tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h rename to tensorflow/lite/toco/tensorflow_graph_matching/resolve_cluster.h index 3334552afb1bec..d7afcced7b7ac5 100644 --- a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h +++ b/tensorflow/lite/toco/tensorflow_graph_matching/resolve_cluster.h @@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_CLUSTER_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_CLUSTER_H_ +#ifndef TENSORFLOW_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_CLUSTER_H_ +#define TENSORFLOW_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_CLUSTER_H_ #include #include #include -#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h" -#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h" +#include "tensorflow/lite/toco/tensorflow_graph_matching/cluster.h" +#include "tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -60,4 +60,4 @@ std::unique_ptr MaybeReplaceCompositeSubgraph( } // end namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_CLUSTER_H_ +#endif // TENSORFLOW_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_CLUSTER_H_ diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.cc b/tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.cc similarity index 96% rename from tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.cc rename to tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.cc index d6a099817c7b88..fcd9ee45d984f0 100644 --- a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.cc +++ b/tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h" +#include "tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.h" #include #include @@ -22,11 +22,11 @@ limitations under the License. #include #include "google/protobuf/map.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h" -#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h" -#include "tensorflow/contrib/lite/toco/toco_port.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tensorflow_graph_matching/cluster.h" +#include "tensorflow/lite/toco/tensorflow_graph_matching/cluster_utils.h" +#include "tensorflow/lite/toco/toco_port.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/graph.pb.h" diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h b/tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.h similarity index 85% rename from tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h rename to tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.h index 383fd99dff225c..649cadfa066f94 100644 --- a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h +++ b/tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.h @@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_SVDF_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_SVDF_H_ +#ifndef TENSORFLOW_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_SVDF_H_ +#define TENSORFLOW_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_SVDF_H_ #include #include -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h" -#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tensorflow_graph_matching/cluster.h" +#include "tensorflow/lite/toco/tensorflow_graph_matching/cluster_utils.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -79,4 +79,4 @@ class SvdfClusterFactory : public ClusterFactoryInterface { } // end namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_SVDF_H_ +#endif // TENSORFLOW_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_SVDF_H_ diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf_test.cc b/tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf_test.cc similarity index 96% rename from tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf_test.cc rename to tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf_test.cc index 646d048496c279..f66b59ccce663f 100644 --- a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf_test.cc +++ b/tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h" +#include "tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.h" #include #include @@ -20,9 +20,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h" -#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h" -#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h" +#include "tensorflow/lite/toco/tensorflow_graph_matching/cluster.h" +#include "tensorflow/lite/toco/tensorflow_graph_matching/cluster_utils.h" +#include "tensorflow/lite/toco/tensorflow_graph_matching/resolve_cluster.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" diff --git a/tensorflow/contrib/lite/toco/tensorflow_util.cc b/tensorflow/lite/toco/tensorflow_util.cc similarity index 97% rename from tensorflow/contrib/lite/toco/tensorflow_util.cc rename to tensorflow/lite/toco/tensorflow_util.cc index 0e7e9c41a06658..db9388b040c4e9 100644 --- a/tensorflow/contrib/lite/toco/tensorflow_util.cc +++ b/tensorflow/lite/toco/tensorflow_util.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/tensorflow_util.h" +#include "tensorflow/lite/toco/tensorflow_util.h" #include #include @@ -24,8 +24,8 @@ limitations under the License. #include "google/protobuf/map.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" -#include "tensorflow/contrib/lite/toco/toco_port.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/toco_port.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor.pb.h" diff --git a/tensorflow/contrib/lite/toco/tensorflow_util.h b/tensorflow/lite/toco/tensorflow_util.h similarity index 81% rename from tensorflow/contrib/lite/toco/tensorflow_util.h rename to tensorflow/lite/toco/tensorflow_util.h index 61f91042685288..010fbe88b21790 100644 --- a/tensorflow/contrib/lite/toco/tensorflow_util.h +++ b/tensorflow/lite/toco/tensorflow_util.h @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_ +#ifndef TENSORFLOW_LITE_TOCO_TENSORFLOW_UTIL_H_ +#define TENSORFLOW_LITE_TOCO_TENSORFLOW_UTIL_H_ #include #include -#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/lite/toco/model.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" @@ -29,4 +29,4 @@ void LogDumpGraphDef(int log_level, const string& message, } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_ +#endif // TENSORFLOW_LITE_TOCO_TENSORFLOW_UTIL_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/lite/toco/tflite/BUILD similarity index 68% rename from tensorflow/contrib/lite/toco/tflite/BUILD rename to tensorflow/lite/toco/tflite/BUILD index e97bedd2036bb0..99c4f8edebe518 100644 --- a/tensorflow/contrib/lite/toco/tflite/BUILD +++ b/tensorflow/lite/toco/tflite/BUILD @@ -25,12 +25,12 @@ cc_library( ], deps = [ ":types", - "//tensorflow/contrib/lite/schema:schema_fbs", - "//tensorflow/contrib/lite/toco:graph_transformations", - "//tensorflow/contrib/lite/toco:model", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core:ptr_util", + "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/toco:graph_transformations", + "//tensorflow/lite/toco:model", "@com_google_absl//absl/memory", "@flatbuffers", ], @@ -44,9 +44,9 @@ tf_cc_test( tags = ["no_oss"], deps = [ ":operator", - "//tensorflow/contrib/lite/toco:tooling_util", "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", + "//tensorflow/lite/toco:tooling_util", "@com_google_googletest//:gtest_main", "@flatbuffers", ], @@ -61,9 +61,9 @@ cc_library( "types.h", ], deps = [ - "//tensorflow/contrib/lite:string_util", - "//tensorflow/contrib/lite/schema:schema_fbs", - "//tensorflow/contrib/lite/toco:model", + "//tensorflow/lite:string_util", + "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/toco:model", ], ) @@ -92,11 +92,11 @@ cc_library( deps = [ ":operator", ":types", - "//tensorflow/contrib/lite:schema_fbs_version", - "//tensorflow/contrib/lite/schema:schema_fbs", - "//tensorflow/contrib/lite/toco:model", - "//tensorflow/contrib/lite/toco:tooling_util", - "//tensorflow/contrib/lite/tools/optimize:quantize_weights", + "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/toco:model", + "//tensorflow/lite/toco:tooling_util", + "//tensorflow/lite/tools/optimize:quantize_weights", "@com_google_absl//absl/strings", "@flatbuffers", ], @@ -110,8 +110,8 @@ tf_cc_test( tags = ["no_oss"], deps = [ ":export", - "//tensorflow/contrib/lite/schema:schema_fbs", "//tensorflow/core:ops", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest_main", ], ) @@ -128,11 +128,11 @@ cc_library( deps = [ ":operator", ":types", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/schema:schema_fbs", - "//tensorflow/contrib/lite/toco:model", - "//tensorflow/contrib/lite/toco:tooling_util", - "//tensorflow/contrib/lite/tools:verifier", + "//tensorflow/lite:framework", + "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/toco:model", + "//tensorflow/lite/toco:tooling_util", + "//tensorflow/lite/tools:verifier", "@flatbuffers", ], ) @@ -145,9 +145,9 @@ tf_cc_test( tags = ["no_oss"], deps = [ ":import", - "//tensorflow/contrib/lite:schema_fbs_version", - "//tensorflow/contrib/lite/schema:schema_fbs", "//tensorflow/core:ops", + "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest_main", "@flatbuffers", ], diff --git a/tensorflow/contrib/lite/toco/tflite/builtin_operator.h b/tensorflow/lite/toco/tflite/builtin_operator.h similarity index 90% rename from tensorflow/contrib/lite/toco/tflite/builtin_operator.h rename to tensorflow/lite/toco/tflite/builtin_operator.h index cfe7ecd9f98261..ea012ff6e706ae 100644 --- a/tensorflow/contrib/lite/toco/tflite/builtin_operator.h +++ b/tensorflow/lite/toco/tflite/builtin_operator.h @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_ +#ifndef TENSORFLOW_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_ +#define TENSORFLOW_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_ #include "absl/memory/memory.h" -#include "tensorflow/contrib/lite/toco/tflite/operator.h" +#include "tensorflow/lite/toco/tflite/operator.h" namespace toco { @@ -71,4 +71,4 @@ class BuiltinOperator : public BaseOperator { } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_ +#endif // TENSORFLOW_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/custom_operator.h b/tensorflow/lite/toco/tflite/custom_operator.h similarity index 90% rename from tensorflow/contrib/lite/toco/tflite/custom_operator.h rename to tensorflow/lite/toco/tflite/custom_operator.h index bd5713618ff379..2ca740bb90d5ad 100644 --- a/tensorflow/contrib/lite/toco/tflite/custom_operator.h +++ b/tensorflow/lite/toco/tflite/custom_operator.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_ +#ifndef TENSORFLOW_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_ +#define TENSORFLOW_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_ #include "flatbuffers/flexbuffers.h" #include "absl/memory/memory.h" -#include "tensorflow/contrib/lite/toco/tflite/operator.h" +#include "tensorflow/lite/toco/tflite/operator.h" namespace toco { @@ -71,4 +71,4 @@ class CustomOperator : public BaseOperator { } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_ +#endif // TENSORFLOW_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/lite/toco/tflite/export.cc similarity index 94% rename from tensorflow/contrib/lite/toco/tflite/export.cc rename to tensorflow/lite/toco/tflite/export.cc index 30efd67f8c2130..489c21295ef8fc 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.cc +++ b/tensorflow/lite/toco/tflite/export.cc @@ -12,17 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/tflite/export.h" +#include "tensorflow/lite/toco/tflite/export.h" #include "flatbuffers/flexbuffers.h" #include "absl/strings/str_join.h" -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/schema/schema_generated.h" -#include "tensorflow/contrib/lite/toco/tflite/operator.h" -#include "tensorflow/contrib/lite/toco/tflite/types.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" -#include "tensorflow/contrib/lite/tools/optimize/quantize_weights.h" -#include "tensorflow/contrib/lite/version.h" +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/toco/tflite/operator.h" +#include "tensorflow/lite/toco/tflite/types.h" +#include "tensorflow/lite/toco/tooling_util.h" +#include "tensorflow/lite/tools/optimize/quantize_weights.h" +#include "tensorflow/lite/version.h" namespace toco { @@ -108,7 +108,7 @@ namespace details { OperatorKey::OperatorKey( const ::toco::Operator& op, const std::map>& ops_by_type, - bool allow_flex_ops) { + bool enable_select_tf_ops) { // Get the op name (by Toco definition). string name = HelpfulOperatorTypeName(op); @@ -136,7 +136,8 @@ OperatorKey::OperatorKey( static_cast(op); const auto tensorflow_op = unsupported_op.tensorflow_op; - if (ShouldExportAsFlexOp(allow_flex_ops, unsupported_op.tensorflow_op)) { + if (ShouldExportAsFlexOp(enable_select_tf_ops, + unsupported_op.tensorflow_op)) { is_custom_op_ = false; is_flex_op_ = true; flex_tensorflow_op_ = tensorflow_op; @@ -145,7 +146,7 @@ OperatorKey::OperatorKey( } else { custom_code_ = tensorflow_op; } - } else if (allow_flex_ops && !op.tensorflow_node_def.empty()) { + } else if (enable_select_tf_ops && !op.tensorflow_node_def.empty()) { // For Toco-supported/TFLite-unsupported ops, if the TensorFlow NodeDef // is retained in the Toco Operator, we produce a Flex op if Flex mode // is enabled. @@ -186,11 +187,11 @@ void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) { void LoadOperatorsMap( const Model& model, OperatorsMap* operators_map, const std::map>& ops_by_type, - bool allow_flex_ops) { + bool enable_select_tf_ops) { // First find a list of unique operator types. std::set keys; for (const auto& op : model.operators) { - keys.insert(OperatorKey(*op, ops_by_type, allow_flex_ops)); + keys.insert(OperatorKey(*op, ops_by_type, enable_select_tf_ops)); } // Now assign indices to them and fill in the map. int index = 0; @@ -302,7 +303,7 @@ Offset>> ExportOperatorCodes( for (const auto& op : model.operators) { const details::OperatorKey operator_key = - details::OperatorKey(*op, ops_by_type, params.allow_flex_ops); + details::OperatorKey(*op, ops_by_type, params.enable_select_tf_ops); int op_index = operators_map.at(operator_key); flatbuffers::Offset custom_code = 0; @@ -346,7 +347,7 @@ Offset>> ExportOperators( } const auto key = - details::OperatorKey(*op, ops_by_type, params.allow_flex_ops); + details::OperatorKey(*op, ops_by_type, params.enable_select_tf_ops); int op_index = operators_map.at(key); auto tflite_op_it = ops_by_type.find(op->type); @@ -405,7 +406,7 @@ Offset>> ExportBuffers( tensorflow::Status Export(const Model& model, string* output_file_contents, const ExportParams& params) { - const auto ops_by_type = BuildOperatorByTypeMap(params.allow_flex_ops); + const auto ops_by_type = BuildOperatorByTypeMap(params.enable_select_tf_ops); return Export(model, output_file_contents, params, ops_by_type); } @@ -420,7 +421,7 @@ tensorflow::Status Export( details::OperatorsMap operators_map; details::LoadOperatorsMap(model, &operators_map, ops_by_type, - params.allow_flex_ops); + params.enable_select_tf_ops); std::vector buffers_to_write; Array empty_array; @@ -486,7 +487,7 @@ tensorflow::Status Export( "40-tflite-op-request.md\n and pasting the following:\n\n"; }; - if (params.allow_flex_ops) { + if (params.enable_select_tf_ops) { return tensorflow::errors::InvalidArgument(absl::StrCat( please_report_bug_message(), "Some of the operators in the model are not supported by " @@ -494,7 +495,7 @@ tensorflow::Status Export( "TensorFlow. If you have a custom " "implementation for them you can disable this error with " "--allow_custom_ops, or by setting allow_custom_ops=True " - "when calling tf.contrib.lite.TFLiteConverter(). Here is a list " + "when calling tf.lite.TFLiteConverter(). Here is a list " "of builtin operators you are using: ", absl::StrJoin(builtin_ops, ", "), ". Here is a list " @@ -506,12 +507,12 @@ tensorflow::Status Export( "Some of the operators in the model are not supported by " "the standard TensorFlow Lite runtime. If those are native " "TensorFlow operators, you might be able to use the extended " - "runtime by passing --allow_flex_ops, or by setting " + "runtime by passing --enable_select_tf_ops, or by setting " "target_ops=TFLITE_BUILTINS,SELECT_TF_OPS when calling " - "tf.contrib.lite.TFLiteConverter(). Otherwise, if you have a " + "tf.lite.TFLiteConverter(). Otherwise, if you have a " "custom implementation for them you can disable this error with " "--allow_custom_ops, or by setting allow_custom_ops=True " - "when calling tf.contrib.lite.TFLiteConverter(). Here is a list " + "when calling tf.lite.TFLiteConverter(). Here is a list " "of builtin operators you are using: ", absl::StrJoin(builtin_ops, ", "), ". Here is a list " diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/lite/toco/tflite/export.h similarity index 93% rename from tensorflow/contrib/lite/toco/tflite/export.h rename to tensorflow/lite/toco/tflite/export.h index 0a43e01c2e5a7f..adf6757a3027e5 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.h +++ b/tensorflow/lite/toco/tflite/export.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_ +#ifndef TENSORFLOW_LITE_TOCO_TFLITE_EXPORT_H_ +#define TENSORFLOW_LITE_TOCO_TFLITE_EXPORT_H_ -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tflite/operator.h" -#include "tensorflow/contrib/lite/util.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tflite/operator.h" +#include "tensorflow/lite/util.h" namespace toco { @@ -26,7 +26,7 @@ namespace tflite { // The parameters for exporting a TFLite model. struct ExportParams { bool allow_custom_ops = false; - bool allow_flex_ops = false; + bool enable_select_tf_ops = false; bool quantize_weights = false; }; @@ -90,7 +90,7 @@ class OperatorKey { OperatorKey( const ::toco::Operator& op, const std::map>& ops_by_type, - bool allow_flex_ops); + bool enable_select_tf_ops); // Construct OperatorKey by type, custom code and version. // Note that this construct doesn't set the additional information including @@ -165,10 +165,10 @@ void LoadTensorsMap(const Model& model, TensorsMap* tensors_map); void LoadOperatorsMap( const Model& model, OperatorsMap* operators_map, const std::map>& ops_by_type, - bool allow_flex_ops); + bool enable_select_tf_ops); } // namespace details } // namespace tflite } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_ +#endif // TENSORFLOW_LITE_TOCO_TFLITE_EXPORT_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/lite/toco/tflite/export_test.cc similarity index 97% rename from tensorflow/contrib/lite/toco/tflite/export_test.cc rename to tensorflow/lite/toco/tflite/export_test.cc index f574874064458c..b6c67772acadcc 100644 --- a/tensorflow/contrib/lite/toco/tflite/export_test.cc +++ b/tensorflow/lite/toco/tflite/export_test.cc @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/tflite/export.h" +#include "tensorflow/lite/toco/tflite/export.h" #include #include -#include "tensorflow/contrib/lite/schema/schema_generated.h" -#include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h" -#include "tensorflow/contrib/lite/toco/tflite/operator.h" -#include "tensorflow/contrib/lite/toco/tflite/types.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/toco/tflite/builtin_operator.h" +#include "tensorflow/lite/toco/tflite/operator.h" +#include "tensorflow/lite/toco/tflite/types.h" #include "tensorflow/core/framework/node_def.pb.h" namespace toco { @@ -162,7 +162,7 @@ TEST_F(ExportTest, Export) { ExportParams params; params.allow_custom_ops = true; - params.allow_flex_ops = false; + params.enable_select_tf_ops = false; params.quantize_weights = false; EXPECT_THAT(ExportAndSummarizeOperators(params), @@ -192,7 +192,7 @@ class OpSetsTest : public ExportTest { void SetAllowedOpSets(std::initializer_list sets) { import_all_ops_as_unsupported_ = true; params_.allow_custom_ops = false; - params_.allow_flex_ops = false; + params_.enable_select_tf_ops = false; params_.quantize_weights = false; for (OpSet i : sets) { @@ -201,7 +201,7 @@ class OpSetsTest : public ExportTest { import_all_ops_as_unsupported_ = false; break; case kSelectTfOps: - params_.allow_flex_ops = true; + params_.enable_select_tf_ops = true; break; case kCustomOps: params_.allow_custom_ops = true; diff --git a/tensorflow/contrib/lite/toco/tflite/import.cc b/tensorflow/lite/toco/tflite/import.cc similarity index 95% rename from tensorflow/contrib/lite/toco/tflite/import.cc rename to tensorflow/lite/toco/tflite/import.cc index 1dd4915b31413e..88028aa144f2dc 100644 --- a/tensorflow/contrib/lite/toco/tflite/import.cc +++ b/tensorflow/lite/toco/tflite/import.cc @@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/tflite/import.h" +#include "tensorflow/lite/toco/tflite/import.h" #include "flatbuffers/flexbuffers.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/schema/schema_generated.h" -#include "tensorflow/contrib/lite/toco/tflite/operator.h" -#include "tensorflow/contrib/lite/toco/tflite/types.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" -#include "tensorflow/contrib/lite/tools/verifier.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/toco/tflite/operator.h" +#include "tensorflow/lite/toco/tflite/types.h" +#include "tensorflow/lite/toco/tooling_util.h" +#include "tensorflow/lite/tools/verifier.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/tflite/import.h b/tensorflow/lite/toco/tflite/import.h similarity index 84% rename from tensorflow/contrib/lite/toco/tflite/import.h rename to tensorflow/lite/toco/tflite/import.h index 280677bae189fa..f5de3b53b5bc24 100644 --- a/tensorflow/contrib/lite/toco/tflite/import.h +++ b/tensorflow/lite/toco/tflite/import.h @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_ +#ifndef TENSORFLOW_LITE_TOCO_TFLITE_IMPORT_H_ +#define TENSORFLOW_LITE_TOCO_TFLITE_IMPORT_H_ -#include "tensorflow/contrib/lite/schema/schema_generated.h" -#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/toco/model.h" namespace toco { @@ -46,4 +46,4 @@ void LoadOperatorsTable(const ::tflite::Model &input_model, } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_ +#endif // TENSORFLOW_LITE_TOCO_TFLITE_IMPORT_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/import_test.cc b/tensorflow/lite/toco/tflite/import_test.cc similarity index 98% rename from tensorflow/contrib/lite/toco/tflite/import_test.cc rename to tensorflow/lite/toco/tflite/import_test.cc index edd22f783f03b1..93ab5141abe81c 100644 --- a/tensorflow/contrib/lite/toco/tflite/import_test.cc +++ b/tensorflow/lite/toco/tflite/import_test.cc @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/tflite/import.h" +#include "tensorflow/lite/toco/tflite/import.h" #include "flatbuffers/flexbuffers.h" #include #include -#include "tensorflow/contrib/lite/schema/schema_generated.h" -#include "tensorflow/contrib/lite/version.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/version.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc similarity index 98% rename from tensorflow/contrib/lite/toco/tflite/operator.cc rename to tensorflow/lite/toco/tflite/operator.cc index 0f4ea5cd214517..015029e1cbd57c 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/tflite/operator.h" +#include "tensorflow/lite/toco/tflite/operator.h" // TODO(ycling): Consider refactoring to extract the LSTM definition out of // graph_transformation module. -#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h" -#include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h" -#include "tensorflow/contrib/lite/toco/tflite/custom_operator.h" -#include "tensorflow/contrib/lite/toco/tflite/simple_operator.h" -#include "tensorflow/contrib/lite/toco/tflite/types.h" -#include "tensorflow/contrib/lite/toco/tflite/whitelisted_flex_ops.h" +#include "tensorflow/lite/toco/graph_transformations/lstm_utils.h" +#include "tensorflow/lite/toco/tflite/builtin_operator.h" +#include "tensorflow/lite/toco/tflite/custom_operator.h" +#include "tensorflow/lite/toco/tflite/simple_operator.h" +#include "tensorflow/lite/toco/tflite/types.h" +#include "tensorflow/lite/toco/tflite/whitelisted_flex_ops.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op.h" @@ -1240,8 +1240,8 @@ std::unique_ptr WriteFlexOpOptions( class TensorFlowUnsupported : public BaseOperator { public: TensorFlowUnsupported(const string& name, OperatorType type, - bool allow_flex_ops) - : BaseOperator(name, type), allow_flex_ops_(allow_flex_ops) {} + bool enable_select_tf_ops) + : BaseOperator(name, type), enable_select_tf_ops_(enable_select_tf_ops) {} Options Serialize(const Operator& op, flatbuffers::FlatBufferBuilder* builder) const override { @@ -1272,7 +1272,7 @@ class TensorFlowUnsupported : public BaseOperator { std::unique_ptr WriteOptions( const TensorFlowUnsupportedOperator& op) const { - if (allow_flex_ops_) { + if (enable_select_tf_ops_) { return WriteFlexOpOptions(op.tensorflow_node_def); } auto fbb = absl::make_unique(); @@ -1283,7 +1283,7 @@ class TensorFlowUnsupported : public BaseOperator { return std::unique_ptr(); } - if (ShouldExportAsFlexOp(allow_flex_ops_, node_def.op())) { + if (ShouldExportAsFlexOp(enable_select_tf_ops_, node_def.op())) { fbb->Vector([&]() { fbb->String(node_def.op()); fbb->String(op.tensorflow_node_def); @@ -1399,13 +1399,13 @@ class TensorFlowUnsupported : public BaseOperator { } private: - const bool allow_flex_ops_; + const bool enable_select_tf_ops_; }; namespace { // Build a vector containing all the known operators. std::vector> BuildOperatorList( - bool allow_flex_ops = false) { + bool enable_select_tf_ops = false) { std::vector> ops; using tensorflow::MakeUnique; // Builtin Operators. @@ -1522,8 +1522,9 @@ std::vector> BuildOperatorList( MakeUnique("DEPTH_TO_SPACE", OperatorType::kDepthToSpace)); ops.push_back(MakeUnique( "CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder)); - ops.push_back(MakeUnique( - "TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported, allow_flex_ops)); + ops.push_back(MakeUnique("TENSORFLOW_UNSUPPORTED", + OperatorType::kUnsupported, + enable_select_tf_ops)); // There operators are supported by Toco, but not by TF Lite, and has no // attributes. @@ -1605,11 +1606,11 @@ std::vector> BuildOperatorList( } // namespace std::map> BuildOperatorByTypeMap( - bool allow_flex_ops) { + bool enable_select_tf_ops) { std::map> result; std::vector> ops = - BuildOperatorList(allow_flex_ops); + BuildOperatorList(enable_select_tf_ops); for (auto& op : ops) { result[op->type()] = std::move(op); } @@ -1618,11 +1619,11 @@ std::map> BuildOperatorByTypeMap( } std::map> BuildOperatorByNameMap( - bool allow_flex_ops) { + bool enable_select_tf_ops) { std::map> result; std::vector> ops = - BuildOperatorList(allow_flex_ops); + BuildOperatorList(enable_select_tf_ops); for (auto& op : ops) { result[op->name()] = std::move(op); } @@ -1630,10 +1631,10 @@ std::map> BuildOperatorByNameMap( return result; } -bool ShouldExportAsFlexOp(bool allow_flex_ops, +bool ShouldExportAsFlexOp(bool enable_select_tf_ops, const string& tensorflow_op_name) { // If Flex ops aren't allow at all, simply return false. - if (!allow_flex_ops) { + if (!enable_select_tf_ops) { return false; } // Check if we can find the `OpDef` for the TensorFlow op. If we can find diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/lite/toco/tflite/operator.h similarity index 90% rename from tensorflow/contrib/lite/toco/tflite/operator.h rename to tensorflow/lite/toco/tflite/operator.h index 6e2a41bf53ad23..4ac531579c12c8 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.h +++ b/tensorflow/lite/toco/tflite/operator.h @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_ +#ifndef TENSORFLOW_LITE_TOCO_TFLITE_OPERATOR_H_ +#define TENSORFLOW_LITE_TOCO_TFLITE_OPERATOR_H_ #include "flatbuffers/flatbuffers.h" #include "flatbuffers/flexbuffers.h" -#include "tensorflow/contrib/lite/schema/schema_generated.h" -#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/toco/model.h" namespace toco { @@ -27,15 +27,15 @@ namespace tflite { class BaseOperator; // Return a map contained all know TF Lite Operators, keyed by their names. -// TODO(ycling): The pattern to propagate parameters (e.g. allow_flex_ops) +// TODO(ycling): The pattern to propagate parameters (e.g. enable_select_tf_ops) // is ugly here. Consider refactoring. std::map> BuildOperatorByNameMap( - bool allow_flex_ops = false); + bool enable_select_tf_ops = false); // Return a map contained all know TF Lite Operators, keyed by the type of // their tf.mini counterparts. std::map> BuildOperatorByTypeMap( - bool allow_flex_ops = false); + bool enable_select_tf_ops = false); // Write the custom option FlexBuffer with a serialized TensorFlow NodeDef // for a Flex op. @@ -115,11 +115,11 @@ class BaseOperator { // Helper function to determine if a unsupported TensorFlow op should be // exported as an Flex op or a regular custom op. -bool ShouldExportAsFlexOp(bool allow_flex_ops, +bool ShouldExportAsFlexOp(bool enable_select_tf_ops, const string& tensorflow_op_name); } // namespace tflite } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_ +#endif // TENSORFLOW_LITE_TOCO_TFLITE_OPERATOR_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/lite/toco/tflite/operator_test.cc similarity index 99% rename from tensorflow/contrib/lite/toco/tflite/operator_test.cc rename to tensorflow/lite/toco/tflite/operator_test.cc index 37a0ad2d1dff15..8a776cbf0be57d 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/lite/toco/tflite/operator_test.cc @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/tflite/operator.h" +#include "tensorflow/lite/toco/tflite/operator.h" #include "flatbuffers/flexbuffers.h" #include #include -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" diff --git a/tensorflow/contrib/lite/toco/tflite/simple_operator.h b/tensorflow/lite/toco/tflite/simple_operator.h similarity index 86% rename from tensorflow/contrib/lite/toco/tflite/simple_operator.h rename to tensorflow/lite/toco/tflite/simple_operator.h index a7f7e886f61d3b..e3e4c8551e931f 100644 --- a/tensorflow/contrib/lite/toco/tflite/simple_operator.h +++ b/tensorflow/lite/toco/tflite/simple_operator.h @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_ +#ifndef TENSORFLOW_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_ +#define TENSORFLOW_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_ -#include "tensorflow/contrib/lite/toco/tflite/operator.h" +#include "tensorflow/lite/toco/tflite/operator.h" namespace toco { @@ -49,4 +49,4 @@ class SimpleOperator : public BaseOperator { } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_ +#endif // TENSORFLOW_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/types.cc b/tensorflow/lite/toco/tflite/types.cc similarity index 98% rename from tensorflow/contrib/lite/toco/tflite/types.cc rename to tensorflow/lite/toco/tflite/types.cc index 754f0b4b8c6613..f878dafc1ed3c8 100644 --- a/tensorflow/contrib/lite/toco/tflite/types.cc +++ b/tensorflow/lite/toco/tflite/types.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/tflite/types.h" -#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/lite/toco/tflite/types.h" +#include "tensorflow/lite/string_util.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/tflite/types.h b/tensorflow/lite/toco/tflite/types.h similarity index 87% rename from tensorflow/contrib/lite/toco/tflite/types.h rename to tensorflow/lite/toco/tflite/types.h index 3923756fc94e31..bc2edb74297426 100644 --- a/tensorflow/contrib/lite/toco/tflite/types.h +++ b/tensorflow/lite/toco/tflite/types.h @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_ +#ifndef TENSORFLOW_LITE_TOCO_TFLITE_TYPES_H_ +#define TENSORFLOW_LITE_TOCO_TFLITE_TYPES_H_ -#include "tensorflow/contrib/lite/schema/schema_generated.h" -#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/toco/model.h" namespace toco { @@ -55,4 +55,4 @@ struct ActivationFunction { } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_ +#endif // TENSORFLOW_LITE_TOCO_TFLITE_TYPES_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/types_test.cc b/tensorflow/lite/toco/tflite/types_test.cc similarity index 99% rename from tensorflow/contrib/lite/toco/tflite/types_test.cc rename to tensorflow/lite/toco/tflite/types_test.cc index 8e9f30ba3a6e6b..efa2911b5b8c25 100644 --- a/tensorflow/contrib/lite/toco/tflite/types_test.cc +++ b/tensorflow/lite/toco/tflite/types_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/tflite/types.h" +#include "tensorflow/lite/toco/tflite/types.h" #include diff --git a/tensorflow/contrib/lite/toco/tflite/whitelisted_flex_ops.cc b/tensorflow/lite/toco/tflite/whitelisted_flex_ops.cc similarity index 99% rename from tensorflow/contrib/lite/toco/tflite/whitelisted_flex_ops.cc rename to tensorflow/lite/toco/tflite/whitelisted_flex_ops.cc index d605a006becc60..221e9b8e34e2b8 100644 --- a/tensorflow/contrib/lite/toco/tflite/whitelisted_flex_ops.cc +++ b/tensorflow/lite/toco/tflite/whitelisted_flex_ops.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/tflite/whitelisted_flex_ops.h" +#include "tensorflow/lite/toco/tflite/whitelisted_flex_ops.h" #include diff --git a/tensorflow/contrib/lite/toco/tflite/whitelisted_flex_ops.h b/tensorflow/lite/toco/tflite/whitelisted_flex_ops.h similarity index 86% rename from tensorflow/contrib/lite/toco/tflite/whitelisted_flex_ops.h rename to tensorflow/lite/toco/tflite/whitelisted_flex_ops.h index ed2435fbe0bb36..2559a7052852ca 100644 --- a/tensorflow/contrib/lite/toco/tflite/whitelisted_flex_ops.h +++ b/tensorflow/lite/toco/tflite/whitelisted_flex_ops.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_WHITELISTED_FLEX_OPS_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_WHITELISTED_FLEX_OPS_H_ +#ifndef TENSORFLOW_LITE_TOCO_TFLITE_WHITELISTED_FLEX_OPS_H_ +#define TENSORFLOW_LITE_TOCO_TFLITE_WHITELISTED_FLEX_OPS_H_ #include @@ -32,4 +32,4 @@ bool IsWhitelistedFlexOp(const std::string& tensorflow_op_name); } // namespace tflite } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_WHITELISTED_FLEX_OPS_H_ +#endif // TENSORFLOW_LITE_TOCO_TFLITE_WHITELISTED_FLEX_OPS_H_ diff --git a/tensorflow/contrib/lite/toco/toco.cc b/tensorflow/lite/toco/toco.cc similarity index 90% rename from tensorflow/contrib/lite/toco/toco.cc rename to tensorflow/lite/toco/toco.cc index d1c431278581bd..9740015850a05c 100644 --- a/tensorflow/contrib/lite/toco/toco.cc +++ b/tensorflow/lite/toco/toco.cc @@ -17,14 +17,14 @@ limitations under the License. #include #include "absl/strings/string_view.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h" -#include "tensorflow/contrib/lite/toco/model_flags.pb.h" -#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h" -#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" -#include "tensorflow/contrib/lite/toco/toco_port.h" -#include "tensorflow/contrib/lite/toco/toco_tooling.h" -#include "tensorflow/contrib/lite/toco/toco_types.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/model_cmdline_flags.h" +#include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/toco_cmdline_flags.h" +#include "tensorflow/lite/toco/toco_flags.pb.h" +#include "tensorflow/lite/toco/toco_port.h" +#include "tensorflow/lite/toco/toco_tooling.h" +#include "tensorflow/lite/toco/toco_types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" @@ -60,7 +60,7 @@ void ReadInputData(const ParsedTocoFlags& parsed_toco_flags, // Ensure savedmodel_directory is not set. QCHECK(!parsed_toco_flags.savedmodel_directory.specified()) - << "Use `tensorflow/contrib/lite/python/tflite_convert` script with " + << "Use `tensorflow/lite/python/tflite_convert` script with " << "SavedModel directories.\n"; // Checks the input file permissions and reads the contents. diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/lite/toco/toco_cmdline_flags.cc similarity index 94% rename from tensorflow/contrib/lite/toco/toco_cmdline_flags.cc rename to tensorflow/lite/toco/toco_cmdline_flags.cc index cff79776bc787e..7d525ae5583c4f 100644 --- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc +++ b/tensorflow/lite/toco/toco_cmdline_flags.cc @@ -21,8 +21,8 @@ limitations under the License. #include "absl/strings/str_split.h" #include "absl/strings/strip.h" #include "absl/types/optional.h" -#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h" -#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/lite/toco/toco_cmdline_flags.h" +#include "tensorflow/lite/toco/toco_port.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/command_line_flags.h" @@ -166,12 +166,13 @@ bool ParseTocoFlagsFromCommandLineFlags( "Boolean indicating whether to quantize the weights of the " "converted float model. Model size will be reduced and there will " "be latency improvements (at the cost of accuracy)."), + // TODO(b/118822804): Unify the argument definition with `tflite_convert`. // WARNING: Experimental interface, subject to change - Flag("allow_flex_ops", parsed_flags.allow_flex_ops.bind(), - parsed_flags.allow_flex_ops.default_value(), ""), + Flag("enable_select_tf_ops", parsed_flags.enable_select_tf_ops.bind(), + parsed_flags.enable_select_tf_ops.default_value(), ""), // WARNING: Experimental interface, subject to change - Flag("force_flex_ops", parsed_flags.force_flex_ops.bind(), - parsed_flags.force_flex_ops.default_value(), "")}; + Flag("force_select_tf_ops", parsed_flags.force_select_tf_ops.bind(), + parsed_flags.force_select_tf_ops.default_value(), "")}; bool asked_for_help = *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help")); if (asked_for_help) { @@ -266,15 +267,15 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone); READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone); READ_TOCO_FLAG(post_training_quantize, FlagRequirement::kNone); - READ_TOCO_FLAG(allow_flex_ops, FlagRequirement::kNone); - READ_TOCO_FLAG(force_flex_ops, FlagRequirement::kNone); + READ_TOCO_FLAG(enable_select_tf_ops, FlagRequirement::kNone); + READ_TOCO_FLAG(force_select_tf_ops, FlagRequirement::kNone); - if (parsed_toco_flags.force_flex_ops.value() && - !parsed_toco_flags.allow_flex_ops.value()) { - // TODO(ycling): Consider to enforce `allow_flex_ops` when - // `force_flex_ops` is true. - LOG(WARNING) << "--force_flex_ops should always be used with " - "--allow_flex_ops."; + if (parsed_toco_flags.force_select_tf_ops.value() && + !parsed_toco_flags.enable_select_tf_ops.value()) { + // TODO(ycling): Consider to enforce `enable_select_tf_ops` when + // `force_select_tf_ops` is true. + LOG(WARNING) << "--force_select_tf_ops should always be used with " + "--enable_select_tf_ops."; } // Deprecated flag handling. diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.h b/tensorflow/lite/toco/toco_cmdline_flags.h similarity index 79% rename from tensorflow/contrib/lite/toco/toco_cmdline_flags.h rename to tensorflow/lite/toco/toco_cmdline_flags.h index 46eb3f57283cc5..cf57055abc26e6 100644 --- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.h +++ b/tensorflow/lite/toco/toco_cmdline_flags.h @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_ +#ifndef TENSORFLOW_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_ +#define TENSORFLOW_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_ #include #include -#include "tensorflow/contrib/lite/toco/args.h" -#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" -#include "tensorflow/contrib/lite/toco/types.pb.h" +#include "tensorflow/lite/toco/args.h" +#include "tensorflow/lite/toco/toco_flags.pb.h" +#include "tensorflow/lite/toco/types.pb.h" namespace toco { // Parse and remove arguments handled from toco. Returns true if parsing @@ -33,4 +33,4 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_ +#endif // TENSORFLOW_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_ diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/lite/toco/toco_flags.proto similarity index 93% rename from tensorflow/contrib/lite/toco/toco_flags.proto rename to tensorflow/lite/toco/toco_flags.proto index ca3e64485e7a46..cb015ba3d2a742 100644 --- a/tensorflow/contrib/lite/toco/toco_flags.proto +++ b/tensorflow/lite/toco/toco_flags.proto @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto2"; -import "tensorflow/contrib/lite/toco/types.proto"; +import "tensorflow/lite/toco/types.proto"; package toco; @@ -190,16 +190,19 @@ message TocoFlags { // (at the cost of accuracy). optional bool post_training_quantize = 26 [default = false]; - // When enabled, unsupported ops will be converted to TFLite Flex ops. + // This flag only works when converting to TensorFlow Lite format. + // When enabled, unsupported ops will be converted to select TensorFlow ops. // TODO(ycling): Consider to rename the following 2 flags and don't call it // "Flex". - // `allow_flex_ops` should always be used with `allow_custom_ops`. + // `enable_select_tf_ops` should always be used with `allow_custom_ops`. // WARNING: Experimental interface, subject to change - optional bool allow_flex_ops = 27 [default = false]; + optional bool enable_select_tf_ops = 27 [default = false]; - // When enabled, all TensorFlow ops will be converted to TFLite Flex - // ops directly. This will force `allow_flex_ops` to true. - // `force_flex_ops` should always be used with `allow_flex_ops`. + // This flag only works when converting to TensorFlow Lite format. + // When enabled, all TensorFlow ops will be converted to select TensorFlow + // ops. + // This will force `enable_select_tf_ops` to true. + // `force_select_tf_ops` should always be used with `enable_select_tf_ops`. // WARNING: Experimental interface, subject to change - optional bool force_flex_ops = 28 [default = false]; + optional bool force_select_tf_ops = 28 [default = false]; } diff --git a/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.cc b/tensorflow/lite/toco/toco_graphviz_dump_options.cc similarity index 92% rename from tensorflow/contrib/lite/toco/toco_graphviz_dump_options.cc rename to tensorflow/lite/toco/toco_graphviz_dump_options.cc index 4e98e7081de438..449f0f07cec128 100644 --- a/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.cc +++ b/tensorflow/lite/toco/toco_graphviz_dump_options.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" +#include "tensorflow/lite/toco/toco_graphviz_dump_options.h" namespace toco { GraphVizDumpOptions* GraphVizDumpOptions::singleton() { diff --git a/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h b/tensorflow/lite/toco/toco_graphviz_dump_options.h similarity index 82% rename from tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h rename to tensorflow/lite/toco/toco_graphviz_dump_options.h index 7cdd55e5422589..00d9cd13a66272 100644 --- a/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h +++ b/tensorflow/lite/toco/toco_graphviz_dump_options.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_ +#ifndef TENSORFLOW_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_ +#define TENSORFLOW_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_ #include @@ -29,4 +29,4 @@ struct GraphVizDumpOptions { } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_ +#endif // TENSORFLOW_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_ diff --git a/tensorflow/contrib/lite/toco/toco_port.cc b/tensorflow/lite/toco/toco_port.cc similarity index 98% rename from tensorflow/contrib/lite/toco/toco_port.cc rename to tensorflow/lite/toco/toco_port.cc index 204c0d101eac6d..0881065a23f122 100644 --- a/tensorflow/contrib/lite/toco/toco_port.cc +++ b/tensorflow/lite/toco/toco_port.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/contrib/lite/toco/toco_port.h" -#include "tensorflow/contrib/lite/toco/toco_types.h" +#include "tensorflow/lite/toco/toco_port.h" +#include "tensorflow/lite/toco/toco_types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/contrib/lite/toco/toco_port.h b/tensorflow/lite/toco/toco_port.h similarity index 94% rename from tensorflow/contrib/lite/toco/toco_port.h rename to tensorflow/lite/toco/toco_port.h index 17f82b9dd7dcc6..2f39e3d6d5c024 100644 --- a/tensorflow/contrib/lite/toco/toco_port.h +++ b/tensorflow/lite/toco/toco_port.h @@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_ +#ifndef TENSORFLOW_LITE_TOCO_TOCO_PORT_H_ +#define TENSORFLOW_LITE_TOCO_TOCO_PORT_H_ // Portability layer for toco tool. Mainly, abstract filesystem access so we // can build and use on google internal environments and on OSX. #include #include "google/protobuf/text_format.h" -#include "tensorflow/contrib/lite/toco/format_port.h" +#include "tensorflow/lite/toco/format_port.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/platform.h" @@ -110,4 +110,4 @@ bool ParseFromStringEitherTextOrBinary(const std::string& input_file_contents, } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_ +#endif // TENSORFLOW_LITE_TOCO_TOCO_PORT_H_ diff --git a/tensorflow/contrib/lite/toco/toco_port_test.cc b/tensorflow/lite/toco/toco_port_test.cc similarity index 88% rename from tensorflow/contrib/lite/toco/toco_port_test.cc rename to tensorflow/lite/toco/toco_port_test.cc index 650a617aebc053..f5fbb4caeb2882 100644 --- a/tensorflow/contrib/lite/toco/toco_port_test.cc +++ b/tensorflow/lite/toco/toco_port_test.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/toco_port.h" -#include "tensorflow/contrib/lite/toco/toco_types.h" +#include "tensorflow/lite/toco/toco_port.h" +#include "tensorflow/lite/toco/toco_types.h" #include #include @@ -23,9 +23,9 @@ namespace port { namespace { #ifdef PLATFORM_GOOGLE -#define TFLITE_PREFIX "third_party/tensorflow/contrib/lite/" +#define TFLITE_PREFIX "third_party/tensorflow/lite/" #else -#define TFLITE_PREFIX "tensorflow/contrib/lite/" +#define TFLITE_PREFIX "tensorflow/lite/" #endif TEST(TocoPortTest, Exists) { diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/lite/toco/toco_tooling.cc similarity index 95% rename from tensorflow/contrib/lite/toco/toco_tooling.cc rename to tensorflow/lite/toco/toco_tooling.cc index c4eacf836e3d5d..5f96e833fbf400 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/lite/toco/toco_tooling.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/toco_tooling.h" +#include "tensorflow/lite/toco/toco_tooling.h" #include #include @@ -20,16 +20,16 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_join.h" -#include "tensorflow/contrib/lite/toco/allocate_transient_arrays.h" -#include "tensorflow/contrib/lite/toco/dump_graphviz.h" -#include "tensorflow/contrib/lite/toco/export_tensorflow.h" -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/import_tensorflow.h" -#include "tensorflow/contrib/lite/toco/model_flags.pb.h" -#include "tensorflow/contrib/lite/toco/tflite/export.h" -#include "tensorflow/contrib/lite/toco/tflite/import.h" -#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/allocate_transient_arrays.h" +#include "tensorflow/lite/toco/dump_graphviz.h" +#include "tensorflow/lite/toco/export_tensorflow.h" +#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/lite/toco/import_tensorflow.h" +#include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/tflite/export.h" +#include "tensorflow/lite/toco/tflite/import.h" +#include "tensorflow/lite/toco/toco_flags.pb.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" namespace toco { @@ -198,7 +198,7 @@ std::unique_ptr Import(const TocoFlags& toco_flags, : (toco_flags.output_format() != TENSORFLOW_GRAPHDEF); tf_import_flags.import_all_ops_as_unsupported = - toco_flags.force_flex_ops(); + toco_flags.force_select_tf_ops(); model = ImportTensorFlowGraphDef(model_flags, tf_import_flags, input_file_contents); @@ -409,8 +409,8 @@ tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model, case TFLITE: { toco::tflite::ExportParams params; - params.allow_flex_ops = - toco_flags.force_flex_ops() || toco_flags.allow_flex_ops(); + params.enable_select_tf_ops = + toco_flags.force_select_tf_ops() || toco_flags.enable_select_tf_ops(); params.allow_custom_ops = allow_custom_ops; params.quantize_weights = toco_flags.post_training_quantize(); diff --git a/tensorflow/contrib/lite/toco/toco_tooling.h b/tensorflow/lite/toco/toco_tooling.h similarity index 84% rename from tensorflow/contrib/lite/toco/toco_tooling.h rename to tensorflow/lite/toco/toco_tooling.h index 40c0e7f0a39089..742e3769269859 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.h +++ b/tensorflow/lite/toco/toco_tooling.h @@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_ +#ifndef TENSORFLOW_LITE_TOCO_TOCO_TOOLING_H_ +#define TENSORFLOW_LITE_TOCO_TOCO_TOOLING_H_ #include #include -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/model_flags.pb.h" -#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/toco_flags.pb.h" namespace toco { @@ -50,4 +50,4 @@ inline void Export(const TocoFlags& toco_flags, const Model& model, } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_ +#endif // TENSORFLOW_LITE_TOCO_TOCO_TOOLING_H_ diff --git a/tensorflow/contrib/lite/toco/toco_types.h b/tensorflow/lite/toco/toco_types.h similarity index 88% rename from tensorflow/contrib/lite/toco/toco_types.h rename to tensorflow/lite/toco/toco_types.h index 319f1066cdb33e..da2efd6724a704 100644 --- a/tensorflow/contrib/lite/toco/toco_types.h +++ b/tensorflow/lite/toco/toco_types.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TYPES_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TYPES_H_ +#ifndef TENSORFLOW_LITE_TOCO_TOCO_TYPES_H_ +#define TENSORFLOW_LITE_TOCO_TOCO_TYPES_H_ #include #include "tensorflow/core/platform/platform.h" @@ -42,4 +42,4 @@ using tensorflow::uint8; } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TYPES_H_ +#endif // TENSORFLOW_LITE_TOCO_TOCO_TYPES_H_ diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/lite/toco/tooling_util.cc similarity index 99% rename from tensorflow/contrib/lite/toco/tooling_util.cc rename to tensorflow/lite/toco/tooling_util.cc index 2d6968239efecf..e33f7c8452f88d 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/lite/toco/tooling_util.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/tooling_util.h" #include #include @@ -27,9 +27,9 @@ limitations under the License. #include "absl/strings/str_replace.h" #include "absl/strings/str_split.h" #include "re2/re2.h" -#include "tensorflow/contrib/lite/toco/dump_graphviz.h" -#include "tensorflow/contrib/lite/toco/model_flags.pb.h" -#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" +#include "tensorflow/lite/toco/dump_graphviz.h" +#include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/toco_graphviz_dump_options.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/lite/toco/tooling_util.h similarity index 96% rename from tensorflow/contrib/lite/toco/tooling_util.h rename to tensorflow/lite/toco/tooling_util.h index 5f4b8cb66a2c54..92ce82632f9685 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/lite/toco/tooling_util.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_ +#ifndef TENSORFLOW_LITE_TOCO_TOOLING_UTIL_H_ +#define TENSORFLOW_LITE_TOCO_TOOLING_UTIL_H_ #include #include @@ -28,12 +28,12 @@ limitations under the License. #if TOCO_SUPPORT_PORTABLE_PROTOS #include "third_party/protobuf/include/google/protobuf/text_format.h" #endif // TOCO_SUPPORT_PORTABLE_PROTOS -#include "tensorflow/contrib/lite/kernels/internal/types.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/model_flags.pb.h" -#include "tensorflow/contrib/lite/toco/runtime/types.h" -#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" -#include "tensorflow/contrib/lite/toco/types.pb.h" +#include "tensorflow/lite/kernels/internal/types.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/runtime/types.h" +#include "tensorflow/lite/toco/toco_flags.pb.h" +#include "tensorflow/lite/toco/types.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -358,4 +358,4 @@ void CopyMinMaxAndQuantizationRelatedFields(const Array& src, Array* dst); } // namespace toco -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_ +#endif // TENSORFLOW_LITE_TOCO_TOOLING_UTIL_H_ diff --git a/tensorflow/contrib/lite/toco/tooling_util_test.cc b/tensorflow/lite/toco/tooling_util_test.cc similarity index 98% rename from tensorflow/contrib/lite/toco/tooling_util_test.cc rename to tensorflow/lite/toco/tooling_util_test.cc index eb495646a2df0d..e3826cb8fde69f 100644 --- a/tensorflow/contrib/lite/toco/tooling_util_test.cc +++ b/tensorflow/lite/toco/tooling_util_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/lib/core/status.h" namespace toco { diff --git a/tensorflow/contrib/lite/toco/types.proto b/tensorflow/lite/toco/types.proto similarity index 100% rename from tensorflow/contrib/lite/toco/types.proto rename to tensorflow/lite/toco/types.proto diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/lite/tools/BUILD similarity index 60% rename from tensorflow/contrib/lite/tools/BUILD rename to tensorflow/lite/tools/BUILD index 0b268264031f4f..93725b5de473e4 100644 --- a/tensorflow/contrib/lite/tools/BUILD +++ b/tensorflow/lite/tools/BUILD @@ -4,7 +4,7 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 -load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") +load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") common_copts = ["-Wall"] @@ -13,7 +13,7 @@ py_binary( name = "visualize", srcs = ["visualize.py"], data = [ - "//tensorflow/contrib/lite/schema:schema.fbs", + "//tensorflow/lite/schema:schema.fbs", "//tensorflow/python:platform", "@flatbuffers//:flatc", ], @@ -24,9 +24,9 @@ tf_cc_binary( name = "generate_op_registrations", srcs = ["gen_op_registration_main.cc"], deps = [ - "//tensorflow/contrib/lite/tools:gen_op_registration", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "//tensorflow/lite/tools:gen_op_registration", "@com_google_absl//absl/strings", ], ) @@ -36,8 +36,8 @@ cc_library( srcs = ["gen_op_registration.cc"], hdrs = ["gen_op_registration.h"], deps = [ - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:string", + "//tensorflow/lite:framework", + "//tensorflow/lite:string", "@com_googlesource_code_re2//:re2", ], ) @@ -46,11 +46,11 @@ cc_test( name = "gen_op_registration_test", srcs = ["gen_op_registration_test.cc"], data = [ - "//tensorflow/contrib/lite:testdata/0_subgraphs.bin", - "//tensorflow/contrib/lite:testdata/2_subgraphs.bin", - "//tensorflow/contrib/lite:testdata/empty_model.bin", - "//tensorflow/contrib/lite:testdata/test_model.bin", - "//tensorflow/contrib/lite:testdata/test_model_broken.bin", + "//tensorflow/lite:testdata/0_subgraphs.bin", + "//tensorflow/lite:testdata/2_subgraphs.bin", + "//tensorflow/lite:testdata/empty_model.bin", + "//tensorflow/lite:testdata/test_model.bin", + "//tensorflow/lite:testdata/test_model_broken.bin", ], tags = [ "no_oss", @@ -68,10 +68,10 @@ cc_library( srcs = ["verifier.cc"], hdrs = ["verifier.h"], deps = [ - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:schema_fbs_version", - "//tensorflow/contrib/lite:string_util", - "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/lite:framework", + "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite:string_util", + "//tensorflow/lite/schema:schema_fbs", ], ) @@ -85,11 +85,11 @@ cc_test( ], deps = [ ":verifier", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:schema_fbs_version", - "//tensorflow/contrib/lite/schema:schema_fbs", - "//tensorflow/contrib/lite/testing:util", "//tensorflow/core:framework_lite", + "//tensorflow/lite:framework", + "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", "@flatbuffers", ], diff --git a/tensorflow/contrib/lite/tools/accuracy/BUILD b/tensorflow/lite/tools/accuracy/BUILD similarity index 92% rename from tensorflow/contrib/lite/tools/accuracy/BUILD rename to tensorflow/lite/tools/accuracy/BUILD index 1b60d6a60d39cc..64475e057ae415 100644 --- a/tensorflow/contrib/lite/tools/accuracy/BUILD +++ b/tensorflow/lite/tools/accuracy/BUILD @@ -5,8 +5,8 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") -load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "tflite_linkopts") -load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") +load("//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_linkopts") +load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") common_linkopts = tflite_linkopts() + select({ "//conditions:default": [], @@ -22,8 +22,8 @@ cc_library( hdrs = ["utils.h"], copts = tflite_copts(), deps = [ - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:builtin_ops", ] + select( { "//tensorflow:android": [ @@ -40,9 +40,9 @@ tf_cc_test( name = "utils_test", srcs = ["utils_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/contrib/lite:testdata/multi_add.bin)", + "--test_model_file=$(location //tensorflow/lite:testdata/multi_add.bin)", ], - data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"], + data = ["//tensorflow/lite:testdata/multi_add.bin"], linkopts = common_linkopts, linkstatic = 1, tags = [ @@ -72,8 +72,8 @@ cc_library( copts = tflite_copts(), deps = [ ":utils", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:builtin_ops", ] + select( { "//tensorflow:android": [ @@ -102,9 +102,9 @@ tf_cc_test( name = "run_tflite_model_op_test", srcs = ["run_tflite_model_op_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/contrib/lite:testdata/multi_add.bin)", + "--test_model_file=$(location //tensorflow/lite:testdata/multi_add.bin)", ], - data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"], + data = ["//tensorflow/lite:testdata/multi_add.bin"], linkopts = common_linkopts, linkstatic = 1, tags = [ diff --git a/tensorflow/contrib/lite/tools/accuracy/README.md b/tensorflow/lite/tools/accuracy/README.md similarity index 100% rename from tensorflow/contrib/lite/tools/accuracy/README.md rename to tensorflow/lite/tools/accuracy/README.md diff --git a/tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h b/tensorflow/lite/tools/accuracy/accuracy_eval_stage.h similarity index 88% rename from tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h rename to tensorflow/lite/tools/accuracy/accuracy_eval_stage.h index 9cb843729aa8c1..5a2ba3d2a7a2f1 100644 --- a/tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h +++ b/tensorflow/lite/tools/accuracy/accuracy_eval_stage.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ACCURACY_EVAL_STAGE_H_ -#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ACCURACY_EVAL_STAGE_H_ +#ifndef TENSORFLOW_LITE_TOOLS_ACCURACY_ACCURACY_EVAL_STAGE_H_ +#define TENSORFLOW_LITE_TOOLS_ACCURACY_ACCURACY_EVAL_STAGE_H_ #include @@ -46,4 +46,4 @@ class AccuracyEval { }; } // namespace metrics } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ACCURACY_EVAL_STAGE_H_ +#endif // TENSORFLOW_LITE_TOOLS_ACCURACY_ACCURACY_EVAL_STAGE_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/android_required_build_flags.cc b/tensorflow/lite/tools/accuracy/android_required_build_flags.cc similarity index 100% rename from tensorflow/contrib/lite/tools/accuracy/android_required_build_flags.cc rename to tensorflow/lite/tools/accuracy/android_required_build_flags.cc diff --git a/tensorflow/contrib/lite/tools/accuracy/csv_writer.h b/tensorflow/lite/tools/accuracy/csv_writer.h similarity index 92% rename from tensorflow/contrib/lite/tools/accuracy/csv_writer.h rename to tensorflow/lite/tools/accuracy/csv_writer.h index 806b0d9418e8b0..d74a803ce18766 100644 --- a/tensorflow/contrib/lite/tools/accuracy/csv_writer.h +++ b/tensorflow/lite/tools/accuracy/csv_writer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_CSV_WRITER_H_ -#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_CSV_WRITER_H_ +#ifndef TENSORFLOW_LITE_TOOLS_ACCURACY_CSV_WRITER_H_ +#define TENSORFLOW_LITE_TOOLS_ACCURACY_CSV_WRITER_H_ #include #include @@ -76,4 +76,4 @@ class CSVWriter { }; } // namespace metrics } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_CSV_WRITER_H_ +#endif // TENSORFLOW_LITE_TOOLS_ACCURACY_CSV_WRITER_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.cc b/tensorflow/lite/tools/accuracy/eval_pipeline.cc similarity index 95% rename from tensorflow/contrib/lite/tools/accuracy/eval_pipeline.cc rename to tensorflow/lite/tools/accuracy/eval_pipeline.cc index a03aba6a2685db..658824a7d03fe6 100644 --- a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.cc +++ b/tensorflow/lite/tools/accuracy/eval_pipeline.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h" +#include "tensorflow/lite/tools/accuracy/eval_pipeline.h" namespace tensorflow { namespace metrics { diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h b/tensorflow/lite/tools/accuracy/eval_pipeline.h similarity index 89% rename from tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h rename to tensorflow/lite/tools/accuracy/eval_pipeline.h index c9cfc866139da8..1ec21b07e8bce1 100644 --- a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h +++ b/tensorflow/lite/tools/accuracy/eval_pipeline.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_H_ -#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_H_ +#ifndef TENSORFLOW_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_H_ +#define TENSORFLOW_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_H_ #include -#include "tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h" -#include "tensorflow/contrib/lite/tools/accuracy/stage.h" +#include "tensorflow/lite/tools/accuracy/accuracy_eval_stage.h" +#include "tensorflow/lite/tools/accuracy/stage.h" #include "tensorflow/core/public/session.h" namespace tensorflow { @@ -84,4 +84,4 @@ class EvalPipeline { }; } // namespace metrics } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_H_ +#endif // TENSORFLOW_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.cc b/tensorflow/lite/tools/accuracy/eval_pipeline_builder.cc similarity index 97% rename from tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.cc rename to tensorflow/lite/tools/accuracy/eval_pipeline_builder.cc index 2e16437e1588b4..1b360d31b36e57 100644 --- a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.cc +++ b/tensorflow/lite/tools/accuracy/eval_pipeline_builder.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h" +#include "tensorflow/lite/tools/accuracy/eval_pipeline_builder.h" #include "absl/memory/memory.h" #include "tensorflow/cc/ops/standard_ops.h" diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h b/tensorflow/lite/tools/accuracy/eval_pipeline_builder.h similarity index 89% rename from tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h rename to tensorflow/lite/tools/accuracy/eval_pipeline_builder.h index 692db022f8bc74..18b52ac7bea361 100644 --- a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h +++ b/tensorflow/lite/tools/accuracy/eval_pipeline_builder.h @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_BUILDER_H_ -#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_BUILDER_H_ +#ifndef TENSORFLOW_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_BUILDER_H_ +#define TENSORFLOW_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_BUILDER_H_ #include #include -#include "tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h" -#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h" -#include "tensorflow/contrib/lite/tools/accuracy/stage.h" +#include "tensorflow/lite/tools/accuracy/accuracy_eval_stage.h" +#include "tensorflow/lite/tools/accuracy/eval_pipeline.h" +#include "tensorflow/lite/tools/accuracy/stage.h" namespace tensorflow { namespace metrics { @@ -96,4 +96,4 @@ class EvalPipelineBuilder { } // namespace metrics } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_BUILDER_H_ +#endif // TENSORFLOW_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_BUILDER_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder_test.cc b/tensorflow/lite/tools/accuracy/eval_pipeline_builder_test.cc similarity index 99% rename from tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder_test.cc rename to tensorflow/lite/tools/accuracy/eval_pipeline_builder_test.cc index 2d41929b7920f4..9bf725439c486d 100644 --- a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder_test.cc +++ b/tensorflow/lite/tools/accuracy/eval_pipeline_builder_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h" +#include "tensorflow/lite/tools/accuracy/eval_pipeline_builder.h" #include #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/public/session.h" diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_test.cc b/tensorflow/lite/tools/accuracy/eval_pipeline_test.cc similarity index 98% rename from tensorflow/contrib/lite/tools/accuracy/eval_pipeline_test.cc rename to tensorflow/lite/tools/accuracy/eval_pipeline_test.cc index ea0f6e19df46d8..53cbf8ccd5b7fd 100644 --- a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_test.cc +++ b/tensorflow/lite/tools/accuracy/eval_pipeline_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h" +#include "tensorflow/lite/tools/accuracy/eval_pipeline.h" #include #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/public/session.h" diff --git a/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.cc b/tensorflow/lite/tools/accuracy/file_reader_stage.cc similarity index 93% rename from tensorflow/contrib/lite/tools/accuracy/file_reader_stage.cc rename to tensorflow/lite/tools/accuracy/file_reader_stage.cc index 61bed369f8b4f6..a106a79a4baedc 100644 --- a/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.cc +++ b/tensorflow/lite/tools/accuracy/file_reader_stage.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h" +#include "tensorflow/lite/tools/accuracy/file_reader_stage.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" diff --git a/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h b/tensorflow/lite/tools/accuracy/file_reader_stage.h similarity index 81% rename from tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h rename to tensorflow/lite/tools/accuracy/file_reader_stage.h index 18db5837c1717c..19655e96973498 100644 --- a/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h +++ b/tensorflow/lite/tools/accuracy/file_reader_stage.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_FILE_READER_STAGE_H_ -#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_FILE_READER_STAGE_H_ +#ifndef TENSORFLOW_LITE_TOOLS_ACCURACY_FILE_READER_STAGE_H_ +#define TENSORFLOW_LITE_TOOLS_ACCURACY_FILE_READER_STAGE_H_ #include -#include "tensorflow/contrib/lite/tools/accuracy/stage.h" +#include "tensorflow/lite/tools/accuracy/stage.h" namespace tensorflow { namespace metrics { @@ -34,4 +34,4 @@ class FileReaderStage : public Stage { }; } // namespace metrics } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_FILE_READER_STAGE_H_ +#endif // TENSORFLOW_LITE_TOOLS_ACCURACY_FILE_READER_STAGE_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/file_reader_stage_test.cc b/tensorflow/lite/tools/accuracy/file_reader_stage_test.cc similarity index 97% rename from tensorflow/contrib/lite/tools/accuracy/file_reader_stage_test.cc rename to tensorflow/lite/tools/accuracy/file_reader_stage_test.cc index a75f99187d6ea0..21be0a766b5ec4 100644 --- a/tensorflow/contrib/lite/tools/accuracy/file_reader_stage_test.cc +++ b/tensorflow/lite/tools/accuracy/file_reader_stage_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h" +#include "tensorflow/lite/tools/accuracy/file_reader_stage.h" #include "tensorflow/core/public/session.h" namespace tensorflow { diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD b/tensorflow/lite/tools/accuracy/ilsvrc/BUILD similarity index 82% rename from tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD rename to tensorflow/lite/tools/accuracy/ilsvrc/BUILD index 98e2835b2ebd2f..a4d21961a6f00e 100644 --- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD +++ b/tensorflow/lite/tools/accuracy/ilsvrc/BUILD @@ -5,8 +5,8 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") -load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "tflite_linkopts") -load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") +load("//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_linkopts") +load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") common_linkopts = tflite_linkopts() + select({ "//conditions:default": [], @@ -22,8 +22,8 @@ cc_library( hdrs = ["inception_preprocessing.h"], copts = tflite_copts(), deps = [ - "//tensorflow/contrib/lite/tools/accuracy:android_required_build_flags", - "//tensorflow/contrib/lite/tools/accuracy:stage", + "//tensorflow/lite/tools/accuracy:android_required_build_flags", + "//tensorflow/lite/tools/accuracy:stage", "//tensorflow/cc:cc_ops", "//tensorflow/cc:scope", ] + select( @@ -60,7 +60,7 @@ tf_cc_test( ], deps = [ ":inception_preprocessing", - "//tensorflow/contrib/lite/tools/accuracy:android_required_build_flags", + "//tensorflow/lite/tools/accuracy:android_required_build_flags", "@com_google_googletest//:gtest", ] + select( { @@ -83,7 +83,7 @@ cc_library( hdrs = ["imagenet_topk_eval.h"], copts = tflite_copts(), deps = [ - "//tensorflow/contrib/lite/tools/accuracy:accuracy_eval_stage", + "//tensorflow/lite/tools/accuracy:accuracy_eval_stage", ] + select( { "//tensorflow:android": [ @@ -127,12 +127,12 @@ cc_library( deps = [ ":imagenet_topk_eval", ":inception_preprocessing", - "//tensorflow/contrib/lite/tools/accuracy:android_required_build_flags", - "//tensorflow/contrib/lite/tools/accuracy:eval_pipeline", - "//tensorflow/contrib/lite/tools/accuracy:eval_pipeline_builder", - "//tensorflow/contrib/lite/tools/accuracy:file_reader_stage", - "//tensorflow/contrib/lite/tools/accuracy:run_tflite_model_stage", - "//tensorflow/contrib/lite/tools/accuracy:utils", + "//tensorflow/lite/tools/accuracy:android_required_build_flags", + "//tensorflow/lite/tools/accuracy:eval_pipeline", + "//tensorflow/lite/tools/accuracy:eval_pipeline_builder", + "//tensorflow/lite/tools/accuracy:file_reader_stage", + "//tensorflow/lite/tools/accuracy:run_tflite_model_stage", + "//tensorflow/lite/tools/accuracy:utils", "@com_google_absl//absl/memory", "//tensorflow/cc:cc_ops", "//tensorflow/cc:scope", @@ -164,8 +164,8 @@ tf_cc_binary( ":imagenet_model_evaluator", ":imagenet_topk_eval", "@com_google_absl//absl/memory", - "//tensorflow/contrib/lite/tools/accuracy:android_required_build_flags", - "//tensorflow/contrib/lite/tools/accuracy:csv_writer", + "//tensorflow/lite/tools/accuracy:android_required_build_flags", + "//tensorflow/lite/tools/accuracy:csv_writer", ] + select( { "//tensorflow:android": [ diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md b/tensorflow/lite/tools/accuracy/ilsvrc/README.md similarity index 95% rename from tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md rename to tensorflow/lite/tools/accuracy/ilsvrc/README.md index 362ea3ac34f60a..ac3a1566e2a2c8 100644 --- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md +++ b/tensorflow/lite/tools/accuracy/ilsvrc/README.md @@ -75,14 +75,14 @@ bazel build -c opt \ --cxxopt='--std=c++11' \ --copt=-D__ANDROID_TYPES_FULL__ \ --copt=-DSUPPORT_SELECTIVE_REGISTRATION \ - //tensorflow/contrib/lite/tools/accuracy/ilsvrc:imagenet_accuracy_eval + //tensorflow/lite/tools/accuracy/ilsvrc:imagenet_accuracy_eval ``` (2) Connect your phone. Push the binary to your phone with adb push (make the directory if required): ``` -adb push bazel-bin/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval /data/local/tmp +adb push bazel-bin/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval /data/local/tmp ``` (3) Make the binary executable. @@ -136,7 +136,7 @@ adb shell /data/local/tmp/imagenet_accuracy_eval \ bazel run -c opt \ --cxxopt='--std=c++11' \ -- \ - //tensorflow/contrib/lite/tools/accuracy/ilsvrc:imagenet_accuracy_eval \ + //tensorflow/lite/tools/accuracy/ilsvrc:imagenet_accuracy_eval \ --model_file=mobilenet_quant_v1_224.tflite \ --ground_truth_images_path=${IMAGENET_IMAGES_DIR} \ --ground_truth_labels=${VALIDATION_LABELS} \ diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/clsloc_validation_blacklist.txt b/tensorflow/lite/tools/accuracy/ilsvrc/clsloc_validation_blacklist.txt similarity index 100% rename from tensorflow/contrib/lite/tools/accuracy/ilsvrc/clsloc_validation_blacklist.txt rename to tensorflow/lite/tools/accuracy/ilsvrc/clsloc_validation_blacklist.txt diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/generate_validation_labels.py b/tensorflow/lite/tools/accuracy/ilsvrc/generate_validation_labels.py similarity index 100% rename from tensorflow/contrib/lite/tools/accuracy/ilsvrc/generate_validation_labels.py rename to tensorflow/lite/tools/accuracy/ilsvrc/generate_validation_labels.py diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc similarity index 96% rename from tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc rename to tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc index 2a8a2b9b59db06..090a023c02727c 100644 --- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc +++ b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include "absl/memory/memory.h" -#include "tensorflow/contrib/lite/tools/accuracy/csv_writer.h" -#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h" -#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h" +#include "tensorflow/lite/tools/accuracy/csv_writer.h" +#include "tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h" +#include "tensorflow/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/util/command_line_flags.h" diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc similarity index 95% rename from tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc rename to tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc index 63616fc3b4b066..9a74e221c13e72 100644 --- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc +++ b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h" +#include "tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h" #include #include @@ -22,13 +22,13 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/cc/framework/scope.h" -#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h" -#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h" -#include "tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h" -#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h" -#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h" -#include "tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h" -#include "tensorflow/contrib/lite/tools/accuracy/utils.h" +#include "tensorflow/lite/tools/accuracy/eval_pipeline.h" +#include "tensorflow/lite/tools/accuracy/eval_pipeline_builder.h" +#include "tensorflow/lite/tools/accuracy/file_reader_stage.h" +#include "tensorflow/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h" +#include "tensorflow/lite/tools/accuracy/ilsvrc/inception_preprocessing.h" +#include "tensorflow/lite/tools/accuracy/run_tflite_model_stage.h" +#include "tensorflow/lite/tools/accuracy/utils.h" #include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/init_main.h" diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h similarity index 91% rename from tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h rename to tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h index 97e4232b358cab..c3c49e9a51b525 100644 --- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h +++ b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_MODEL_EVALUATOR_H_ -#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_MODEL_EVALUATOR_H_ +#ifndef TENSORFLOW_LITE_TOOLS_ACCURACY_IMAGENET_MODEL_EVALUATOR_H_ +#define TENSORFLOW_LITE_TOOLS_ACCURACY_IMAGENET_MODEL_EVALUATOR_H_ #include #include -#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h" -#include "tensorflow/contrib/lite/tools/accuracy/utils.h" +#include "tensorflow/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h" +#include "tensorflow/lite/tools/accuracy/utils.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/status.h" @@ -121,4 +121,4 @@ class ImagenetModelEvaluator { } // namespace metrics } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ILSVRC_IMAGENET_MODEL_EVALUATOR_H_ +#endif // TENSORFLOW_LITE_TOOLS_ACCURACY_ILSVRC_IMAGENET_MODEL_EVALUATOR_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc similarity index 98% rename from tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc rename to tensorflow/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc index c75baa82b1d013..2b086cdf7075d7 100644 --- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc +++ b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h" +#include "tensorflow/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h" #include diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h similarity index 91% rename from tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h rename to tensorflow/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h index cad646a30ca96b..e1fc445abf41b5 100644 --- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h +++ b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_TOPK_EVAL_H_ -#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_TOPK_EVAL_H_ +#ifndef TENSORFLOW_LITE_TOOLS_ACCURACY_IMAGENET_TOPK_EVAL_H_ +#define TENSORFLOW_LITE_TOOLS_ACCURACY_IMAGENET_TOPK_EVAL_H_ #include #include -#include "tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h" +#include "tensorflow/lite/tools/accuracy/accuracy_eval_stage.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/mutex.h" @@ -80,4 +80,4 @@ class ImagenetTopKAccuracy : public AccuracyEval { } // namespace metrics } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ILSVRC_IMAGENET_TOPK_EVAL_H_ +#endif // TENSORFLOW_LITE_TOOLS_ACCURACY_ILSVRC_IMAGENET_TOPK_EVAL_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval_test.cc b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_topk_eval_test.cc similarity index 98% rename from tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval_test.cc rename to tensorflow/lite/tools/accuracy/ilsvrc/imagenet_topk_eval_test.cc index ff332af5c5e56e..61b7afc552de60 100644 --- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval_test.cc +++ b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_topk_eval_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h" +#include "tensorflow/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h" #include namespace tensorflow { diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.cc b/tensorflow/lite/tools/accuracy/ilsvrc/inception_preprocessing.cc similarity index 97% rename from tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.cc rename to tensorflow/lite/tools/accuracy/ilsvrc/inception_preprocessing.cc index 7512b39c32f98f..9a889f0dd88bc4 100644 --- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.cc +++ b/tensorflow/lite/tools/accuracy/ilsvrc/inception_preprocessing.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h" +#include "tensorflow/lite/tools/accuracy/ilsvrc/inception_preprocessing.h" #include diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h b/tensorflow/lite/tools/accuracy/ilsvrc/inception_preprocessing.h similarity index 89% rename from tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h rename to tensorflow/lite/tools/accuracy/ilsvrc/inception_preprocessing.h index 15df71981756f6..4a1d3ce4769d1a 100644 --- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h +++ b/tensorflow/lite/tools/accuracy/ilsvrc/inception_preprocessing.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_INCEPTION_PREPROCESSING_H_ -#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_INCEPTION_PREPROCESSING_H_ +#ifndef TENSORFLOW_LITE_TOOLS_ACCURACY_INCEPTION_PREPROCESSING_H_ +#define TENSORFLOW_LITE_TOOLS_ACCURACY_INCEPTION_PREPROCESSING_H_ #include -#include "tensorflow/contrib/lite/tools/accuracy/stage.h" +#include "tensorflow/lite/tools/accuracy/stage.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" @@ -72,4 +72,4 @@ class InceptionPreprocessingStage : public Stage { } // namespace metrics } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_INCEPTION_PREPROCESSING_H_ +#endif // TENSORFLOW_LITE_TOOLS_ACCURACY_INCEPTION_PREPROCESSING_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing_test.cc b/tensorflow/lite/tools/accuracy/ilsvrc/inception_preprocessing_test.cc similarity index 98% rename from tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing_test.cc rename to tensorflow/lite/tools/accuracy/ilsvrc/inception_preprocessing_test.cc index 3587878ba3cadd..5d0e01d7d18c45 100644 --- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing_test.cc +++ b/tensorflow/lite/tools/accuracy/ilsvrc/inception_preprocessing_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h" +#include "tensorflow/lite/tools/accuracy/ilsvrc/inception_preprocessing.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/util/command_line_flags.h" diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/testdata/grace_hopper.jpg b/tensorflow/lite/tools/accuracy/ilsvrc/testdata/grace_hopper.jpg similarity index 100% rename from tensorflow/contrib/lite/tools/accuracy/ilsvrc/testdata/grace_hopper.jpg rename to tensorflow/lite/tools/accuracy/ilsvrc/testdata/grace_hopper.jpg diff --git a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op.cc b/tensorflow/lite/tools/accuracy/run_tflite_model_op.cc similarity index 95% rename from tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op.cc rename to tensorflow/lite/tools/accuracy/run_tflite_model_op.cc index da4258f1c13107..5f413b8ee39324 100644 --- a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op.cc +++ b/tensorflow/lite/tools/accuracy/run_tflite_model_op.cc @@ -16,12 +16,12 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/op_resolver.h" -#include "tensorflow/contrib/lite/tools/accuracy/utils.h" +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/op_resolver.h" +#include "tensorflow/lite/tools/accuracy/utils.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op_test.cc b/tensorflow/lite/tools/accuracy/run_tflite_model_op_test.cc similarity index 100% rename from tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op_test.cc rename to tensorflow/lite/tools/accuracy/run_tflite_model_op_test.cc diff --git a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.cc b/tensorflow/lite/tools/accuracy/run_tflite_model_stage.cc similarity index 95% rename from tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.cc rename to tensorflow/lite/tools/accuracy/run_tflite_model_stage.cc index c96795d4994ae3..6082290c0bc4fb 100644 --- a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.cc +++ b/tensorflow/lite/tools/accuracy/run_tflite_model_stage.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h" +#include "tensorflow/lite/tools/accuracy/run_tflite_model_stage.h" #include diff --git a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h b/tensorflow/lite/tools/accuracy/run_tflite_model_stage.h similarity index 85% rename from tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h rename to tensorflow/lite/tools/accuracy/run_tflite_model_stage.h index 90d12d6f424516..61034491777a8d 100644 --- a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h +++ b/tensorflow/lite/tools/accuracy/run_tflite_model_stage.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_RUN_TFLITE_MODEL_STAGE_H_ -#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_RUN_TFLITE_MODEL_STAGE_H_ +#ifndef TENSORFLOW_LITE_TOOLS_ACCURACY_RUN_TFLITE_MODEL_STAGE_H_ +#define TENSORFLOW_LITE_TOOLS_ACCURACY_RUN_TFLITE_MODEL_STAGE_H_ #include -#include "tensorflow/contrib/lite/tools/accuracy/stage.h" +#include "tensorflow/lite/tools/accuracy/stage.h" namespace tensorflow { namespace metrics { @@ -50,4 +50,4 @@ class RunTFLiteModelStage : public Stage { } // namespace metrics } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_RUN_TFLITE_MODEL_STAGE_H_ +#endif // TENSORFLOW_LITE_TOOLS_ACCURACY_RUN_TFLITE_MODEL_STAGE_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/stage.h b/tensorflow/lite/tools/accuracy/stage.h similarity index 90% rename from tensorflow/contrib/lite/tools/accuracy/stage.h rename to tensorflow/lite/tools/accuracy/stage.h index 8292ea2ec735dc..0a9e3fbd055e67 100644 --- a/tensorflow/contrib/lite/tools/accuracy/stage.h +++ b/tensorflow/lite/tools/accuracy/stage.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_STAGE_H_ -#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_STAGE_H_ +#ifndef TENSORFLOW_LITE_TOOLS_ACCURACY_STAGE_H_ +#define TENSORFLOW_LITE_TOOLS_ACCURACY_STAGE_H_ #include "tensorflow/cc/framework/scope.h" @@ -53,4 +53,4 @@ class Stage { } // namespace metrics } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_STAGE_H_ +#endif // TENSORFLOW_LITE_TOOLS_ACCURACY_STAGE_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/utils.cc b/tensorflow/lite/tools/accuracy/utils.cc similarity index 92% rename from tensorflow/contrib/lite/tools/accuracy/utils.cc rename to tensorflow/lite/tools/accuracy/utils.cc index f5493301fc4d78..c19dc1ff7cca10 100644 --- a/tensorflow/contrib/lite/tools/accuracy/utils.cc +++ b/tensorflow/lite/tools/accuracy/utils.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/accuracy/utils.h" +#include "tensorflow/lite/tools/accuracy/utils.h" #include @@ -22,10 +22,10 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/op_resolver.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/op_resolver.h" namespace tensorflow { namespace metrics { diff --git a/tensorflow/contrib/lite/tools/accuracy/utils.h b/tensorflow/lite/tools/accuracy/utils.h similarity index 85% rename from tensorflow/contrib/lite/tools/accuracy/utils.h rename to tensorflow/lite/tools/accuracy/utils.h index 37cbad4d51fd0d..5b7639317eff0b 100644 --- a/tensorflow/contrib/lite/tools/accuracy/utils.h +++ b/tensorflow/lite/tools/accuracy/utils.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_UTILS_H_ -#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_UTILS_H_ +#ifndef TENSORFLOW_LITE_TOOLS_ACCURACY_UTILS_H_ +#define TENSORFLOW_LITE_TOOLS_ACCURACY_UTILS_H_ #include #include -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/lite/context.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { @@ -43,4 +43,4 @@ Status ReadFileLines(const string& file_path, } // namespace utils } // namespace metrics } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_UTILS_H_ +#endif // TENSORFLOW_LITE_TOOLS_ACCURACY_UTILS_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/utils_test.cc b/tensorflow/lite/tools/accuracy/utils_test.cc similarity index 97% rename from tensorflow/contrib/lite/tools/accuracy/utils_test.cc rename to tensorflow/lite/tools/accuracy/utils_test.cc index 727eba21b6c600..401872f18ffb3a 100644 --- a/tensorflow/contrib/lite/tools/accuracy/utils_test.cc +++ b/tensorflow/lite/tools/accuracy/utils_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/tools/accuracy/utils.h" +#include "tensorflow/lite/tools/accuracy/utils.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" diff --git a/tensorflow/contrib/lite/tools/benchmark/BUILD b/tensorflow/lite/tools/benchmark/BUILD similarity index 71% rename from tensorflow/contrib/lite/tools/benchmark/BUILD rename to tensorflow/lite/tools/benchmark/BUILD index f990493dc5ab50..583046ad73d67b 100644 --- a/tensorflow/contrib/lite/tools/benchmark/BUILD +++ b/tensorflow/lite/tools/benchmark/BUILD @@ -4,9 +4,9 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 -load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") -load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") -load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts") +load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") +load("//tensorflow/lite:build_def.bzl", "tflite_copts") +load("//tensorflow/lite:build_def.bzl", "tflite_linkopts") common_copts = ["-Wall"] + tflite_copts() @@ -51,8 +51,8 @@ cc_binary( deps = [ ":benchmark_tflite_model_lib", ":logging", - "//tensorflow/contrib/lite/delegates/flex:delegate", - "//tensorflow/contrib/lite/testing:init_tensorflow", + "//tensorflow/lite/delegates/flex:delegate", + "//tensorflow/lite/testing:init_tensorflow", ], ) @@ -60,9 +60,9 @@ cc_test( name = "benchmark_test", srcs = ["benchmark_test.cc"], args = [ - "--graph=$(location //tensorflow/contrib/lite:testdata/multi_add.bin)", + "--graph=$(location //tensorflow/lite:testdata/multi_add.bin)", ], - data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"], + data = ["//tensorflow/lite:testdata/multi_add.bin"], tags = [ "tflite_not_portable_android", "tflite_not_portable_ios", @@ -70,7 +70,7 @@ cc_test( deps = [ ":benchmark_tflite_model_lib", ":command_line_flags", - "//tensorflow/contrib/lite/testing:util", + "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -92,7 +92,7 @@ cc_test( visibility = ["//visibility:private"], deps = [ ":command_line_flags", - "//tensorflow/contrib/lite/testing:util", + "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -108,10 +108,11 @@ cc_library( deps = [ ":benchmark_model_lib", ":logging", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:string_util", - "//tensorflow/contrib/lite/kernels:builtin_ops", - "//tensorflow/contrib/lite/profiling:profile_summarizer", + "//tensorflow/lite:framework", + "//tensorflow/lite:string_util", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/profiling:profile_summarizer", + "@gemmlowp", ], ) @@ -136,13 +137,13 @@ cc_library( ":benchmark_params", ":command_line_flags", ":logging", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:string_util", - "//tensorflow/contrib/lite/kernels:builtin_ops", - "//tensorflow/contrib/lite/profiling:profile_summarizer", - "//tensorflow/contrib/lite/profiling:profiler", - "//tensorflow/contrib/lite/profiling:time", "//tensorflow/core:stats_calculator_portable", + "//tensorflow/lite:framework", + "//tensorflow/lite:string_util", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/profiling:profile_summarizer", + "//tensorflow/lite/profiling:profiler", + "//tensorflow/lite/profiling:time", ], ) diff --git a/tensorflow/contrib/lite/tools/benchmark/README.md b/tensorflow/lite/tools/benchmark/README.md similarity index 96% rename from tensorflow/contrib/lite/tools/benchmark/README.md rename to tensorflow/lite/tools/benchmark/README.md index 8d997639fb7a36..a71a2fa1c0ec3c 100644 --- a/tensorflow/contrib/lite/tools/benchmark/README.md +++ b/tensorflow/lite/tools/benchmark/README.md @@ -9,7 +9,7 @@ of runs. Aggregrate latency statistics are reported after running the benchmark. The instructions below are for running the binary on Desktop and Android, for iOS please use the -[iOS benchmark app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios). +[iOS benchmark app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark/ios). ## Parameters @@ -45,14 +45,14 @@ and the following optional parameters: bazel build -c opt \ --config=android_arm \ --cxxopt='--std=c++11' \ - tensorflow/contrib/lite/tools/benchmark:benchmark_model + tensorflow/lite/tools/benchmark:benchmark_model ``` (2) Connect your phone. Push the binary to your phone with adb push (make the directory if required): ``` -adb push bazel-bin/tensorflow/contrib/lite/tools/benchmark/benchmark_model /data/local/tmp +adb push bazel-bin/tensorflow/lite/tools/benchmark/benchmark_model /data/local/tmp ``` (3) Make the binary executable. @@ -79,14 +79,14 @@ adb shell /data/local/tmp/benchmark_model \ (1) build the binary ``` -bazel build -c opt tensorflow/contrib/lite/tools/benchmark:benchmark_model +bazel build -c opt tensorflow/lite/tools/benchmark:benchmark_model ``` (2) Run on your compute graph, similar to the Android case but without the need of adb shell. For example: ``` -bazel-bin/tensorflow/contrib/lite/tools/benchmark/benchmark_model \ +bazel-bin/tensorflow/lite/tools/benchmark/benchmark_model \ --graph=mobilenet_quant_v1_224.tflite \ --num_threads=4 ``` @@ -126,7 +126,7 @@ bazel build -c opt \ --config=android_arm \ --cxxopt='--std=c++11' \ --copt=-DTFLITE_PROFILING_ENABLED \ - tensorflow/contrib/lite/tools/benchmark:benchmark_model + tensorflow/lite/tools/benchmark:benchmark_model ``` This compiles TFLite with profiling enabled, now you can run the benchmark binary like before. The binary will produce detailed statistics for each operation similar to those shown below: diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_main.cc b/tensorflow/lite/tools/benchmark/benchmark_main.cc similarity index 89% rename from tensorflow/contrib/lite/tools/benchmark/benchmark_main.cc rename to tensorflow/lite/tools/benchmark/benchmark_main.cc index 372d31e838e566..dcf82a8b7ec348 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_main.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_main.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h" -#include "tensorflow/contrib/lite/tools/benchmark/logging.h" +#include "tensorflow/lite/tools/benchmark/benchmark_tflite_model.h" +#include "tensorflow/lite/tools/benchmark/logging.h" namespace tflite { namespace benchmark { diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc b/tensorflow/lite/tools/benchmark/benchmark_model.cc similarity index 96% rename from tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc rename to tensorflow/lite/tools/benchmark/benchmark_model.cc index f86c0445b0525c..05148aea65b6e5 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_model.cc @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/benchmark/benchmark_model.h" +#include "tensorflow/lite/tools/benchmark/benchmark_model.h" #include #include #include -#include "tensorflow/contrib/lite/profiling/time.h" -#include "tensorflow/contrib/lite/tools/benchmark/logging.h" +#include "tensorflow/lite/profiling/time.h" +#include "tensorflow/lite/tools/benchmark/logging.h" namespace { void SleepForSeconds(double sleep_seconds) { diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h b/tensorflow/lite/tools/benchmark/benchmark_model.h similarity index 93% rename from tensorflow/contrib/lite/tools/benchmark/benchmark_model.h rename to tensorflow/lite/tools/benchmark/benchmark_model.h index cc215a7b7f08a9..d8a9b05010aba4 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h +++ b/tensorflow/lite/tools/benchmark/benchmark_model.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_ -#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_ +#ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_ +#define TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_ #include #include @@ -23,8 +23,8 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/tools/benchmark/benchmark_params.h" -#include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h" +#include "tensorflow/lite/tools/benchmark/benchmark_params.h" +#include "tensorflow/lite/tools/benchmark/command_line_flags.h" #include "tensorflow/core/util/stats_calculator.h" namespace tflite { @@ -160,4 +160,4 @@ class BenchmarkModel { } // namespace benchmark } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_ +#endif // TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_ diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_params.cc b/tensorflow/lite/tools/benchmark/benchmark_params.cc similarity index 92% rename from tensorflow/contrib/lite/tools/benchmark/benchmark_params.cc rename to tensorflow/lite/tools/benchmark/benchmark_params.cc index 1dcf580a9d4995..5ab3adff553674 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_params.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_params.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/benchmark/benchmark_params.h" +#include "tensorflow/lite/tools/benchmark/benchmark_params.h" #include #include #include -#include "tensorflow/contrib/lite/tools/benchmark/logging.h" +#include "tensorflow/lite/tools/benchmark/logging.h" namespace tflite { namespace benchmark { diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_params.h b/tensorflow/lite/tools/benchmark/benchmark_params.h similarity index 90% rename from tensorflow/contrib/lite/tools/benchmark/benchmark_params.h rename to tensorflow/lite/tools/benchmark/benchmark_params.h index c98f47bb0d8986..594baa5b4ec1ec 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_params.h +++ b/tensorflow/lite/tools/benchmark/benchmark_params.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_ -#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_ +#ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_ +#define TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_ #include #include #include #include -#include "tensorflow/contrib/lite/tools/benchmark/logging.h" +#include "tensorflow/lite/tools/benchmark/logging.h" namespace tflite { namespace benchmark { @@ -98,4 +98,4 @@ class BenchmarkParams { } // namespace benchmark } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_ +#endif // TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_ diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_plus_flex_main.cc b/tensorflow/lite/tools/benchmark/benchmark_plus_flex_main.cc similarity index 85% rename from tensorflow/contrib/lite/tools/benchmark/benchmark_plus_flex_main.cc rename to tensorflow/lite/tools/benchmark/benchmark_plus_flex_main.cc index b9cf6c67d2fe94..6e72a293770afd 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_plus_flex_main.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_plus_flex_main.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/testing/init_tensorflow.h" -#include "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h" -#include "tensorflow/contrib/lite/tools/benchmark/logging.h" +#include "tensorflow/lite/testing/init_tensorflow.h" +#include "tensorflow/lite/tools/benchmark/benchmark_tflite_model.h" +#include "tensorflow/lite/tools/benchmark/logging.h" namespace tflite { namespace benchmark { diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_test.cc b/tensorflow/lite/tools/benchmark/benchmark_test.cc similarity index 92% rename from tensorflow/contrib/lite/tools/benchmark/benchmark_test.cc rename to tensorflow/lite/tools/benchmark/benchmark_test.cc index b697bb394db9b9..59d23d90086761 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_test.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/testing/util.h" -#include "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h" -#include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h" +#include "tensorflow/lite/testing/util.h" +#include "tensorflow/lite/tools/benchmark/benchmark_tflite_model.h" +#include "tensorflow/lite/tools/benchmark/command_line_flags.h" namespace { const std::string* g_model_path = nullptr; diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc similarity index 93% rename from tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc rename to tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc index 2a3df7f289ffb7..777d9dde7dd528 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h" +#include "tensorflow/lite/tools/benchmark/benchmark_tflite_model.h" #include #include @@ -23,11 +23,15 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/op_resolver.h" -#include "tensorflow/contrib/lite/string_util.h" -#include "tensorflow/contrib/lite/tools/benchmark/logging.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/op_resolver.h" +#include "tensorflow/lite/string_util.h" +#include "tensorflow/lite/tools/benchmark/logging.h" + +#ifdef GEMMLOWP_PROFILING +#include "third_party/gemmlowp/profiling/profiler.h" +#endif #ifdef TFLITE_CUSTOM_OPS_HEADER void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); @@ -62,6 +66,21 @@ void ProfilingListener::OnSingleRunEnd() { summarizer_.ProcessProfiles(profile_events, *interpreter_); } +void GemmlowpProfilingListener::OnBenchmarkStart( + const BenchmarkParams& params) { +#ifdef GEMMLOWP_PROFILING + gemmlowp::RegisterCurrentThreadForProfiling(); + gemmlowp::StartProfiling(); +#endif +} + +void GemmlowpProfilingListener::OnBenchmarkEnd( + const BenchmarkResults& results) { +#ifdef GEMMLOWP_PROFILING + gemmlowp::FinishProfiling(); +#endif +} + namespace { std::vector Split(const std::string& str, const char delim) { @@ -176,13 +195,12 @@ BenchmarkParams GetDefaultParams() { } // namespace BenchmarkTfLiteModel::BenchmarkTfLiteModel() - : BenchmarkModel(GetDefaultParams()) { - AddListener(&profiling_listener_); -} + : BenchmarkTfLiteModel(GetDefaultParams()) {} BenchmarkTfLiteModel::BenchmarkTfLiteModel(BenchmarkParams params) : BenchmarkModel(std::move(params)) { AddListener(&profiling_listener_); + AddListener(&gemmlowp_profiling_listener_); } std::vector BenchmarkTfLiteModel::GetFlags() { diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h similarity index 70% rename from tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h rename to tensorflow/lite/tools/benchmark/benchmark_tflite_model.h index 25a302b2aaea40..401ab5427d3a04 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h @@ -13,21 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_TFLITE_MODEL_H_ -#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_TFLITE_MODEL_H_ +#ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_TFLITE_MODEL_H_ +#define TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_TFLITE_MODEL_H_ #include #include #include -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/profiling/profile_summarizer.h" -#include "tensorflow/contrib/lite/tools/benchmark/benchmark_model.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/profiling/profile_summarizer.h" +#include "tensorflow/lite/tools/benchmark/benchmark_model.h" namespace tflite { namespace benchmark { -// Dumps profiling events if profiling is enabled +// Dumps profiling events if profiling is enabled. class ProfilingListener : public BenchmarkListener { public: explicit ProfilingListener() : interpreter_(nullptr), has_profiles_(false) {} @@ -47,11 +47,21 @@ class ProfilingListener : public BenchmarkListener { bool has_profiles_; }; +// Dumps gemmlowp profiling events if gemmlowp profiling is enabled. +class GemmlowpProfilingListener : public BenchmarkListener { + public: + virtual ~GemmlowpProfilingListener() {} + + void OnBenchmarkStart(const BenchmarkParams& params) override; + + void OnBenchmarkEnd(const BenchmarkResults& results) override; +}; + // Benchmarks a TFLite model by running tflite interpreter. class BenchmarkTfLiteModel : public BenchmarkModel { public: BenchmarkTfLiteModel(); - BenchmarkTfLiteModel(BenchmarkParams params); + explicit BenchmarkTfLiteModel(BenchmarkParams params); virtual ~BenchmarkTfLiteModel() {} std::vector GetFlags() override; @@ -74,9 +84,10 @@ class BenchmarkTfLiteModel : public BenchmarkModel { std::unique_ptr interpreter; std::vector inputs; ProfilingListener profiling_listener_; + GemmlowpProfilingListener gemmlowp_profiling_listener_; }; } // namespace benchmark } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_TFLITE_MODEL_H_ +#endif // TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_TFLITE_MODEL_H_ diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc b/tensorflow/lite/tools/benchmark/command_line_flags.cc similarity index 96% rename from tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc rename to tensorflow/lite/tools/benchmark/command_line_flags.cc index ff818b9dcb5ee0..2fad780dc8680b 100644 --- a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc +++ b/tensorflow/lite/tools/benchmark/command_line_flags.cc @@ -10,7 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h" +#include "tensorflow/lite/tools/benchmark/command_line_flags.h" #include #include @@ -59,11 +59,12 @@ bool ParseFlag(const std::string& flag_value, bool ParseBoolFlag(const std::string& flag_value, const std::function& hook) { - if (flag_value != "true" && flag_value != "false") { + if (flag_value != "true" && flag_value != "false" && flag_value != "0" && + flag_value != "1") { return false; } - hook(flag_value == "true"); + hook(flag_value == "true" || flag_value == "1"); return true; } } // namespace diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h b/tensorflow/lite/tools/benchmark/command_line_flags.h similarity index 95% rename from tensorflow/contrib/lite/tools/benchmark/command_line_flags.h rename to tensorflow/lite/tools/benchmark/command_line_flags.h index 6a0affd8344935..cc71450053ee8d 100644 --- a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h +++ b/tensorflow/lite/tools/benchmark/command_line_flags.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_COMMAND_LINE_FLAGS_H_ -#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_COMMAND_LINE_FLAGS_H_ +#ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_COMMAND_LINE_FLAGS_H_ +#define TENSORFLOW_LITE_TOOLS_BENCHMARK_COMMAND_LINE_FLAGS_H_ #include #include @@ -120,4 +120,4 @@ class Flags { } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_COMMAND_LINE_FLAGS_H_ +#endif // TENSORFLOW_LITE_TOOLS_BENCHMARK_COMMAND_LINE_FLAGS_H_ diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc b/tensorflow/lite/tools/benchmark/command_line_flags_test.cc similarity index 90% rename from tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc rename to tensorflow/lite/tools/benchmark/command_line_flags_test.cc index 03da8051099899..afdf2793bf9db6 100644 --- a/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc +++ b/tensorflow/lite/tools/benchmark/command_line_flags_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h" +#include "tensorflow/lite/tools/benchmark/command_line_flags.h" #include #include -#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/lite/testing/util.h" namespace tflite { namespace { @@ -27,13 +27,17 @@ TEST(CommandLineFlagsTest, BasicUsage) { bool some_switch = false; std::string some_name = "something_a"; float some_float = -23.23f; + bool some_bool = false; + bool some_numeric_bool = true; const char* argv_strings[] = {"program_name", "--some_int32=20", "--some_int64=214748364700", "--some_switch=true", "--some_name=somethingelse", - "--some_float=42.0"}; - int argc = 6; + "--some_float=42.0", + "--some_bool=true", + "--some_numeric_bool=0"}; + int argc = 8; bool parsed_ok = Flags::Parse( &argc, reinterpret_cast(argv_strings), { @@ -42,6 +46,9 @@ TEST(CommandLineFlagsTest, BasicUsage) { Flag::CreateFlag("some_switch", &some_switch, "some switch"), Flag::CreateFlag("some_name", &some_name, "some name"), Flag::CreateFlag("some_float", &some_float, "some float"), + Flag::CreateFlag("some_bool", &some_bool, "some bool"), + Flag::CreateFlag("some_numeric_bool", &some_numeric_bool, + "some numeric bool"), }); EXPECT_EQ(true, parsed_ok); @@ -50,6 +57,8 @@ TEST(CommandLineFlagsTest, BasicUsage) { EXPECT_EQ(true, some_switch); EXPECT_EQ("somethingelse", some_name); EXPECT_NEAR(42.0f, some_float, 1e-5f); + EXPECT_TRUE(some_bool); + EXPECT_FALSE(some_numeric_bool); EXPECT_EQ(argc, 1); } diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/README.md b/tensorflow/lite/tools/benchmark/ios/README.md similarity index 84% rename from tensorflow/contrib/lite/tools/benchmark/ios/README.md rename to tensorflow/lite/tools/benchmark/ios/README.md index 46144f7bf8e142..3dc29d9b941195 100644 --- a/tensorflow/contrib/lite/tools/benchmark/ios/README.md +++ b/tensorflow/lite/tools/benchmark/ios/README.md @@ -18,15 +18,15 @@ Mobilenet_1.0_224 model ## To build/install/run - Follow instructions at -[iOS build for TFLite](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/ios.md) +[iOS build for TFLite](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/ios.md) to build TFLite. Running ```bash -tensorflow/contrib/lite/build_ios_universal_lib.sh +tensorflow/lite/build_ios_universal_lib.sh ``` -will also build `tensorflow/contrib/lite/gen/lib/benchmark-lib.a` . +will also build `tensorflow/lite/gen/lib/benchmark-lib.a` . - Now copy the downloaded model file to `benchmark_data` directory. @@ -37,7 +37,7 @@ and other benchmark parameters. resources that need to be copied. - Ensure that `Build Phases -> Link Binary With Library` contains the -`Accelerate framework` and `tensorflow/contrib/lite/gen/lib/benchmark-lib.a`. +`Accelerate framework` and `tensorflow/lite/gen/lib/benchmark-lib.a`. - Now try running the app. The app has a single button that runs the benchmark on the model and displays results in a text view below. diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark.xcodeproj/project.pbxproj b/tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark.xcodeproj/project.pbxproj similarity index 92% rename from tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark.xcodeproj/project.pbxproj rename to tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark.xcodeproj/project.pbxproj index 8436c752818040..958936a6607eff 100644 --- a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark.xcodeproj/project.pbxproj +++ b/tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark.xcodeproj/project.pbxproj @@ -20,7 +20,7 @@ /* Begin PBXFileReference section */ 6FE7579920D59CE500F01636 /* benchmark_params.json */ = {isa = PBXFileReference; lastKnownFileType = text.json; path = benchmark_params.json; sourceTree = ""; }; - 6FE7579C20D5A5E000F01636 /* benchmark-lib.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "benchmark-lib.a"; path = "$SRCROOT/../../../../../../../tensorflow/contrib/lite/tools/make/gen/lib/benchmark-lib.a"; sourceTree = ""; }; + 6FE7579C20D5A5E000F01636 /* benchmark-lib.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "benchmark-lib.a"; path = "$SRCROOT/../../../../../../../tensorflow/lite/tools/make/gen/lib/benchmark-lib.a"; sourceTree = ""; }; 6FE7579E20D5A6A700F01636 /* Accelerate.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Accelerate.framework; path = System/Library/Frameworks/Accelerate.framework; sourceTree = SDKROOT; }; 6FE757A020D5AB8000F01636 /* mobilenet_v1_1.0_224.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = mobilenet_v1_1.0_224.tflite; sourceTree = ""; }; 6FE93FF820D592D8008C9FE4 /* TFLiteBenchmark.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = TFLiteBenchmark.app; sourceTree = BUILT_PRODUCTS_DIR; }; @@ -310,18 +310,18 @@ CODE_SIGN_STYLE = Automatic; "HEADER_SEARCH_PATHS[arch=*]" = ( $SRCROOT/../../../../../../../, - $SRCROOT/../../../../../../../tensorflow/contrib/lite/tools/make/downloads/eigen, - $SRCROOT/../../../../../../../tensorflow/contrib/lite/tools/make/downloads/gemmlowp, - $SRCROOT/../../../../../../../tensorflow/contrib/lite/tools/make/downloads/neon_2_sse, - $SRCROOT/../../../../../../../tensorflow/contrib/lite/tools/make/downloads/farmhash/src, - $SRCROOT/../../../../../../../tensorflow/contrib/lite/tools/make/downloads/flatbuffers/include, + $SRCROOT/../../../../../../../tensorflow/lite/tools/make/downloads/eigen, + $SRCROOT/../../../../../../../tensorflow/lite/tools/make/downloads/gemmlowp, + $SRCROOT/../../../../../../../tensorflow/lite/tools/make/downloads/neon_2_sse, + $SRCROOT/../../../../../../../tensorflow/lite/tools/make/downloads/farmhash/src, + $SRCROOT/../../../../../../../tensorflow/lite/tools/make/downloads/flatbuffers/include, ); INFOPLIST_FILE = TFLiteBenchmark/Info.plist; LD_RUNPATH_SEARCH_PATHS = ( "$(inherited)", "@executable_path/Frameworks", ); - "LIBRARY_SEARCH_PATHS[arch=*]" = $SRCROOT/../../../../../../../tensorflow/contrib/lite/tools/make/gen/lib; + "LIBRARY_SEARCH_PATHS[arch=*]" = $SRCROOT/../../../../../../../tensorflow/lite/tools/make/gen/lib; PRODUCT_BUNDLE_IDENTIFIER = example.TFLiteBenchmark; PRODUCT_NAME = "$(TARGET_NAME)"; TARGETED_DEVICE_FAMILY = "1,2"; @@ -336,18 +336,18 @@ CODE_SIGN_STYLE = Automatic; "HEADER_SEARCH_PATHS[arch=*]" = ( $SRCROOT/../../../../../../../, - $SRCROOT/../../../../../../../tensorflow/contrib/lite/tools/make/downloads/eigen, - $SRCROOT/../../../../../../../tensorflow/contrib/lite/tools/make/downloads/gemmlowp, - $SRCROOT/../../../../../../../tensorflow/contrib/lite/tools/make/downloads/neon_2_sse, - $SRCROOT/../../../../../../../tensorflow/contrib/lite/tools/make/downloads/farmhash/src, - $SRCROOT/../../../../../../../tensorflow/contrib/lite/tools/make/downloads/flatbuffers/include, + $SRCROOT/../../../../../../../tensorflow/lite/tools/make/downloads/eigen, + $SRCROOT/../../../../../../../tensorflow/lite/tools/make/downloads/gemmlowp, + $SRCROOT/../../../../../../../tensorflow/lite/tools/make/downloads/neon_2_sse, + $SRCROOT/../../../../../../../tensorflow/lite/tools/make/downloads/farmhash/src, + $SRCROOT/../../../../../../../tensorflow/lite/tools/make/downloads/flatbuffers/include, ); INFOPLIST_FILE = TFLiteBenchmark/Info.plist; LD_RUNPATH_SEARCH_PATHS = ( "$(inherited)", "@executable_path/Frameworks", ); - "LIBRARY_SEARCH_PATHS[arch=*]" = $SRCROOT/../../../../../../../tensorflow/contrib/lite/tools/make/gen/lib; + "LIBRARY_SEARCH_PATHS[arch=*]" = $SRCROOT/../../../../../../../tensorflow/lite/tools/make/gen/lib; PRODUCT_BUNDLE_IDENTIFIER = example.TFLiteBenchmark; PRODUCT_NAME = "$(TARGET_NAME)"; TARGETED_DEVICE_FAMILY = "1,2"; diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.h b/tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.h similarity index 100% rename from tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.h rename to tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.h diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.m b/tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.m similarity index 100% rename from tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.m rename to tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.m diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/AppIcon.appiconset/Contents.json b/tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/AppIcon.appiconset/Contents.json similarity index 100% rename from tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/AppIcon.appiconset/Contents.json rename to tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/AppIcon.appiconset/Contents.json diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/Contents.json b/tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/Contents.json similarity index 100% rename from tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/Contents.json rename to tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/Contents.json diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/LaunchScreen.storyboard b/tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/LaunchScreen.storyboard similarity index 100% rename from tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/LaunchScreen.storyboard rename to tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/LaunchScreen.storyboard diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/Main.storyboard b/tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/Main.storyboard similarity index 100% rename from tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/Main.storyboard rename to tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/Main.storyboard diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.h b/tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.h similarity index 100% rename from tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.h rename to tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.h diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.mm b/tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.mm similarity index 97% rename from tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.mm rename to tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.mm index 356d5b0e17abc7..590c215f51546f 100644 --- a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.mm +++ b/tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.mm @@ -18,8 +18,8 @@ #import #import #import -#import "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h" -#import "tensorflow/contrib/lite/tools/benchmark/logging.h" +#import "tensorflow/lite/tools/benchmark/benchmark_tflite_model.h" +#import "tensorflow/lite/tools/benchmark/logging.h" namespace { NSString* FilePathForResourceName(NSString* filename) { diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Info.plist b/tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Info.plist similarity index 100% rename from tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Info.plist rename to tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Info.plist diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/benchmark_data/benchmark_params.json b/tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/benchmark_data/benchmark_params.json similarity index 100% rename from tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/benchmark_data/benchmark_params.json rename to tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/benchmark_data/benchmark_params.json diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/main.m b/tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/main.m similarity index 100% rename from tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/main.m rename to tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/main.m diff --git a/tensorflow/contrib/lite/tools/benchmark/logging.h b/tensorflow/lite/tools/benchmark/logging.h similarity index 92% rename from tensorflow/contrib/lite/tools/benchmark/logging.h rename to tensorflow/lite/tools/benchmark/logging.h index 4045d1e7311512..71dd511a080ecc 100644 --- a/tensorflow/contrib/lite/tools/benchmark/logging.h +++ b/tensorflow/lite/tools/benchmark/logging.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_LOGGING_H_ -#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_LOGGING_H_ +#ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_LOGGING_H_ +#define TENSORFLOW_LITE_TOOLS_BENCHMARK_LOGGING_H_ // LOG and CHECK macros for benchmarks. @@ -73,4 +73,4 @@ class LoggingWrapper { #define TFLITE_BENCHMARK_CHECK_EQ(a, b) TFLITE_BENCHMARK_CHECK(a == b) -#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_LOGGING_H_ +#endif // TENSORFLOW_LITE_TOOLS_BENCHMARK_LOGGING_H_ diff --git a/tensorflow/contrib/lite/tools/gen_op_registration.cc b/tensorflow/lite/tools/gen_op_registration.cc similarity index 93% rename from tensorflow/contrib/lite/tools/gen_op_registration.cc rename to tensorflow/lite/tools/gen_op_registration.cc index d80ea59170b4ed..ca66eef46607ae 100644 --- a/tensorflow/contrib/lite/tools/gen_op_registration.cc +++ b/tensorflow/lite/tools/gen_op_registration.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include "re2/re2.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/tools/gen_op_registration.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/tools/gen_op_registration.h" namespace tflite { diff --git a/tensorflow/contrib/lite/tools/gen_op_registration.h b/tensorflow/lite/tools/gen_op_registration.h similarity index 82% rename from tensorflow/contrib/lite/tools/gen_op_registration.h rename to tensorflow/lite/tools/gen_op_registration.h index 5f2ac6ca97fde9..a616720c934b9e 100644 --- a/tensorflow/contrib/lite/tools/gen_op_registration.h +++ b/tensorflow/lite/tools/gen_op_registration.h @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_ -#define TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_ +#ifndef TENSORFLOW_LITE_TOOLS_GEN_OP_REGISTRATION_H_ +#define TENSORFLOW_LITE_TOOLS_GEN_OP_REGISTRATION_H_ -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/string.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/string.h" namespace tflite { @@ -36,4 +36,4 @@ void ReadOpsFromModel(const ::tflite::Model* model, } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_ +#endif // TENSORFLOW_LITE_TOOLS_GEN_OP_REGISTRATION_H_ diff --git a/tensorflow/contrib/lite/tools/gen_op_registration_main.cc b/tensorflow/lite/tools/gen_op_registration_main.cc similarity index 98% rename from tensorflow/contrib/lite/tools/gen_op_registration_main.cc rename to tensorflow/lite/tools/gen_op_registration_main.cc index f7df80821fc383..090b709478d7e7 100644 --- a/tensorflow/contrib/lite/tools/gen_op_registration_main.cc +++ b/tensorflow/lite/tools/gen_op_registration_main.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include "absl/strings/strip.h" -#include "tensorflow/contrib/lite/tools/gen_op_registration.h" +#include "tensorflow/lite/tools/gen_op_registration.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" diff --git a/tensorflow/contrib/lite/tools/gen_op_registration_test.cc b/tensorflow/lite/tools/gen_op_registration_test.cc similarity index 88% rename from tensorflow/contrib/lite/tools/gen_op_registration_test.cc rename to tensorflow/lite/tools/gen_op_registration_test.cc index 28a98d68ab23a5..0ae91018ddf3db 100644 --- a/tensorflow/contrib/lite/tools/gen_op_registration_test.cc +++ b/tensorflow/lite/tools/gen_op_registration_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/gen_op_registration.h" +#include "tensorflow/lite/tools/gen_op_registration.h" #include #include @@ -43,25 +43,25 @@ TEST_F(GenOpRegistrationTest, TestNonExistantFiles) { } TEST_F(GenOpRegistrationTest, TestModels) { - ReadOps("tensorflow/contrib/lite/testdata/test_model.bin"); + ReadOps("tensorflow/lite/testdata/test_model.bin"); EXPECT_THAT(builtin_ops_, ElementsAreArray({"CONV_2D"})); EXPECT_THAT(custom_ops_, ElementsAreArray({"testing_op"})); } TEST_F(GenOpRegistrationTest, TestEmptyModels) { - ReadOps("tensorflow/contrib/lite/testdata/empty_model.bin"); + ReadOps("tensorflow/lite/testdata/empty_model.bin"); EXPECT_EQ(builtin_ops_.size(), 0); EXPECT_EQ(custom_ops_.size(), 0); } TEST_F(GenOpRegistrationTest, TestZeroSubgraphs) { - ReadOps("tensorflow/contrib/lite/testdata/0_subgraphs.bin"); + ReadOps("tensorflow/lite/testdata/0_subgraphs.bin"); EXPECT_EQ(builtin_ops_.size(), 0); EXPECT_EQ(custom_ops_.size(), 0); } TEST_F(GenOpRegistrationTest, TestBrokenMmap) { - ReadOps("tensorflow/contrib/lite/testdata/test_model_broken.bin"); + ReadOps("tensorflow/lite/testdata/test_model_broken.bin"); EXPECT_EQ(builtin_ops_.size(), 0); EXPECT_EQ(custom_ops_.size(), 0); } diff --git a/tensorflow/contrib/lite/tools/make/Makefile b/tensorflow/lite/tools/make/Makefile similarity index 77% rename from tensorflow/contrib/lite/tools/make/Makefile rename to tensorflow/lite/tools/make/Makefile index 16012a3fb16398..8f123558545723 100644 --- a/tensorflow/contrib/lite/tools/make/Makefile +++ b/tensorflow/lite/tools/make/Makefile @@ -20,7 +20,7 @@ endif HOST_ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32; else echo $(shell uname -m); fi) # Override these on the make command line to target a specific architecture. For example: -# make -f tensorflow/contrib/lite/Makefile TARGET=rpi TARGET_ARCH=armv7l +# make -f tensorflow/lite/Makefile TARGET=rpi TARGET_ARCH=armv7l TARGET := $(HOST_OS) TARGET_ARCH := $(HOST_ARCH) @@ -70,55 +70,55 @@ BENCHMARK_BINARY_NAME := benchmark_model # A small example program that shows how to link against the library. MINIMAL_SRCS := \ -tensorflow/contrib/lite/examples/minimal/minimal.cc +tensorflow/lite/examples/minimal/minimal.cc # What sources we want to compile, must be kept in sync with the main Bazel # build files. PROFILER_SRCS := \ - tensorflow/contrib/lite/profiling/time.cc + tensorflow/lite/profiling/time.cc PROFILE_SUMMARIZER_SRCS := \ - tensorflow/contrib/lite/profiling/profile_summarizer.cc \ + tensorflow/lite/profiling/profile_summarizer.cc \ tensorflow/core/util/stats_calculator.cc CORE_CC_ALL_SRCS := \ -$(wildcard tensorflow/contrib/lite/*.cc) \ -$(wildcard tensorflow/contrib/lite/*.c) \ -$(wildcard tensorflow/contrib/lite/c/*.c) \ -$(wildcard tensorflow/contrib/lite/core/api/*.cc) +$(wildcard tensorflow/lite/*.cc) \ +$(wildcard tensorflow/lite/*.c) \ +$(wildcard tensorflow/lite/c/*.c) \ +$(wildcard tensorflow/lite/core/api/*.cc) ifneq ($(BUILD_TYPE),micro) CORE_CC_ALL_SRCS += \ -$(wildcard tensorflow/contrib/lite/kernels/*.cc) \ -$(wildcard tensorflow/contrib/lite/kernels/internal/*.cc) \ -$(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.cc) \ -$(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.cc) \ +$(wildcard tensorflow/lite/kernels/*.cc) \ +$(wildcard tensorflow/lite/kernels/internal/*.cc) \ +$(wildcard tensorflow/lite/kernels/internal/optimized/*.cc) \ +$(wildcard tensorflow/lite/kernels/internal/reference/*.cc) \ $(PROFILER_SRCS) \ -$(wildcard tensorflow/contrib/lite/kernels/*.c) \ -$(wildcard tensorflow/contrib/lite/kernels/internal/*.c) \ -$(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.c) \ -$(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.c) \ -$(wildcard tensorflow/contrib/lite/tools/make/downloads/farmhash/src/farmhash.cc) \ -$(wildcard tensorflow/contrib/lite/tools/make/downloads/fft2d/fftsg.c) +$(wildcard tensorflow/lite/kernels/*.c) \ +$(wildcard tensorflow/lite/kernels/internal/*.c) \ +$(wildcard tensorflow/lite/kernels/internal/optimized/*.c) \ +$(wildcard tensorflow/lite/kernels/internal/reference/*.c) \ +$(wildcard tensorflow/lite/tools/make/downloads/farmhash/src/farmhash.cc) \ +$(wildcard tensorflow/lite/tools/make/downloads/fft2d/fftsg.c) endif # Remove any duplicates. CORE_CC_ALL_SRCS := $(sort $(CORE_CC_ALL_SRCS)) CORE_CC_EXCLUDE_SRCS := \ -$(wildcard tensorflow/contrib/lite/*test.cc) \ -$(wildcard tensorflow/contrib/lite/*/*test.cc) \ -$(wildcard tensorflow/contrib/lite/*/*/*test.cc) \ -$(wildcard tensorflow/contrib/lite/*/*/*/*test.cc) \ -$(wildcard tensorflow/contrib/lite/kernels/test_util.cc) \ +$(wildcard tensorflow/lite/*test.cc) \ +$(wildcard tensorflow/lite/*/*test.cc) \ +$(wildcard tensorflow/lite/*/*/*test.cc) \ +$(wildcard tensorflow/lite/*/*/*/*test.cc) \ +$(wildcard tensorflow/lite/kernels/test_util.cc) \ $(MINIMAL_SRCS) ifeq ($(BUILD_TYPE),micro) CORE_CC_EXCLUDE_SRCS += \ -tensorflow/contrib/lite/mmap_allocation.cc \ -tensorflow/contrib/lite/nnapi_delegate.cc +tensorflow/lite/mmap_allocation.cc \ +tensorflow/lite/nnapi_delegate.cc endif # Filter out all the excluded files. TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS)) # Benchmark sources -BENCHMARK_SRCS_DIR := tensorflow/contrib/lite/tools/benchmark +BENCHMARK_SRCS_DIR := tensorflow/lite/tools/benchmark BENCHMARK_ALL_SRCS := $(TFLITE_CC_SRCS) \ $(wildcard $(BENCHMARK_SRCS_DIR)/*.cc) \ $(PROFILE_SUMMARIZER_SRCS) @@ -180,11 +180,11 @@ all: $(LIB_PATH) $(MINIMAL_BINARY) $(BENCHMARK_BINARY) micro: $(LIB_PATH) # Hack for generating schema file bypassing flatbuffer parsing -tensorflow/contrib/lite/schema/schema_generated.h: - @cp -u tensorflow/contrib/lite/schema/schema_generated.h.OPENSOURCE tensorflow/contrib/lite/schema/schema_generated.h +tensorflow/lite/schema/schema_generated.h: + @cp -u tensorflow/lite/schema/schema_generated.h.OPENSOURCE tensorflow/lite/schema/schema_generated.h # Gathers together all the objects we've compiled into a single '.a' archive. -$(LIB_PATH): tensorflow/contrib/lite/schema/schema_generated.h $(LIB_OBJS) +$(LIB_PATH): tensorflow/lite/schema/schema_generated.h $(LIB_OBJS) @mkdir -p $(dir $@) $(AR) $(ARFLAGS) $(LIB_PATH) $(LIB_OBJS) diff --git a/tensorflow/contrib/lite/tools/make/build_ios_universal_lib.sh b/tensorflow/lite/tools/make/build_ios_universal_lib.sh similarity index 69% rename from tensorflow/contrib/lite/tools/make/build_ios_universal_lib.sh rename to tensorflow/lite/tools/make/build_ios_universal_lib.sh index fe056945a652b0..477883b49095b2 100755 --- a/tensorflow/contrib/lite/tools/make/build_ios_universal_lib.sh +++ b/tensorflow/lite/tools/make/build_ios_universal_lib.sh @@ -23,17 +23,17 @@ cd "$SCRIPT_DIR/../../../../.." make_library() { for arch in x86_64 armv7 armv7s arm64 do - make -f tensorflow/contrib/lite/tools/make/Makefile TARGET=ios TARGET_ARCH=${arch} \ + make -f tensorflow/lite/tools/make/Makefile TARGET=ios TARGET_ARCH=${arch} \ -j 8 done - mkdir -p tensorflow/contrib/lite/tools/make/gen/lib + mkdir -p tensorflow/lite/tools/make/gen/lib lipo \ - tensorflow/contrib/lite/tools/make/gen/ios_x86_64/lib/${1} \ - tensorflow/contrib/lite/tools/make/gen/ios_armv7/lib/${1} \ - tensorflow/contrib/lite/tools/make/gen/ios_armv7s/lib/${1} \ - tensorflow/contrib/lite/tools/make/gen/ios_arm64/lib/${1} \ + tensorflow/lite/tools/make/gen/ios_x86_64/lib/${1} \ + tensorflow/lite/tools/make/gen/ios_armv7/lib/${1} \ + tensorflow/lite/tools/make/gen/ios_armv7s/lib/${1} \ + tensorflow/lite/tools/make/gen/ios_arm64/lib/${1} \ -create \ - -output tensorflow/contrib/lite/tools/make/gen/lib/${1} + -output tensorflow/lite/tools/make/gen/lib/${1} } make_library libtensorflow-lite.a diff --git a/tensorflow/contrib/lite/tools/make/build_rpi_lib.sh b/tensorflow/lite/tools/make/build_rpi_lib.sh similarity index 87% rename from tensorflow/contrib/lite/tools/make/build_rpi_lib.sh rename to tensorflow/lite/tools/make/build_rpi_lib.sh index 24ecd4356df12c..d4047bb0eb5071 100755 --- a/tensorflow/contrib/lite/tools/make/build_rpi_lib.sh +++ b/tensorflow/lite/tools/make/build_rpi_lib.sh @@ -19,4 +19,4 @@ set -e SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" cd "$SCRIPT_DIR/../../../../.." -CC_PREFIX=arm-linux-gnueabihf- make -j 3 -f tensorflow/contrib/lite/tools/make/Makefile TARGET=rpi TARGET_ARCH=armv7l +CC_PREFIX=arm-linux-gnueabihf- make -j 3 -f tensorflow/lite/tools/make/Makefile TARGET=rpi TARGET_ARCH=armv7l diff --git a/tensorflow/contrib/lite/tools/make/download_dependencies.sh b/tensorflow/lite/tools/make/download_dependencies.sh similarity index 98% rename from tensorflow/contrib/lite/tools/make/download_dependencies.sh rename to tensorflow/lite/tools/make/download_dependencies.sh index 3570f9a38d3fdc..aa5495329b1057 100755 --- a/tensorflow/contrib/lite/tools/make/download_dependencies.sh +++ b/tensorflow/lite/tools/make/download_dependencies.sh @@ -19,7 +19,7 @@ set -e SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" cd "$SCRIPT_DIR/../../../../.." -DOWNLOADS_DIR=tensorflow/contrib/lite/tools/make/downloads +DOWNLOADS_DIR=tensorflow/lite/tools/make/downloads BZL_FILE_PATH=tensorflow/workspace.bzl # Ensure it is being run from repo root diff --git a/tensorflow/contrib/lite/tools/make/targets/ios_makefile.inc b/tensorflow/lite/tools/make/targets/ios_makefile.inc similarity index 100% rename from tensorflow/contrib/lite/tools/make/targets/ios_makefile.inc rename to tensorflow/lite/tools/make/targets/ios_makefile.inc diff --git a/tensorflow/contrib/lite/tools/make/targets/linux_makefile.inc b/tensorflow/lite/tools/make/targets/linux_makefile.inc similarity index 100% rename from tensorflow/contrib/lite/tools/make/targets/linux_makefile.inc rename to tensorflow/lite/tools/make/targets/linux_makefile.inc diff --git a/tensorflow/contrib/lite/tools/make/targets/riscv_makefile.inc b/tensorflow/lite/tools/make/targets/riscv_makefile.inc similarity index 100% rename from tensorflow/contrib/lite/tools/make/targets/riscv_makefile.inc rename to tensorflow/lite/tools/make/targets/riscv_makefile.inc diff --git a/tensorflow/contrib/lite/tools/make/targets/rpi_makefile.inc b/tensorflow/lite/tools/make/targets/rpi_makefile.inc similarity index 100% rename from tensorflow/contrib/lite/tools/make/targets/rpi_makefile.inc rename to tensorflow/lite/tools/make/targets/rpi_makefile.inc diff --git a/tensorflow/contrib/lite/tools/make/targets/stm32f1_makefile.inc b/tensorflow/lite/tools/make/targets/stm32f1_makefile.inc similarity index 100% rename from tensorflow/contrib/lite/tools/make/targets/stm32f1_makefile.inc rename to tensorflow/lite/tools/make/targets/stm32f1_makefile.inc diff --git a/tensorflow/contrib/lite/tools/make/targets/stm32f7_makefile.inc b/tensorflow/lite/tools/make/targets/stm32f7_makefile.inc similarity index 100% rename from tensorflow/contrib/lite/tools/make/targets/stm32f7_makefile.inc rename to tensorflow/lite/tools/make/targets/stm32f7_makefile.inc diff --git a/tensorflow/contrib/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD similarity index 67% rename from tensorflow/contrib/lite/tools/optimize/BUILD rename to tensorflow/lite/tools/optimize/BUILD index 51ccaedc23d0ab..0a0d5cc4123ba6 100644 --- a/tensorflow/contrib/lite/tools/optimize/BUILD +++ b/tensorflow/lite/tools/optimize/BUILD @@ -8,17 +8,17 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") +load("//tensorflow/lite:build_def.bzl", "tflite_copts") cc_library( name = "quantize_weights", srcs = ["quantize_weights.cc"], hdrs = ["quantize_weights.h"], deps = [ - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels/internal:tensor_utils", - "//tensorflow/contrib/lite/schema:schema_fbs", "//tensorflow/core:tflite_portable_logging", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels/internal:tensor_utils", + "//tensorflow/lite/schema:schema_fbs", "@com_google_absl//absl/memory", "@flatbuffers", ], diff --git a/tensorflow/contrib/lite/tools/optimize/g3doc/quantize_weights.md b/tensorflow/lite/tools/optimize/g3doc/quantize_weights.md similarity index 96% rename from tensorflow/contrib/lite/tools/optimize/g3doc/quantize_weights.md rename to tensorflow/lite/tools/optimize/g3doc/quantize_weights.md index 93fe576583eaaf..2517882c84c307 100644 --- a/tensorflow/contrib/lite/tools/optimize/g3doc/quantize_weights.md +++ b/tensorflow/lite/tools/optimize/g3doc/quantize_weights.md @@ -3,7 +3,7 @@ ## Recommended usage The Quantize Weights transformation is integrated with -[tflite_convert](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md#transformation-flags). +[tflite_convert](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/toco/g3doc/cmdline_reference.md#transformation-flags). The recommended way of invoking this tool is by simply adding the `--post_training_quantize` flag to your original tflite_convert invocation. For diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc b/tensorflow/lite/tools/optimize/quantize_weights.cc similarity index 94% rename from tensorflow/contrib/lite/tools/optimize/quantize_weights.cc rename to tensorflow/lite/tools/optimize/quantize_weights.cc index d02d78bf53a145..de3c0b03237c1c 100644 --- a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc +++ b/tensorflow/lite/tools/optimize/quantize_weights.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/optimize/quantize_weights.h" +#include "tensorflow/lite/tools/optimize/quantize_weights.h" #include #include @@ -21,10 +21,10 @@ limitations under the License. #include "flatbuffers/flexbuffers.h" #include "absl/memory/memory.h" -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/core/platform/logging.h" namespace tflite { @@ -110,24 +110,24 @@ std::vector GetWeightInputIndices(const BuiltinOperator& op_code) { op_code == BuiltinOperator_EMBEDDING_LOOKUP) { return {1}; } else if (op_code == BuiltinOperator_SVDF) { - // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/svdf.cc + // https://www.tensorflow.org/code/tensorflow/lite/kernels/svdf.cc return {1, 2}; } else if (op_code == BuiltinOperator_LSTM || op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM) { - // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/lstm.cc - // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc + // https://www.tensorflow.org/code/tensorflow/lite/kernels/lstm.cc + // https://www.tensorflow.org/code/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16}; } else if (op_code == BuiltinOperator_RNN || op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) { - // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/basic_rnn.cc - // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc + // https://www.tensorflow.org/code/tensorflow/lite/kernels/basic_rnn.cc + // https://www.tensorflow.org/code/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc return {1, 2}; } else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM) { - // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc + // https://www.tensorflow.org/code/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 33}; } else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN) { - // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc + // https://www.tensorflow.org/code/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc return {1, 2, 4, 5}; } return {}; @@ -182,8 +182,7 @@ std::vector GetQuantizableTensorsFromOperator( TensorT* tensor = subgraph->tensors[tensor_idx].get(); // TODO(suharshs): Support shared weights, i.e. If two tensors share the // same weight array, things may break. (i.e. SSD object detection) - if (!eval_hybrid && - CountTensorConsumers(model, subgraph, tensor_idx) != 1) { + if (CountTensorConsumers(model, subgraph, tensor_idx) != 1) { LOG(INFO) << "Skipping quantization of tensor " << tensor->name << " that is shared between multiple multiple operations."; continue; diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.h b/tensorflow/lite/tools/optimize/quantize_weights.h similarity index 85% rename from tensorflow/contrib/lite/tools/optimize/quantize_weights.h rename to tensorflow/lite/tools/optimize/quantize_weights.h index 706f10b87b166c..c2c0b0ce83435d 100644 --- a/tensorflow/contrib/lite/tools/optimize/quantize_weights.h +++ b/tensorflow/lite/tools/optimize/quantize_weights.h @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_OPTIMIZE_QUANTIZE_WEIGHTS_H_ -#define TENSORFLOW_CONTRIB_LITE_TOOLS_OPTIMIZE_QUANTIZE_WEIGHTS_H_ +#ifndef TENSORFLOW_LITE_TOOLS_OPTIMIZE_QUANTIZE_WEIGHTS_H_ +#define TENSORFLOW_LITE_TOOLS_OPTIMIZE_QUANTIZE_WEIGHTS_H_ #include #include "flatbuffers/flexbuffers.h" -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace optimize { @@ -54,4 +54,4 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, } // namespace optimize } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_OPTIMIZE_QUANTIZE_WEIGHTS_H_ +#endif // TENSORFLOW_LITE_TOOLS_OPTIMIZE_QUANTIZE_WEIGHTS_H_ diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc b/tensorflow/lite/tools/optimize/quantize_weights_test.cc similarity index 95% rename from tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc rename to tensorflow/lite/tools/optimize/quantize_weights_test.cc index 387b3471c2c4c5..32725e5ee29c36 100644 --- a/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_weights_test.cc @@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/optimize/quantize_weights.h" +#include "tensorflow/lite/tools/optimize/quantize_weights.h" #include #include "flatbuffers/flexbuffers.h" #include #include -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace optimize { @@ -160,7 +160,7 @@ class QuantizeWeightsTest : public ::testing::Test { TEST_F(QuantizeWeightsTest, SimpleTestWithHybrid) { string model_path = - "third_party/tensorflow/contrib/lite/tools/optimize/testdata/" + "third_party/tensorflow/lite/tools/optimize/testdata/" "mobilenet_v1_0.25_128.tflite"; std::unique_ptr input_fb = FlatBufferModel::BuildFromFile(model_path.data()); @@ -177,7 +177,7 @@ TEST_F(QuantizeWeightsTest, SimpleTestWithHybrid) { TEST_F(QuantizeWeightsTest, SimpleTestWithoutHybrid) { string model_path = - "third_party/tensorflow/contrib/lite/tools/optimize/testdata/" + "third_party/tensorflow/lite/tools/optimize/testdata/" "mobilenet_v1_0.25_128.tflite"; std::unique_ptr input_fb = FlatBufferModel::BuildFromFile(model_path.data()); @@ -195,7 +195,7 @@ TEST_F(QuantizeWeightsTest, SimpleTestWithoutHybrid) { TEST_F(QuantizeWeightsTest, SimpleTestWithWeightsMinNumElements) { string model_path = - "third_party/tensorflow/contrib/lite/tools/optimize/testdata/" + "third_party/tensorflow/lite/tools/optimize/testdata/" "mobilenet_v1_0.25_128.tflite"; std::unique_ptr input_fb = FlatBufferModel::BuildFromFile(model_path.data()); diff --git a/tensorflow/contrib/lite/tools/verifier.cc b/tensorflow/lite/tools/verifier.cc similarity index 97% rename from tensorflow/contrib/lite/tools/verifier.cc rename to tensorflow/lite/tools/verifier.cc index 8d3a7a624265ca..02d6e6b23cdd66 100644 --- a/tensorflow/contrib/lite/tools/verifier.cc +++ b/tensorflow/lite/tools/verifier.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/verifier.h" +#include "tensorflow/lite/tools/verifier.h" #include -#include "tensorflow/contrib/lite/schema/schema_generated.h" -#include "tensorflow/contrib/lite/string_util.h" -#include "tensorflow/contrib/lite/version.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/string_util.h" +#include "tensorflow/lite/version.h" namespace tflite { diff --git a/tensorflow/contrib/lite/tools/verifier.h b/tensorflow/lite/tools/verifier.h similarity index 87% rename from tensorflow/contrib/lite/tools/verifier.h rename to tensorflow/lite/tools/verifier.h index a596c650a0c253..50b6432d4e3d82 100644 --- a/tensorflow/contrib/lite/tools/verifier.h +++ b/tensorflow/lite/tools/verifier.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_VERIFIER_H_ -#define TENSORFLOW_CONTRIB_LITE_TOOLS_VERIFIER_H_ +#ifndef TENSORFLOW_LITE_TOOLS_VERIFIER_H_ +#define TENSORFLOW_LITE_TOOLS_VERIFIER_H_ #include -#include "tensorflow/contrib/lite/error_reporter.h" -#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/lite/error_reporter.h" +#include "tensorflow/lite/model.h" namespace tflite { @@ -49,4 +49,4 @@ bool Verify(const void* buf, size_t len, const OpResolver& resolver, } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_VERIFIER_H_ +#endif // TENSORFLOW_LITE_TOOLS_VERIFIER_H_ diff --git a/tensorflow/contrib/lite/tools/verifier_test.cc b/tensorflow/lite/tools/verifier_test.cc similarity index 96% rename from tensorflow/contrib/lite/tools/verifier_test.cc rename to tensorflow/lite/tools/verifier_test.cc index ad7d59ecb41a0c..98abafad927ae4 100644 --- a/tensorflow/contrib/lite/tools/verifier_test.cc +++ b/tensorflow/lite/tools/verifier_test.cc @@ -18,13 +18,13 @@ limitations under the License. #include "flatbuffers/flatbuffers.h" #include "flatbuffers/util.h" #include -#include "tensorflow/contrib/lite/allocation.h" -#include "tensorflow/contrib/lite/error_reporter.h" -#include "tensorflow/contrib/lite/op_resolver.h" -#include "tensorflow/contrib/lite/schema/schema_generated.h" -#include "tensorflow/contrib/lite/testing/util.h" -#include "tensorflow/contrib/lite/tools/verifier.h" -#include "tensorflow/contrib/lite/version.h" +#include "tensorflow/lite/allocation.h" +#include "tensorflow/lite/error_reporter.h" +#include "tensorflow/lite/op_resolver.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/testing/util.h" +#include "tensorflow/lite/tools/verifier.h" +#include "tensorflow/lite/version.h" #include "tensorflow/core/framework/numeric_types.h" namespace tflite { diff --git a/tensorflow/contrib/lite/tools/visualize.py b/tensorflow/lite/tools/visualize.py similarity index 98% rename from tensorflow/contrib/lite/tools/visualize.py rename to tensorflow/lite/tools/visualize.py index bf8889aa2ab856..53bb67e3fedbda 100644 --- a/tensorflow/contrib/lite/tools/visualize.py +++ b/tensorflow/lite/tools/visualize.py @@ -31,15 +31,15 @@ from tensorflow.python.platform import resource_loader # Schema to use for flatbuffers -_SCHEMA = "third_party/tensorflow/contrib/lite/schema/schema.fbs" +_SCHEMA = "third_party/tensorflow/lite/schema/schema.fbs" # TODO(angerson): fix later when rules are simplified.. _SCHEMA = resource_loader.get_path_to_datafile("../schema/schema.fbs") -_BINARY = resource_loader.get_path_to_datafile("../../../../flatbuffers/flatc") +_BINARY = resource_loader.get_path_to_datafile("../../../flatbuffers/flatc") # Account for different package positioning internal vs. external. if not os.path.exists(_BINARY): _BINARY = resource_loader.get_path_to_datafile( - "../../../../../flatbuffers/flatc") + "../../../../flatbuffers/flatc") if not os.path.exists(_SCHEMA): raise RuntimeError("Sorry, schema file cannot be found at %r" % _SCHEMA) diff --git a/tensorflow/contrib/lite/tutorials/BUILD b/tensorflow/lite/tutorials/BUILD similarity index 100% rename from tensorflow/contrib/lite/tutorials/BUILD rename to tensorflow/lite/tutorials/BUILD diff --git a/tensorflow/contrib/lite/tutorials/dataset.py b/tensorflow/lite/tutorials/dataset.py similarity index 100% rename from tensorflow/contrib/lite/tutorials/dataset.py rename to tensorflow/lite/tutorials/dataset.py diff --git a/tensorflow/contrib/lite/tutorials/mnist_tflite.py b/tensorflow/lite/tutorials/mnist_tflite.py similarity index 95% rename from tensorflow/contrib/lite/tutorials/mnist_tflite.py rename to tensorflow/lite/tutorials/mnist_tflite.py index 7b8bf5b5dbc846..002365717fce9e 100644 --- a/tensorflow/contrib/lite/tutorials/mnist_tflite.py +++ b/tensorflow/lite/tutorials/mnist_tflite.py @@ -19,7 +19,7 @@ from __future__ import print_function import numpy as np import tensorflow as tf # pylint: disable=g-bad-import-order -from tensorflow.contrib.lite.tutorials import dataset +from tensorflow.lite.tutorials import dataset flags = tf.app.flags flags.DEFINE_string('data_dir', '/tmp/data_dir', @@ -69,7 +69,7 @@ def run_eval(interpreter, input_image): def main(_): - interpreter = tf.contrib.lite.Interpreter(model_path=flags.model_file) + interpreter = tf.lite.Interpreter(model_path=flags.model_file) interpreter.allocate_tensors() num_correct, total = 0, 0 for input_data in test_image_generator(): diff --git a/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb b/tensorflow/lite/tutorials/post_training_quant.ipynb similarity index 95% rename from tensorflow/contrib/lite/tutorials/post_training_quant.ipynb rename to tensorflow/lite/tutorials/post_training_quant.ipynb index 80cdb2f080ba51..3ff145d9ce9291 100644 --- a/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb +++ b/tensorflow/lite/tutorials/post_training_quant.ipynb @@ -19,10 +19,10 @@ "source": [ "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", " \u003ctd\u003e\n", - " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/tutorials/post_training_quant.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", " \u003c/td\u003e\n", " \u003ctd\u003e\n", - " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", + " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tutorials/post_training_quant.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", " \u003c/td\u003e\n", "\u003c/table\u003e" ] @@ -252,7 +252,7 @@ "source": [ "import tensorflow as tf\n", "tf.enable_eager_execution()\n", - "converter = tf.contrib.lite.TocoConverter.from_saved_model(saved_model_dir)\n", + "converter = tf.lite.TocoConverter.from_saved_model(saved_model_dir)\n", "tflite_model = converter.convert()" ] }, @@ -386,7 +386,7 @@ "images, labels = tf.to_float(mnist_test[0])/255.0, mnist_test[1]\n", "\n", "# Note: If you change the batch size, then use \n", - "# `tf.contrib.lite.Interpreter.resize_tensor_input` to also change it for\n", + "# `tf.lite.Interpreter.resize_tensor_input` to also change it for\n", "# the interpreter.\n", "mnist_ds = tf.data.Dataset.from_tensor_slices((images, labels)).batch(1)" ] @@ -411,7 +411,7 @@ }, "outputs": [], "source": [ - "interpreter = tf.contrib.lite.Interpreter(model_path=str(tflite_model_file))\n", + "interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))\n", "interpreter.allocate_tensors()\n", "input_index = interpreter.get_input_details()[0][\"index\"]\n", "output_index = interpreter.get_output_details()[0][\"index\"]" @@ -428,7 +428,7 @@ "outputs": [], "source": [ "tf.logging.set_verbosity(tf.logging.DEBUG)\n", - "interpreter_quant = tf.contrib.lite.Interpreter(model_path=str(tflite_model_quant_file))" + "interpreter_quant = tf.lite.Interpreter(model_path=str(tflite_model_quant_file))" ] }, { @@ -592,7 +592,7 @@ "\n", "We now consider another example. Resnets with pre-activation layers (Resnet-v2) are widely used for vision applications.\n", " Pre-trained frozen graph for resnet-v2-101 is available at the\n", - " [Tensorflow Lite model repository](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/models.md).\n", + " [Tensorflow Lite model repository](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/models.md).\n", "\n", "We can convert the frozen graph to a TFLite flatbuffer with quantization by:\n" ] @@ -648,7 +648,7 @@ "graph_def_file = pathlib.Path(archive_path).parent/\"resnet_v2_101_299_frozen.pb\"\n", "input_arrays = [\"input\"] \n", "output_arrays = [\"output\"]\n", - "converter = tf.contrib.lite.TocoConverter.from_frozen_graph(\n", + "converter = tf.lite.TocoConverter.from_frozen_graph(\n", " str(graph_def_file), input_arrays, output_arrays, input_shapes={\"input\":[1,299,299,3]})\n", "converter.post_training_quantize = True\n", "resnet_tflite_file = graph_def_file.parent/\"resnet_v2_101_quantized.tflite\"\n", @@ -678,7 +678,7 @@ "source": [ "\n", "The model size reduces from 171 MB to 43 MB.\n", - "The accuracy of this model on imagenet can be evaluated using the scripts provided for [TFLite accuracy measurement](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/accuracy/ilsvrc).\n", + "The accuracy of this model on imagenet can be evaluated using the scripts provided for [TFLite accuracy measurement](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/accuracy/ilsvrc).\n", "\n", "The optimized model top-1 accuracy is 76.8, the same as the floating point model." ] diff --git a/tensorflow/contrib/lite/util.cc b/tensorflow/lite/util.cc similarity index 97% rename from tensorflow/contrib/lite/util.cc rename to tensorflow/lite/util.cc index 6aa35b52277910..866e4ebb0aa83a 100644 --- a/tensorflow/contrib/lite/util.cc +++ b/tensorflow/lite/util.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/util.h" +#include "tensorflow/lite/util.h" #include diff --git a/tensorflow/contrib/lite/util.h b/tensorflow/lite/util.h similarity index 92% rename from tensorflow/contrib/lite/util.h rename to tensorflow/lite/util.h index 31292a6f8131f7..64a5b52e2f982b 100644 --- a/tensorflow/contrib/lite/util.h +++ b/tensorflow/lite/util.h @@ -18,11 +18,11 @@ limitations under the License. // Flatbuffer vectors. These functions can't live in `context.h` since it's pure // C. -#ifndef TENSORFLOW_CONTRIB_LITE_UTIL_H_ -#define TENSORFLOW_CONTRIB_LITE_UTIL_H_ +#ifndef TENSORFLOW_LITE_UTIL_H_ +#define TENSORFLOW_LITE_UTIL_H_ #include -#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/lite/c/c_api_internal.h" namespace tflite { @@ -54,4 +54,4 @@ size_t CombineHashes(std::initializer_list hashes); } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_UTIL_H_ +#endif // TENSORFLOW_LITE_UTIL_H_ diff --git a/tensorflow/contrib/lite/util_test.cc b/tensorflow/lite/util_test.cc similarity index 94% rename from tensorflow/contrib/lite/util_test.cc rename to tensorflow/lite/util_test.cc index 25f3aded7140ff..606d24274770d2 100644 --- a/tensorflow/contrib/lite/util_test.cc +++ b/tensorflow/lite/util_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/util.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/util.h" namespace tflite { namespace { diff --git a/tensorflow/contrib/lite/version.h b/tensorflow/lite/version.h similarity index 87% rename from tensorflow/contrib/lite/version.h rename to tensorflow/lite/version.h index efd63f4006ae66..639d5a336a1794 100644 --- a/tensorflow/contrib/lite/version.h +++ b/tensorflow/lite/version.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_VERSION_H_ -#define TENSORFLOW_CONTRIB_LITE_VERSION_H_ +#ifndef TENSORFLOW_LITE_VERSION_H_ +#define TENSORFLOW_LITE_VERSION_H_ // The version number of the Schema. Ideally all changes will be backward // compatible. If that ever changes, we must ensure that version is the first // entry in the new tflite root so that we can see that version is not 1. #define TFLITE_SCHEMA_VERSION (3) -#endif // TENSORFLOW_CONTRIB_LITE_VERSION_H_ +#endif // TENSORFLOW_LITE_VERSION_H_ diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index cf3cba4fda098b..0d06c49f7c7783 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -7,10 +7,11 @@ visibility = [ "//engedu/ml/tf_from_scratch:__pkg__", "//tensorflow:internal", - "//tensorflow/contrib/lite/toco/python:__pkg__", + "//tensorflow/lite/toco/python:__pkg__", "//tensorflow_models:__subpackages__", + "//tensorflow_model_optimization:__subpackages__", # TODO(aselle): to pass open source test. - "//bazel_pip/tensorflow/contrib/lite/toco/python:__pkg__", + "//bazel_pip/tensorflow/lite/toco/python:__pkg__", ] package(default_visibility = visibility) @@ -58,7 +59,7 @@ py_library( "//tensorflow/compiler/aot/tests:__pkg__", # TODO(b/34059704): remove when fixed "//tensorflow/contrib/learn:__pkg__", # TODO(b/34059704): remove when fixed "//tensorflow/contrib/learn/python/learn/datasets:__pkg__", # TODO(b/34059704): remove when fixed - "//tensorflow/contrib/lite/toco/python:__pkg__", # TODO(b/34059704): remove when fixed + "//tensorflow/lite/toco/python:__pkg__", # TODO(b/34059704): remove when fixed "//tensorflow/python/debug:__pkg__", # TODO(b/34059704): remove when fixed "//tensorflow/python/tools:__pkg__", # TODO(b/34059704): remove when fixed "//tensorflow/tools/quantization:__pkg__", # TODO(b/34059704): remove when fixed @@ -76,6 +77,7 @@ py_library( srcs_version = "PY2AND3", visibility = [ "//tensorflow:__pkg__", + "//tensorflow/python/estimator:__subpackages__", "//tensorflow/python/tools:__pkg__", "//tensorflow/python/tools/api/generator:__pkg__", "//tensorflow/tools/api/tests:__pkg__", @@ -136,6 +138,7 @@ py_library( ":weights_broadcast_ops", ":while_v2", "//tensorflow/core:protos_all_py", + "//tensorflow/lite/python:lite", "//tensorflow/python/compat", "//tensorflow/python/data", "//tensorflow/python/distribute:estimator_training", @@ -2036,6 +2039,8 @@ py_library( ], ) +# Note: targets depending on this should also depend on ":cond_v2" and ":while_v2". +# See b/118513001. py_library( name = "control_flow_ops", srcs = ["ops/control_flow_ops.py"], @@ -3221,6 +3226,7 @@ cuda_py_test( ":variable_scope", "//third_party/py/numpy", ], + tags = ["no_oss"], # b/118709825 ) cuda_py_test( diff --git a/tensorflow/python/autograph/__init__.py b/tensorflow/python/autograph/__init__.py index fd9e60bea75fcd..7252e0d9bf92e4 100644 --- a/tensorflow/python/autograph/__init__.py +++ b/tensorflow/python/autograph/__init__.py @@ -26,6 +26,7 @@ from tensorflow.python.autograph import utils from tensorflow.python.autograph.core.converter import ConversionOptions from tensorflow.python.autograph.core.converter import Feature +from tensorflow.python.autograph.core.converter import Verbosity from tensorflow.python.autograph.core.errors import GraphConstructionError from tensorflow.python.autograph.core.errors import improved_errors from tensorflow.python.autograph.core.errors import TfRuntimeError @@ -58,6 +59,7 @@ 'improved_errors', 'GraphConstructionError', 'TfRuntimeError', + 'Verbosity', # Python language "extensions" 'set_element_type', 'set_loop_options', diff --git a/tensorflow/python/autograph/core/converter.py b/tensorflow/python/autograph/core/converter.py index 59b9ebb591865b..bc366123096ed2 100644 --- a/tensorflow/python/autograph/core/converter.py +++ b/tensorflow/python/autograph/core/converter.py @@ -64,6 +64,7 @@ from __future__ import print_function from enum import Enum +from enum import IntEnum from tensorflow.python.autograph.core import config from tensorflow.python.autograph.core import naming @@ -89,6 +90,17 @@ # TODO(mdan): Add a test specific to this converter. +class Verbosity(IntEnum): + """Different levels of verbosity for printing errors. + + Attributes: + * BRIEF: No logging, minimal error messages. + * VERBOSE: Detailed logging of generated code, detailed error messages. + """ + BRIEF = 0 + VERBOSE = 1 + + class Feature(Enum): """Constants to use when selecting AutoGraph features.""" @@ -111,7 +123,7 @@ class ConversionOptions(object): Attributes: recursive: bool, whether to recursively convert any user functions or classes that the converted function may use. - verbose: bool, whether to log the converted code. + verbose: Verbosity, the level of verbosity to use. strip_decorators: Tuple[Callable], contains decorators that should be in excluded from the compiled output. By default, when converting a function before the decorators are applied, the compiled output will include those @@ -126,7 +138,7 @@ class ConversionOptions(object): def __init__(self, recursive=False, - verbose=False, + verbose=Verbosity.VERBOSE, strip_decorators=None, force_conversion=False, internal_convert_user_code=True, @@ -197,7 +209,7 @@ def list_of_features(values): constructor_name=parser.parse_expression( as_qualified_name(ConversionOptions)), recursive_val=parser.parse_expression(str(self.recursive)), - verbose_val=parser.parse_expression(str(self.verbose)), + verbose_val=parser.parse_expression(str(int(self.verbose))), strip_decorators_val=list_of_names(self.strip_decorators), force_conversion_val=parser.parse_expression( str(self.force_conversion)), diff --git a/tensorflow/python/autograph/impl/BUILD b/tensorflow/python/autograph/impl/BUILD index bef62a640384bd..2f9037c43b6452 100644 --- a/tensorflow/python/autograph/impl/BUILD +++ b/tensorflow/python/autograph/impl/BUILD @@ -31,6 +31,7 @@ py_library( "//tensorflow/python/autograph/pyct", "//tensorflow/python/autograph/pyct/static_analysis", "//tensorflow/python/autograph/utils", + "//third_party/py/numpy", "@gast_archive//:gast", "@six_archive//:six", ], diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index dcee861f826bd3..123d289739078e 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -19,8 +19,15 @@ from __future__ import print_function import functools +import sys + from enum import Enum +# pylint:disable=g-bad-import-order +import numpy as np +# pylint:enable=g-bad-import-order + + from tensorflow.python.autograph.core import config from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.impl import conversion @@ -28,6 +35,8 @@ from tensorflow.python.autograph.pyct import compiler from tensorflow.python.autograph.pyct import inspect_utils from tensorflow.python.autograph.utils import py_func +from tensorflow.python.data.util import nest +from tensorflow.python.framework import tensor_util from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect @@ -38,7 +47,9 @@ # TODO(mdan): This should behave like to_graph (e.g. convert statically). -def convert(recursive=False, verbose=False): +# TODO(znado): Make an alias so can write Verbosity directly without needing +# to write converter. +def convert(recursive=False, verbose=converter.Verbosity.VERBOSE): """Decorator that compiles a function to use TensorFlow ops. The decorator is dynamic - it recompiles the target whenever the decorated @@ -49,7 +60,7 @@ def convert(recursive=False, verbose=False): Args: recursive: bool, whether to recursively convert any functions or classes that the converted function may use. - verbose: bool, whether to output the compiled code in the logs. + verbose: converter.Verbosity, the level of verbosity. Returns: Callable, a decorator that converts the given function into an equivalent @@ -83,8 +94,7 @@ def wrapper(*args, **kwargs): class RunMode(Enum): """Specifies the way a converted function or method should be executed in TF. - The enum values have the following semantics: - + Attributes: * GRAPH: Call this function directly, as-is. This is suitable for functions that were already designed for TF graphs and contain ops. * PY_FUNC: Wrap this function into a py_func op. This is suitable for code @@ -144,7 +154,7 @@ def py_func_wrapper(*args, **kwargs): # TODO(mdan): Move to a private, undocumented module. def converted_call(f, owner, options, *args, **kwargs): """Compiles a function call inline. For internal use only.""" - if options.verbose: + if options.verbose >= converter.Verbosity.VERBOSE: logging.info('Converted call: {}; owner: {}'.format(f, owner)) if owner is not None: @@ -243,7 +253,30 @@ def converted_call(f, owner, options, *args, **kwargs): partial_types=partial_types, strip_decorators=options.strip_decorators, optional_features=options.optional_features) - return converted_f(*effective_args, **kwargs) + + result = converted_f(*effective_args, **kwargs) + # When converting a function, we write a tmp file and import it as a module. + # This leaks the module's closure. Once we've executed the converted_f module + # and there is no more code left to be executed, we can clean up the module. + + # TODO(mdan): Look into workarounds that don't suffer from refcount leaks. + # Possibly attach the closure as a regular closure cell, instead of relying on + # module globals. + + # If there are callables in the result, they will fail to find their closure + # when called, so only delete module if all returned types are not callable. + flat_results = nest.flatten(result) + if all(map(_is_not_callable, flat_results)): + del sys.modules[converted_f.__module__] + + return result + + +def _is_not_callable(obj): + # TODO(brianklee): What happens if obj is a tensor wrapping a py_func? + return (isinstance(obj, + (int, float, complex, str, bool, np.ndarray, np.generic)) + or tensor_util.is_tensor(obj)) # TODO(mdan): Rename: to_ops? @@ -251,7 +284,7 @@ def converted_call(f, owner, options, *args, **kwargs): # TODO(mdan): Remove partial_types. def to_graph(e, recursive=True, - verbose=False, + verbose=converter.Verbosity.VERBOSE, arg_values=None, arg_types=None, partial_types=None, @@ -269,7 +302,7 @@ def to_graph(e, e: Union[Callable, Type], the Python entity to convert. recursive: bool, whether to recursively convert any functions that the converted function may call. - verbose: bool, whether to output the compiled code in the logs. + verbose: converter.Verbosity, the level of printing verbosity to use. arg_values: Optional[Dict[Text, Any]], value hints for symbols including function arguments. arg_types: Optional[Dict[Text, Type]], type hints for symbols including diff --git a/tensorflow/python/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py index e7049c3e5e8641..276fb8748fe463 100644 --- a/tensorflow/python/autograph/impl/api_test.py +++ b/tensorflow/python/autograph/impl/api_test.py @@ -18,6 +18,8 @@ from __future__ import division from __future__ import print_function +import gc + import numpy as np from tensorflow.python.autograph import utils @@ -32,6 +34,10 @@ tf = utils.fake_tf() +class TestResource(str): + pass + + class ApiTest(test.TestCase): def test_decorator_recurses(self): @@ -360,6 +366,39 @@ def test_fn(y): self.assertTrue(hasattr(api.to_graph(test_fn), 'ag_source_map')) + def assertNoMemoryLeaks(self, target_f): + refs_before = set(id(obj) for obj in gc.get_objects()) + target_f() + gc.collect() + objs_after = [obj for obj in gc.get_objects() if id(obj) not in refs_before] + leaked = [obj for obj in objs_after if isinstance(obj, TestResource)] + self.assertFalse(leaked, + 'Resources {} were leaked by AutoGraph.'.format(leaked)) + + def test_no_module_memory_leak(self): + def f(): + resource = TestResource('some-resource') + @api.convert() + def target(x): + return x + resource, 42 + self.assertEqual(target('foo'), ('foosome-resource', 42)) + + self.assertNoMemoryLeaks(f) + + def test_no_module_memory_leak_deferred_call(self): + def f(): + resource = TestResource('some-resource') + @api.convert() + def target(x): + def inner_fn(): + return x + resource + return inner_fn, 42 + self.assertEqual(target('foo')[0](), 'foosome-resource') + + f() + # TODO(brianklee): Reenable when we've revised module loading approach. + # self.assertNoMemoryLeaks(f) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py index 82471c4b64ad1e..197bd5a3e76992 100644 --- a/tensorflow/python/autograph/impl/conversion.py +++ b/tensorflow/python/autograph/impl/conversion.py @@ -108,7 +108,7 @@ def entity_to_graph(o, program_ctx, arg_values, arg_types): Raises: ValueError: if the entity type is not supported. """ - if program_ctx.options.verbose: + if program_ctx.options.verbose == converter.Verbosity.VERBOSE: logging.info('Converting {}'.format(o)) if tf_inspect.isclass(o): @@ -151,7 +151,7 @@ def entity_to_graph(o, program_ctx, arg_values, arg_types): program_ctx.add_to_cache(o, node) - if program_ctx.options.verbose: + if program_ctx.options.verbose == converter.Verbosity.VERBOSE: logging.info('Compiled output of {}:\n\n{}\n'.format( o, compiler.ast_to_source(node))) @@ -192,8 +192,7 @@ def class_to_graph(c, program_ctx): program_ctx=program_ctx, arg_values={}, arg_types={'self': (c.__name__, c)}, - owner_type=c, - rewrite_errors=False) + owner_type=c) if class_namespace is None: class_namespace = namespace else: @@ -265,8 +264,7 @@ def _add_self_references(namespace, autograph_module): # Craft a module that exposes parts of the external API as well as certain # internal modules. ag_internal = imp.new_module('autograph') - ag_internal.converted_call = autograph_module.converted_call - ag_internal.ConversionOptions = converter.ConversionOptions + ag_internal.__dict__.update(autograph_module.__dict__) ag_internal.utils = utils ag_internal.function_scope = function_wrapping.function_scope ag_internal.rewrite_graph_construction_error = ( @@ -283,8 +281,7 @@ def function_to_graph(f, program_ctx, arg_values, arg_types, - owner_type=None, - rewrite_errors=True): + owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) @@ -303,7 +300,7 @@ def function_to_graph(f, arg_types=arg_types, owner_type=owner_type) context = converter.EntityContext(namer, entity_info, program_ctx) - node = node_to_graph(node, context, rewrite_errors=rewrite_errors) + node = node_to_graph(node, context) # TODO(mdan): This somewhat duplicates the call rename logic in call_trees.py new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) @@ -319,13 +316,12 @@ def function_to_graph(f, return [node], new_name, namespace -def node_to_graph(node, context, rewrite_errors=True): +def node_to_graph(node, context): """Convert Python code to equivalent TF graph mode code. Args: node: AST, the code to convert. context: converter.EntityContext - rewrite_errors: Boolean, whether or not to rewrite the error traceback. Returns: A tuple (node, deps): @@ -363,6 +359,5 @@ def node_to_graph(node, context, rewrite_errors=True): if context.program.options.uses(converter.Feature.AUTO_CONTROL_DEPS): node = converter.apply_(node, context, side_effect_guards) node = converter.apply_(node, context, function_scopes) - if rewrite_errors: - node = converter.apply_(node, context, error_handlers) + node = converter.apply_(node, context, error_handlers) return node diff --git a/tensorflow/python/autograph/pyct/compiler.py b/tensorflow/python/autograph/pyct/compiler.py index 21281aeb561475..06e66c5b5871d5 100644 --- a/tensorflow/python/autograph/pyct/compiler.py +++ b/tensorflow/python/autograph/pyct/compiler.py @@ -123,26 +123,15 @@ def ast_to_object(nodes, compiled_nodes = imp.load_source(module_name, f.name) # TODO(znado): Clean this up so we don't need to attach it to the namespace. - # TODO(znado): This does not work for classes because their methods share a - # namespace. - # This attaches the source map which is needed for error handling. Note that - # api.to_graph copies this source map into an attribute of the function. - # - # We need this so the ag_source_map__ variable is available to the call to - # rewrite_graph_construction_error in the except block inside each function - # that handles graph construction errors. - # # We cannot get the rewritten function name until it is too late so templating - # is hard, and this cleanly fixes the - # issues encountered with nested functions because this is attached to the - # outermost one. + # is hard, and this cleanly fixes the issues encountered with nested functions + # because this is attached to the outermost one. if include_source_map: # TODO(mdan): This name should be decided by the caller. source_map_name = 'ag_source_map__' - if source_map_name in compiled_nodes.__dict__: - raise ValueError('cannot convert %s because is has namespace attribute ' - '"%s", which is reserved for AutoGraph.' % - (compiled_nodes, source_map_name)) + assert source_map_name not in compiled_nodes.__dict__, ( + 'cannot convert %s because is has namespace attribute "%s", which is ' + 'reserved for AutoGraph.') % (compiled_nodes, source_map_name) compiled_nodes.__dict__[source_map_name] = source_map return compiled_nodes, source diff --git a/tensorflow/python/autograph/pyct/inspect_utils.py b/tensorflow/python/autograph/pyct/inspect_utils.py index e078cd56a21085..88f2d7a05649f5 100644 --- a/tensorflow/python/autograph/pyct/inspect_utils.py +++ b/tensorflow/python/autograph/pyct/inspect_utils.py @@ -185,12 +185,9 @@ def getmethodclass(m): return m.__class__ # Instance method and class methods: should be bound to a non-null "self". - # If self is a class, then it's a class method. if hasattr(m, '__self__'): if m.__self__: - if tf_inspect.isclass(m.__self__): - return m.__self__ - return type(m.__self__) + return m.__self__ # Class, static and unbound methods: search all defined classes in any # namespace. This is inefficient but more robust method. diff --git a/tensorflow/python/autograph/pyct/inspect_utils_test.py b/tensorflow/python/autograph/pyct/inspect_utils_test.py index 7e79b3b9f68702..51116b6cac762f 100644 --- a/tensorflow/python/autograph/pyct/inspect_utils_test.py +++ b/tensorflow/python/autograph/pyct/inspect_utils_test.py @@ -184,16 +184,16 @@ def test_getmethodclass(self): test_obj = TestClass() self.assertEqual( inspect_utils.getmethodclass(test_obj.member_function), - TestClass) + test_obj) self.assertEqual( inspect_utils.getmethodclass(test_obj.decorated_member), - TestClass) + test_obj) self.assertEqual( inspect_utils.getmethodclass(test_obj.fn_decorated_member), - TestClass) + test_obj) self.assertEqual( inspect_utils.getmethodclass(test_obj.wrap_decorated_member), - TestClass) + test_obj) self.assertEqual( inspect_utils.getmethodclass(test_obj.static_method), TestClass) @@ -242,16 +242,16 @@ def wrap_decorated_member(self): test_obj = LocalClass() self.assertEqual( inspect_utils.getmethodclass(test_obj.member_function), - LocalClass) + test_obj) self.assertEqual( inspect_utils.getmethodclass(test_obj.decorated_member), - LocalClass) + test_obj) self.assertEqual( inspect_utils.getmethodclass(test_obj.fn_decorated_member), - LocalClass) + test_obj) self.assertEqual( inspect_utils.getmethodclass(test_obj.wrap_decorated_member), - LocalClass) + test_obj) def test_getmethodclass_callables(self): class TestCallable(object): diff --git a/tensorflow/python/autograph/pyct/templates.py b/tensorflow/python/autograph/pyct/templates.py index 66e6be53c0c1d7..9901295445c7a7 100644 --- a/tensorflow/python/autograph/pyct/templates.py +++ b/tensorflow/python/autograph/pyct/templates.py @@ -144,6 +144,12 @@ def _check_inner_children_have_context(self, node): self._check_has_context(node) elif isinstance(node, (gast.Str, gast.Num)): pass + elif isinstance(node, gast.Call): + self._check_inner_children_have_context(node.func) + for a in node.args: + self._check_inner_children_have_context(a) + for k in node.keywords: + self._check_inner_children_have_context(k.value) else: raise ValueError('unexpected node type "%s"' % node) diff --git a/tensorflow/python/autograph/pyct/templates_test.py b/tensorflow/python/autograph/pyct/templates_test.py index 30322418469e4f..54019ef5f4a20e 100644 --- a/tensorflow/python/autograph/pyct/templates_test.py +++ b/tensorflow/python/autograph/pyct/templates_test.py @@ -214,15 +214,15 @@ def test_fn(): result, _ = compiler.ast_to_object(node) self.assertEquals(3, result.test_fn()) - def replace_as_expression(self): + def test_replace_as_expression(self): template = """ foo(a) """ - node = templates.replace(template, foo='bar', a='baz') - self.assertTrue(node is gast.Call) + node = templates.replace_as_expression(template, foo='bar', a='baz') + self.assertIsInstance(node, gast.Call) self.assertEqual(node.func.id, 'bar') - self.assertEqual(node.func.args[0].id, 'baz') + self.assertEqual(node.args[0].id, 'baz') def test_replace_as_expression_restrictions(self): template = """ @@ -232,6 +232,13 @@ def test_replace_as_expression_restrictions(self): with self.assertRaises(ValueError): templates.replace_as_expression(template) + def test_function_call_in_list(self): + template = """ + foo(bar) + """ + source = parser.parse_expression('[a(b(1))]') + templates.replace_as_expression(template, bar=source) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/client/client_lib.py b/tensorflow/python/client/client_lib.py index 80a256bf7a8703..6efddba9792533 100644 --- a/tensorflow/python/client/client_lib.py +++ b/tensorflow/python/client/client_lib.py @@ -15,7 +15,7 @@ """Support for launching graphs and executing operations. -See the [Client](https://tensorflow.org/api_guides/python/client) guide. +See the [Client](https://www.tensorflow.org/guide/graphs) guide. """ from __future__ import absolute_import diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index aff4b1e32469b5..1074a8b5a92f2c 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -26,7 +26,7 @@ from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 10, 29) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 11, 5) @tf_export("compat.forward_compatible") diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD index a1382f759828b6..bfe2e0cf7a1eb8 100644 --- a/tensorflow/python/data/experimental/kernel_tests/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/BUILD @@ -279,6 +279,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:io_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python/data/experimental/ops:readers", "//tensorflow/python/data/ops:readers", diff --git a/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py index 5ee94e14dcdd77..91ae8cb1bd2471 100644 --- a/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py @@ -20,11 +20,13 @@ import numpy as np from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base +from tensorflow.python.data.experimental.ops import readers from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.ops import io_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test @@ -234,6 +236,20 @@ def testIndefiniteRepeatShapeInference(self): if issubclass(clazz, ops.Tensor): self.assertEqual(32, shape[0]) + def testOldStyleReader(self): + with self.assertRaisesRegexp( + TypeError, r"The `reader` argument must return a `Dataset` object. " + r"`tf.ReaderBase` subclasses are not supported."): + _ = readers.make_batched_features_dataset( + file_pattern=self.test_filenames[0], batch_size=32, + features={ + "file": parsing_ops.FixedLenFeature([], dtypes.int64), + "record": parsing_ops.FixedLenFeature([], dtypes.int64), + "keywords": parsing_ops.VarLenFeature(dtypes.string), + "label": parsing_ops.FixedLenFeature([], dtypes.string), + }, + reader=io_ops.TFRecordReader) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD index f22084066d222c..5b75e54f66cd60 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD @@ -215,6 +215,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:clip_ops", "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/assert_next_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/assert_next_dataset_test.py index c2a5da3af00f9e..ed719a0ce9b2c2 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/assert_next_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/assert_next_dataset_test.py @@ -51,6 +51,9 @@ def testAssertNextInvalid(self): def testAssertNextShort(self): dataset = dataset_ops.Dataset.from_tensors(0).apply( optimization.assert_next(["Map", "Whoops"])).map(lambda x: x) + options = dataset_ops.Options() + options.experimental_autotune = False + dataset = dataset.with_options(options) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py index 59c24db5b39afc..f10b66ff69159e 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py @@ -38,6 +38,7 @@ from tensorflow.python.ops import bitwise_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import parsing_ops @@ -280,15 +281,22 @@ def dense_output_only_parse_fn(x): y for y in parse_result if not isinstance(y, sparse_tensor.SparseTensor) ] + def map_fn_with_cycle(x): + c = lambda i: math_ops.less(i, 10) + b = lambda i: math_ops.add(i, 1) + return control_flow_ops.while_loop(c, b, [x]) + # Misc test cases test_cases = [ ("Basic", lambda x: (x, x + 1), base_dataset_factory), ("Broadcast", lambda x: x + rand_val, base_dataset_factory), + ("Cycle", map_fn_with_cycle, lambda: dataset_ops.Dataset.from_tensors(1)), ("Const", lambda x: 2, base_dataset_factory), ("Cast", lambda x: math_ops.cast(x, dtypes.float64), base_dataset_factory), ("Reshape", lambda x: array_ops.reshape(x, (-1, 30)), base_dataset_factory), + ("Transpose", array_ops.transpose, base_dataset_factory), ("Unpack", array_ops.unstack, base_dataset_factory), ("UnpackNegativeAxis", lambda x: array_ops.unstack(x, axis=-1), base_dataset_factory), diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_test.py index c04bef89f55b7f..bd263ee658f4f5 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_test.py @@ -104,6 +104,23 @@ def testOptimizationThreadPoolDataset(self): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testOptimizationNonSerializable(self): + dataset = dataset_ops.Dataset.from_tensors(0) + dataset = dataset.apply(optimization.assert_next(["FiniteSkip"])) + dataset = dataset.skip(0) # Should not be removed by noop elimination + dataset = dataset.apply(optimization.non_serializable()) + dataset = dataset.apply(optimization.assert_next(["MemoryCacheImpl"])) + dataset = dataset.skip(0) # Should be removed by noop elimination + dataset = dataset.cache() + dataset = dataset_ops._OptimizeDataset(dataset, ["noop_elimination"]) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.cached_session() as sess: + self.assertEquals(0, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py index 723e709ae8dbb9..c74f754fefbc88 100644 --- a/tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py @@ -27,14 +27,13 @@ from tensorflow.python.data.experimental.ops import parsing_ops as contrib_parsing_ops from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import test_util from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test -from tensorflow.python.platform import tf_logging # Helpers for creating Example objects example = example_pb2.Example @@ -49,70 +48,63 @@ sequence_example = example_pb2.SequenceExample -def _compare_output_to_expected(tester, dict_tensors, expected_tensors, - flat_output): - tester.assertEqual(set(dict_tensors.keys()), set(expected_tensors.keys())) - - i = 0 # Index into the flattened output of session.run() - for k, v in sorted(dict_tensors.items()): - # TODO(shivaniagrawal): flat_output is same as v. - expected_v = expected_tensors[k] - tf_logging.info("Comparing key: %s", k) - print("i", i, "flat_output", flat_output[i], "expected_v", expected_v) - if sparse_tensor.is_sparse(v): - # Three outputs for SparseTensor : indices, values, shape. - tester.assertEqual([k, len(expected_v)], [k, 3]) - print("i", i, "flat_output", flat_output[i].indices, "expected_v", - expected_v[0]) - tester.assertAllEqual(expected_v[0], flat_output[i].indices) - tester.assertAllEqual(expected_v[1], flat_output[i].values) - tester.assertAllEqual(expected_v[2], flat_output[i].dense_shape) - else: - # One output for standard Tensor. - tester.assertAllEqual(expected_v, flat_output[i]) - i += 1 +@test_util.run_all_in_graph_and_eager_modes +class ParseExampleDatasetTest(test_base.DatasetTestBase): + def _compare_output_to_expected(self, dict_tensors, expected_tensors): + self.assertEqual(set(dict_tensors.keys()), set(expected_tensors.keys())) -class ParseExampleDatasetTest(test_base.DatasetTestBase): + for k, v in sorted(dict_tensors.items()): + expected_v = expected_tensors[k] + if sparse_tensor.is_sparse(v): + self.assertSparseValuesEqual(expected_v, v) + else: + # One output for standard Tensor. + self.assertAllEqual(expected_v, v) def _test(self, input_tensor, feature_val, expected_values=None, - expected_err=None): - - with self.cached_session() as sess: - if expected_err: - with self.assertRaisesWithPredicateMatch(expected_err[0], - expected_err[1]): - dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply( - contrib_parsing_ops.parse_example_dataset(feature_val)) - get_next = dataset.make_one_shot_iterator().get_next() - sess.run(get_next) - return - else: - # Returns dict w/ Tensors and SparseTensors. - # Check values. + expected_err=None, + create_iterator_twice=False): + + if expected_err: + with self.assertRaisesWithPredicateMatch(expected_err[0], + expected_err[1]): dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply( contrib_parsing_ops.parse_example_dataset(feature_val)) - get_next = dataset.make_one_shot_iterator().get_next() - result = sess.run(get_next) - flattened = nest.flatten(result) - print("result", result, "expected_values", expected_values) - _compare_output_to_expected(self, result, expected_values, flattened) - - # Check shapes; if serialized is a Tensor we need its size to - # properly check. - batch_size = ( - input_tensor.eval().size if isinstance(input_tensor, ops.Tensor) else - np.asarray(input_tensor).size) - for k, f in feature_val.items(): - print("output_shapes as list ", - tuple(dataset.output_shapes[k].as_list())) - if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None: - self.assertEqual(dataset.output_shapes[k].as_list()[0], batch_size) - elif isinstance(f, parsing_ops.VarLenFeature): - self.assertEqual(dataset.output_shapes[k].as_list()[1], None) + get_next = self.getNext(dataset) + self.evaluate(get_next()) + return + else: + # Returns dict w/ Tensors and SparseTensors. + # Check values. + dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply( + contrib_parsing_ops.parse_example_dataset(feature_val)) + get_next = self.getNext(dataset) + result = self.evaluate(get_next()) + self._compare_output_to_expected(result, expected_values) + with self.assertRaises(errors_impl.OutOfRangeError): + self.evaluate(get_next()) + with self.assertRaises(errors_impl.OutOfRangeError): + self.evaluate(get_next()) + if create_iterator_twice: + get_next = self.getNext(dataset) + result = self.evaluate(get_next()) + self._compare_output_to_expected(result, expected_values) + with self.assertRaises(errors_impl.OutOfRangeError): + self.evaluate(get_next()) + # Check shapes; if serialized is a Tensor we need its size to + # properly check. + batch_size = ( + self.evaluate(input_tensor).size if isinstance(input_tensor, ops.Tensor) + else np.asarray(input_tensor).size) + for k, f in feature_val.items(): + if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None: + self.assertEqual(dataset.output_shapes[k].as_list()[0], batch_size) + elif isinstance(f, parsing_ops.VarLenFeature): + self.assertEqual(dataset.output_shapes[k].as_list()[1], None) def testEmptySerializedWithAllDefaults(self): sparse_name = "st_a" @@ -123,13 +115,10 @@ def testEmptySerializedWithAllDefaults(self): b_default = np.random.rand(3, 3).astype(bytes) c_default = np.random.rand(2).astype(np.float32) - expected_st_a = ( # indices, values, shape - np.empty( - (0, 2), dtype=np.int64), # indices - np.empty( - (0,), dtype=np.int64), # sp_a is DT_INT64 - np.array( - [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0 + expected_st_a = sparse_tensor.SparseTensorValue( # indices, values, shape + np.empty((0, 2), dtype=np.int64), # indices + np.empty((0,), dtype=np.int64), # sp_a is DT_INT64 + np.array([2, 0], dtype=np.int64)) # batch == 2, max_elems = 0 expected_output = { sparse_name: expected_st_a, @@ -152,7 +141,8 @@ def testEmptySerializedWithAllDefaults(self): parsing_ops.FixedLenFeature( (2,), dtypes.float32, default_value=c_default), }, - expected_values=expected_output) + expected_values=expected_output, + create_iterator_twice=True) def testEmptySerializedWithoutDefaultsShouldFail(self): input_features = { @@ -233,17 +223,14 @@ def testSerializedContainingSparse(self): serialized = [m.SerializeToString() for m in original] - expected_st_c = ( # indices, values, shape - np.array( - [[0, 0], [0, 1], [3, 0], [3, 1], [3, 2]], dtype=np.int64), np.array( - [3.0, 4.0, 1.0, 2.0, -1.0], dtype=np.float32), np.array( - [4, 3], dtype=np.int64)) # batch == 2, max_elems = 3 + expected_st_c = sparse_tensor.SparseTensorValue( # indices, values, shape + np.array([[0, 0], [0, 1], [3, 0], [3, 1], [3, 2]], dtype=np.int64), + np.array([3.0, 4.0, 1.0, 2.0, -1.0], dtype=np.float32), + np.array([4, 3], dtype=np.int64)) # batch == 2, max_elems = 3 - expected_st_d = ( # indices, values, shape - np.array( - [[3, 0]], dtype=np.int64), np.array( - ["hi"], dtype=bytes), np.array( - [4, 1], dtype=np.int64)) # batch == 2, max_elems = 1 + expected_st_d = sparse_tensor.SparseTensorValue( # indices, values, shape + np.array([[3, 0]], dtype=np.int64), np.array(["hi"], dtype=bytes), + np.array([4, 1], dtype=np.int64)) # batch == 2, max_elems = 1 expected_output = { "st_c": expected_st_c, @@ -255,7 +242,8 @@ def testSerializedContainingSparse(self): "st_c": parsing_ops.VarLenFeature(dtypes.float32), "st_d": parsing_ops.VarLenFeature(dtypes.string) }, - expected_values=expected_output) + expected_values=expected_output, + create_iterator_twice=True) def testSerializedContainingSparseFeature(self): original = [ @@ -280,19 +268,18 @@ def testSerializedContainingSparseFeature(self): serialized = [m.SerializeToString() for m in original] - expected_sp = ( # indices, values, shape - np.array( - [[0, 5], [0, 10], [3, 0], [3, 3], [3, 9]], dtype=np.int64), - np.array( - [3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32), np.array( - [4, 13], dtype=np.int64)) # batch == 4, max_elems = 13 + expected_sp = sparse_tensor.SparseTensorValue( # indices, values, shape + np.array([[0, 5], [0, 10], [3, 0], [3, 3], [3, 9]], dtype=np.int64), + np.array([3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32), + np.array([4, 13], dtype=np.int64)) # batch == 4, max_elems = 13 expected_output = {"sp": expected_sp,} self._test( ops.convert_to_tensor(serialized), {"sp": parsing_ops.SparseFeature(["idx"], "val", dtypes.float32, [13])}, - expected_values=expected_output) + expected_values=expected_output, + create_iterator_twice=True) def testSerializedContainingSparseFeatureReuse(self): original = [ @@ -309,17 +296,15 @@ def testSerializedContainingSparseFeatureReuse(self): serialized = [m.SerializeToString() for m in original] - expected_sp1 = ( # indices, values, shape - np.array( - [[0, 5], [0, 10]], dtype=np.int64), np.array( - [3.0, 4.0], dtype=np.float32), np.array( - [2, 13], dtype=np.int64)) # batch == 2, max_elems = 13 + expected_sp1 = sparse_tensor.SparseTensorValue( # indices, values, shape + np.array([[0, 5], [0, 10]], dtype=np.int64), + np.array([3.0, 4.0], dtype=np.float32), + np.array([2, 13], dtype=np.int64)) # batch == 2, max_elems = 13 - expected_sp2 = ( # indices, values, shape - np.array( - [[0, 5], [0, 10]], dtype=np.int64), np.array( - [5.0, 6.0], dtype=np.float32), np.array( - [2, 7], dtype=np.int64)) # batch == 2, max_elems = 13 + expected_sp2 = sparse_tensor.SparseTensorValue( # indices, values, shape + np.array([[0, 5], [0, 10]], dtype=np.int64), + np.array([5.0, 6.0], dtype=np.float32), + np.array([2, 7], dtype=np.int64)) # batch == 2, max_elems = 13 expected_output = { "sp1": expected_sp1, @@ -334,7 +319,8 @@ def testSerializedContainingSparseFeatureReuse(self): parsing_ops.SparseFeature( "idx", "val2", dtypes.float32, size=7, already_sorted=True) }, - expected_values=expected_output) + expected_values=expected_output, + create_iterator_twice=True) def testSerializedContaining3DSparseFeature(self): original = [ @@ -361,11 +347,10 @@ def testSerializedContaining3DSparseFeature(self): serialized = [m.SerializeToString() for m in original] - expected_sp = ( + expected_sp = sparse_tensor.SparseTensorValue( # indices - np.array( - [[0, 5, 0], [0, 10, 2], [3, 0, 1], [3, 3, 2], [3, 9, 0]], - dtype=np.int64), + np.array([[0, 5, 0], [0, 10, 2], [3, 0, 1], [3, 3, 2], [3, 9, 0]], + dtype=np.int64), # values np.array([3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32), # shape batch == 4, max_elems = 13 @@ -379,7 +364,8 @@ def testSerializedContaining3DSparseFeature(self): parsing_ops.SparseFeature(["idx0", "idx1"], "val", dtypes.float32, [13, 3]) }, - expected_values=expected_output) + expected_values=expected_output, + create_iterator_twice=True) def testSerializedContainingDense(self): aname = "a" @@ -413,7 +399,8 @@ def testSerializedContainingDense(self): bname: parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string), }, - expected_values=expected_output) + expected_values=expected_output, + create_iterator_twice=True) # This test is identical as the previous one except # for the creation of 'serialized'. @@ -459,7 +446,8 @@ def testSerializedContainingDenseWithConcat(self): bname: parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string), }, - expected_values=expected_output) + expected_values=expected_output, + create_iterator_twice=True) def testSerializedContainingDenseScalar(self): original = [ @@ -482,7 +470,8 @@ def testSerializedContainingDenseScalar(self): parsing_ops.FixedLenFeature( (1,), dtype=dtypes.float32, default_value=-1), }, - expected_values=expected_output) + expected_values=expected_output, + create_iterator_twice=True) def testSerializedContainingDenseWithDefaults(self): original = [ @@ -519,21 +508,18 @@ def testSerializedContainingDenseWithDefaults(self): parsing_ops.FixedLenFeature( (1, 1, 1, 1), dtype=dtypes.string, default_value="tmp_str"), }, - expected_values=expected_output) - - def testSerializedContainingSparseAndSparseFeatureAndDenseWithNoDefault(self): - expected_st_a = ( # indices, values, shape - np.empty( - (0, 2), dtype=np.int64), # indices - np.empty( - (0,), dtype=np.int64), # sp_a is DT_INT64 - np.array( - [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0 - expected_sp = ( # indices, values, shape - np.array( - [[0, 0], [0, 3], [1, 7]], dtype=np.int64), np.array( - ["a", "b", "c"], dtype="|S"), np.array( - [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13 + expected_values=expected_output, + create_iterator_twice=True) + + def testSerializedSparseAndSparseFeatureAndDenseWithNoDefault(self): + expected_st_a = sparse_tensor.SparseTensorValue( # indices, values, shape + np.empty((0, 2), dtype=np.int64), # indices + np.empty((0,), dtype=np.int64), # sp_a is DT_INT64 + np.array([2, 0], dtype=np.int64)) # batch == 2, max_elems = 0 + expected_sp = sparse_tensor.SparseTensorValue( # indices, values, shape + np.array([[0, 0], [0, 3], [1, 7]], dtype=np.int64), + np.array(["a", "b", "c"], dtype="|S"), + np.array([2, 13], dtype=np.int64)) # batch == 4, max_elems = 13 original = [ example(features=features({ @@ -577,20 +563,19 @@ def testSerializedContainingSparseAndSparseFeatureAndDenseWithNoDefault(self): "c": parsing_ops.FixedLenFeature((2,), dtypes.float32), }, - expected_values=expected_output) + expected_values=expected_output, + create_iterator_twice=True) - def testSerializedContainingSparseAndSparseFeatureWithReuse(self): - expected_idx = ( # indices, values, shape - np.array( - [[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64), - np.array([0, 3, 7, 1]), np.array( - [2, 2], dtype=np.int64)) # batch == 4, max_elems = 2 + def testerializedContainingSparseAndSparseFeatureWithReuse(self): + expected_idx = sparse_tensor.SparseTensorValue( # indices, values, shape + np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64), + np.array([0, 3, 7, 1]), + np.array([2, 2], dtype=np.int64)) # batch == 4, max_elems = 2 - expected_sp = ( # indices, values, shape - np.array( - [[0, 0], [0, 3], [1, 1], [1, 7]], dtype=np.int64), np.array( - ["a", "b", "d", "c"], dtype="|S"), np.array( - [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13 + expected_sp = sparse_tensor.SparseTensorValue( # indices, values, shape + np.array([[0, 0], [0, 3], [1, 1], [1, 7]], dtype=np.int64), + np.array(["a", "b", "d", "c"], dtype="|S"), + np.array([2, 13], dtype=np.int64)) # batch == 4, max_elems = 13 original = [ example(features=features({ @@ -616,7 +601,8 @@ def testSerializedContainingSparseAndSparseFeatureWithReuse(self): "sp": parsing_ops.SparseFeature(["idx"], "val", dtypes.string, [13]), }, - expected_values=expected_output) + expected_values=expected_output, + create_iterator_twice=True) def _testSerializedContainingVarLenDenseLargerBatch(self, batch_size): # During parsing, data read from the serialized proto is stored in buffers. @@ -675,18 +661,18 @@ def _testSerializedContainingVarLenDenseLargerBatch(self, batch_size): allow_missing=True, default_value="default"), }, - expected_values=expected_output) + expected_values=expected_output, + create_iterator_twice=True) def testSerializedContainingVarLenDenseLargerBatch(self): np.random.seed(3456) for batch_size in (1, 10, 20, 100, 256): self._testSerializedContainingVarLenDenseLargerBatch(batch_size) - def testSerializedContainingVarLenDense(self): + def testSkipEagerSerializedShapeMismatch(self): aname = "a" bname = "b" cname = "c" - dname = "d" original = [ example(features=features({ cname: int64_feature([2]), @@ -705,6 +691,47 @@ def testSerializedContainingVarLenDense(self): })), ] + serialized = [m.SerializeToString() for m in original] + self._test( + ops.convert_to_tensor(serialized), { + aname: + parsing_ops.FixedLenSequenceFeature((2, 1), + dtype=dtypes.float32, + allow_missing=True, + default_value=[]), + bname: + parsing_ops.FixedLenSequenceFeature( + (2, 1, 1), dtype=dtypes.string, allow_missing=True), + }, + expected_err=(ValueError, + "Cannot reshape a tensor with 0 elements to shape")) + + def testSerializedContainingVarLenDense(self): + aname = "a" + bname = "b" + cname = "c" + dname = "d" + original = [ + example(features=features({ + cname: int64_feature([2]), + })), + example( + features=features({ + aname: float_feature([1, 1]), + bname: bytes_feature([b"b0_str", b"b1_str"]), + })), + example( + features=features({ + aname: float_feature([-1, -1, 2, 2]), + bname: bytes_feature([b"b1"]), + })), + example( + features=features({ + aname: float_feature([]), + cname: int64_feature([3]), + })), + ] + serialized = [m.SerializeToString() for m in original] expected_output = { @@ -742,7 +769,8 @@ def testSerializedContainingVarLenDense(self): parsing_ops.FixedLenSequenceFeature( shape=[], dtype=dtypes.string, allow_missing=True), }, - expected_values=expected_output) + expected_values=expected_output, + create_iterator_twice=True) # Test with padding values. expected_output_custom_padding = dict(expected_output) @@ -789,21 +817,6 @@ def testSerializedContainingVarLenDense(self): errors_impl.OpError, "Key: b, Index: 2. " "Number of bytes values is not a multiple of stride length.")) - self._test( - ops.convert_to_tensor(serialized), { - aname: - parsing_ops.FixedLenSequenceFeature( - (2, 1), - dtype=dtypes.float32, - allow_missing=True, - default_value=[]), - bname: - parsing_ops.FixedLenSequenceFeature( - (2, 1, 1), dtype=dtypes.string, allow_missing=True), - }, - expected_err=(ValueError, - "Cannot reshape a tensor with 0 elements to shape")) - self._test( ops.convert_to_tensor(serialized), { aname: diff --git a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py index 427654cd7628b9..4d794b4b8458d8 100644 --- a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py @@ -21,6 +21,8 @@ from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_base +from tensorflow.python.data.experimental.ops import batching +from tensorflow.python.data.experimental.ops import optimization from tensorflow.python.data.experimental.ops import stats_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors @@ -151,6 +153,64 @@ def testFilteredElementsStats(self): self._assertSummaryHasScalarValue( sess.run(summary_t), "Filter::filtered_elements", 34.0) + def testMapBufferUtilization(self): + + def dataset_fn(): + return dataset_ops.Dataset.range(10).map( + lambda x: array_ops.tile([x], ops.convert_to_tensor([x])), + num_parallel_calls=4) + + self._testParallelCallsStats( + dataset_fn, "ParallelMap", 10, function_processing_time=True) + + def testMapAutoTuneBufferUtilization(self): + + def dataset_fn(): + dataset = dataset_ops.Dataset.range(10).map( + lambda x: array_ops.tile([x], ops.convert_to_tensor([x])), + num_parallel_calls=optimization.AUTOTUNE) + options = dataset_ops.Options() + options.experimental_autotune = True + return dataset.with_options(options) + + self._testParallelCallsStats( + dataset_fn, "ParallelMap", 10, function_processing_time=True) + + def testInterleaveAutoTuneBufferUtilization(self): + + def dataset_fn(): + dataset = dataset_ops.Dataset.range(10).map( + lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))) + dataset = dataset_ops.Dataset.range(1).interleave( + lambda _: dataset, + cycle_length=1, + num_parallel_calls=optimization.AUTOTUNE) + options = dataset_ops.Options() + options.experimental_autotune = True + return dataset.with_options(options) + + self._testParallelCallsStats(dataset_fn, "ParallelInterleaveV2", 10) + + def testMapAndBatchAutoTuneBufferUtilization(self): + + def dataset_fn(): + dataset = dataset_ops.Dataset.range(100).apply( + batching.map_and_batch( + lambda x: array_ops.tile([x], ops.convert_to_tensor([2])), + num_parallel_calls=optimization.AUTOTUNE, + batch_size=16)) + options = dataset_ops.Options() + options.experimental_autotune = True + return dataset.with_options(options) + + num_output = 100 // 16 + 1 + self._testParallelCallsStats( + dataset_fn, + "MapAndBatch", + num_output, + check_elements=False, + function_processing_time=True) + def testReinitialize(self): stats_aggregator = stats_ops.StatsAggregator() dataset = dataset_ops.Dataset.range(100).apply( @@ -287,22 +347,32 @@ def testFeaturesStats(self): total_records = num_epochs * self._num_records batch_size = 2 stats_aggregator = stats_ops.StatsAggregator() - dataset = self.make_batch_feature( - filenames=self.test_filenames[0], - num_epochs=num_epochs, - batch_size=batch_size, - shuffle=True, - shuffle_seed=5, - drop_final_batch=False).apply( - stats_ops.set_stats_aggregator(stats_aggregator, "record_stats")) - iterator = dataset.make_initializable_iterator() + + def dataset_fn(): + return self.make_batch_feature( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5, + drop_final_batch=False) + + num_output = total_records // batch_size + if total_records % batch_size: + num_output = total_records // batch_size + 1 + + self._testParallelCallsStats( + dataset_fn, "ParseExample", num_output, check_elements=False) + + iterator = dataset_fn().apply( + stats_ops.set_stats_aggregator( + stats_aggregator, "record_stats")).make_initializable_iterator() next_element = iterator.get_next() summary_t = stats_aggregator.get_summary() with self.test_session() as sess: sess.run(iterator.initializer) - for _ in range(total_records // batch_size + 1 if total_records % - batch_size else total_records // batch_size): + for _ in range(num_output): sess.run(next_element) with self.assertRaises(errors.OutOfRangeError): diff --git a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py index 80f26259272061..a4e6242b00c849 100644 --- a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py +++ b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py @@ -17,9 +17,12 @@ from __future__ import division from __future__ import print_function +import numpy as np from tensorflow.core.framework import summary_pb2 +from tensorflow.python.data.experimental.ops import stats_ops from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.framework import errors class StatsDatasetTestBase(test_base.DatasetTestBase): @@ -42,6 +45,16 @@ def _assertSummaryHasCount(self, summary_str, tag, expected_value): return self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + def _assertSummaryHasCountMoreOrEqualGeneralisedTag(self, summary_str, tag, + expected_value): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag in value.tag: + self.assertGreaterEqual(value.histo.num, expected_value) + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + def _assertSummaryHasRange(self, summary_str, tag, min_value, max_value): summary_proto = summary_pb2.Summary() summary_proto.ParseFromString(summary_str) @@ -69,3 +82,37 @@ def _assertSummaryHasScalarValue(self, summary_str, tag, expected_value): self.assertEqual(expected_value, value.simple_value) return self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + + def _testParallelCallsStats(self, + dataset_fn, + dataset_name, + num_output, + function_processing_time=False, + check_elements=True): + stats_aggregator = stats_ops.StatsAggregator() + dataset = dataset_fn().apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.cached_session() as sess: + sess.run(iterator.initializer) + for i in range(num_output): + next_ = sess.run(next_element) + if check_elements: + self.assertAllEqual(np.array([i] * i, dtype=np.int64), next_) + summary_str = sess.run(summary_t) + if function_processing_time: + self._assertSummaryHasCountMoreOrEqualGeneralisedTag( + summary_str, "::execution_time", float(i + 1)) + self._assertSummaryContains(summary_str, + dataset_name + "::num_parallel_calls") + self._assertSummaryContains(summary_str, + dataset_name + "::active_parallel_calls") + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + if function_processing_time: + summary_str = sess.run(summary_t) + self._assertSummaryHasCountMoreOrEqualGeneralisedTag( + summary_str, "::execution_time", float(num_output)) diff --git a/tensorflow/python/data/experimental/ops/BUILD b/tensorflow/python/data/experimental/ops/BUILD index 323298e33a6b1b..eda547c37af88a 100644 --- a/tensorflow/python/data/experimental/ops/BUILD +++ b/tensorflow/python/data/experimental/ops/BUILD @@ -82,6 +82,7 @@ py_library( "//tensorflow/python:dtypes", "//tensorflow/python:experimental_dataset_ops_gen", "//tensorflow/python:framework_ops", + "//tensorflow/python:io_ops", "//tensorflow/python:lib", "//tensorflow/python:platform", "//tensorflow/python:tensor_shape", diff --git a/tensorflow/python/data/experimental/ops/map_defun.py b/tensorflow/python/data/experimental/ops/map_defun.py index ec1a3adf0c17e8..5d729d392ac5ec 100644 --- a/tensorflow/python/data/experimental/ops/map_defun.py +++ b/tensorflow/python/data/experimental/ops/map_defun.py @@ -52,7 +52,7 @@ def map_defun(fn, elems, output_dtypes, output_shapes): raise ValueError("`output_shapes` must be a list of `tf.TensorShape` " "objects.") - concrete_fn = fn.get_concrete_function() + concrete_fn = fn._get_concrete_function_internal() # pylint: disable=protected-access # TODO(shivaniagrawal/rachelim): what about functions created without # input_signature. elems = [ops.convert_to_tensor(e) for e in elems] diff --git a/tensorflow/python/data/experimental/ops/optimization.py b/tensorflow/python/data/experimental/ops/optimization.py index 8e1de136b66448..b744db7f1e5fbd 100644 --- a/tensorflow/python/data/experimental/ops/optimization.py +++ b/tensorflow/python/data/experimental/ops/optimization.py @@ -65,6 +65,21 @@ def _apply_fn(dataset): return _apply_fn +def non_serializable(): + """A non-serializable identity transformation. + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return _NonSerializableDataset(dataset) + + return _apply_fn + + def optimize(optimizations=None): """A transformation that applies optimizations. @@ -115,3 +130,28 @@ def output_shapes(self): def output_types(self): return self._input_dataset.output_types + +class _NonSerializableDataset(dataset_ops.UnaryDataset): + """A `Dataset` that performs non-serializable identity transformation.""" + + def __init__(self, input_dataset): + """See `non_serializable()` for details.""" + super(_NonSerializableDataset, self).__init__(input_dataset) + self._input_dataset = input_dataset + + def _as_variant_tensor(self): + return gen_experimental_dataset_ops.experimental_non_serializable_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types diff --git a/tensorflow/python/data/experimental/ops/prefetching_ops.py b/tensorflow/python/data/experimental/ops/prefetching_ops.py index 8cc7452c419362..a55b8bfb769e7a 100644 --- a/tensorflow/python/data/experimental/ops/prefetching_ops.py +++ b/tensorflow/python/data/experimental/ops/prefetching_ops.py @@ -138,12 +138,14 @@ def _prefetch_fn(handle): ret = remote_iterator.get_next() return nest.flatten(sparse.serialize_sparse_tensors(ret)) + self._prefetch_fn = _prefetch_fn._get_concrete_function_internal() # pylint: disable=protected-access + iterator_device = ged_ops.experimental_iterator_get_device( self._input_iterator._iterator_resource) with ops.device(device): self._buffering_resource = function_buffering_resource( - f=_prefetch_fn.get_concrete_function(), + f=self._prefetch_fn, target_device=iterator_device, string_arg=input_iterator_handle, buffer_size=buffer_size, @@ -235,7 +237,7 @@ def _prefetch_fn(handle): ret = remote_iterator.get_next() return nest.flatten(sparse.serialize_sparse_tensors(ret)) - self._prefetch_fn = _prefetch_fn.get_concrete_function() + self._prefetch_fn = _prefetch_fn._get_concrete_function_internal() # pylint: disable=protected-access with ops.device(device): self._buffering_resource = function_buffering_resource( @@ -420,15 +422,17 @@ def _init_func(): [gen_dataset_ops.make_iterator(ds_variant, resource)]): return gen_dataset_ops.iterator_to_string_handle(resource) + init_func_concrete = _init_func._get_concrete_function_internal() # pylint: disable=protected-access + @function.defun() def _remote_init_func(): return functional_ops.remote_call( target=self._source_device, - args=_init_func.get_concrete_function().captured_inputs, + args=init_func_concrete.captured_inputs, Tout=[dtypes.string], - f=_init_func.get_concrete_function()) + f=init_func_concrete) - self._init_func = _remote_init_func.get_concrete_function() + self._init_func = _remote_init_func._get_concrete_function_internal() # pylint: disable=protected-access self._init_captured_args = self._init_func.captured_inputs @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) @@ -447,16 +451,18 @@ def _next_func(string_handle): ret = iterator.get_next() return nest.flatten(sparse.serialize_sparse_tensors(ret)) + next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access + @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _remote_next_func(string_handle): return functional_ops.remote_call( target=self._source_device, args=[string_handle] + - _next_func.get_concrete_function().captured_inputs, + next_func_concrete.captured_inputs, Tout=self._flat_output_types, - f=_next_func.get_concrete_function()) + f=next_func_concrete) - self._next_func = _remote_next_func.get_concrete_function() + self._next_func = _remote_next_func._get_concrete_function_internal() # pylint: disable=protected-access self._next_captured_args = self._next_func.captured_inputs @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) @@ -477,16 +483,19 @@ def _finalize_func(string_handle): iterator_resource, ignore_lookup_error=True)]): return array_ops.constant(0, dtypes.int64) + finalize_func_concrete = _finalize_func._get_concrete_function_internal() # pylint: disable=protected-access + @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _remote_finalize_func(string_handle): return functional_ops.remote_call( target=self._source_device, args=[string_handle] + - _finalize_func.get_concrete_function().captured_inputs, + finalize_func_concrete.captured_inputs, Tout=[dtypes.int64], - f=_finalize_func.get_concrete_function()) + f=finalize_func_concrete) - self._finalize_func = _remote_finalize_func.get_concrete_function() + self._finalize_func = _remote_finalize_func._get_concrete_function_internal( # pylint: disable=protected-access + ) self._finalize_captured_args = self._finalize_func.captured_inputs g = ops.get_default_graph() diff --git a/tensorflow/python/data/experimental/ops/readers.py b/tensorflow/python/data/experimental/ops/readers.py index 3b2d0945148e44..fe601925860b4e 100644 --- a/tensorflow/python/data/experimental/ops/readers.py +++ b/tensorflow/python/data/experimental/ops/readers.py @@ -38,6 +38,7 @@ from tensorflow.python.lib.io import file_io from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gen_experimental_dataset_ops +from tensorflow.python.ops import io_ops from tensorflow.python.platform import gfile from tensorflow.python.util.tf_export import tf_export @@ -760,6 +761,7 @@ def make_batched_features_dataset(file_pattern, Each `dict` maps feature keys to `Tensor` or `SparseTensor` objects. Raises: + TypeError: If `reader` is a `tf.ReaderBase` subclass. ValueError: If `label_key` is not one of the `features` keys. """ # Create dataset of all matching filenames @@ -768,6 +770,12 @@ def make_batched_features_dataset(file_pattern, if shuffle: dataset = dataset.shuffle(len(filenames), shuffle_seed) + if isinstance(reader, type) and issubclass(reader, io_ops.ReaderBase): + raise TypeError("The `reader` argument must return a `Dataset` object. " + "`tf.ReaderBase` subclasses are not supported. For " + "example, pass `tf.data.TFRecordDataset` instead of " + "`tf.TFRecordReader`.") + # Read `Example` records from files as tensor objects. if reader_args is None: reader_args = [] diff --git a/tensorflow/python/data/kernel_tests/dataset_ops_test.py b/tensorflow/python/data/kernel_tests/dataset_ops_test.py index 63d2be4371c3d8..a5324af4d0cf95 100644 --- a/tensorflow/python/data/kernel_tests/dataset_ops_test.py +++ b/tensorflow/python/data/kernel_tests/dataset_ops_test.py @@ -226,7 +226,8 @@ def testOptionsTwiceDifferent(self): ds = dataset_ops.Dataset.range(0).with_options(options1).with_options( options2) self.assertTrue(ds.options().experimental_autotune) - self.assertFalse(ds.options().experimental_filter_fusion) + # Explicitly check that flag is False since assertFalse allows None + self.assertIs(ds.options().experimental_filter_fusion, False) def testOptionsTwiceDifferentError(self): options1 = dataset_ops.Options() @@ -237,6 +238,17 @@ def testOptionsTwiceDifferentError(self): "Cannot merge incompatible values of option"): dataset_ops.Dataset.range(0).with_options(options1).with_options(options2) + def testOptionsMergeOptionsFromMultipleInputs(self): + options1 = dataset_ops.Options() + options1.experimental_autotune = True + options2 = dataset_ops.Options() + options2.experimental_filter_fusion = True + ds = dataset_ops.Dataset.zip( + (dataset_ops.Dataset.range(0).with_options(options1), + dataset_ops.Dataset.range(0).with_options(options2))) + self.assertTrue(ds.options().experimental_autotune) + self.assertTrue(ds.options().experimental_filter_fusion) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py index 25c91b42dc65f8..bf5fd781d65cd1 100644 --- a/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py @@ -134,7 +134,7 @@ def testCaptureHashTableInSharedIterator(self): get_next = iterator.get_next() with session.Session(worker[0].target) as sess: - sess.run(table.init) + sess.run(table.initializer) sess.run(init_op) self.assertAllEqual([0, 0, -1, 1, 2], sess.run(get_next)) diff --git a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py index 8eb13815d4a5da..b58c1444daeb03 100644 --- a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py @@ -238,6 +238,54 @@ def testNoShuffle(self): self.assertEqual(produced_filenames[:len(filenames)], produced_filenames[len(filenames):]) + def testMultiplePatternsAsList(self): + filenames = ['a.txt', 'b.py', 'c.py', 'd.pyc'] + self._touchTempFiles(filenames) + + patterns = [path.join(self.tmp_dir, pat) for pat in ['*.py', '*.txt']] + dataset = dataset_ops.Dataset.list_files(patterns) + with self.cached_session() as sess: + itr = dataset.make_one_shot_iterator() + next_element = itr.get_next() + + full_filenames = [] + produced_filenames = [] + for filename in filenames[:-1]: + full_filenames.append( + compat.as_bytes(path.join(self.tmp_dir, filename))) + produced_filenames.append(compat.as_bytes(sess.run(next_element))) + self.assertItemsEqual(full_filenames, produced_filenames) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(itr.get_next()) + + def testMultiplePatternsAsTensor(self): + filenames = ['a.txt', 'b.py', 'c.py', 'd.pyc'] + self._touchTempFiles(filenames) + + filename_placeholder = array_ops.placeholder( + dtypes.string, shape=[ + 2, + ]) + dataset = dataset_ops.Dataset.list_files(filename_placeholder) + + with self.cached_session() as sess: + itr = dataset.make_initializable_iterator() + next_element = itr.get_next() + patterns = [path.join(self.tmp_dir, pat) for pat in ['*.py', '*.txt']] + sess.run(itr.initializer, feed_dict={filename_placeholder: patterns}) + + full_filenames = [] + produced_filenames = [] + for filename in filenames[:-1]: + full_filenames.append( + compat.as_bytes(path.join(self.tmp_dir, filename))) + produced_filenames.append(compat.as_bytes(sess.run(next_element))) + self.assertItemsEqual(full_filenames, produced_filenames) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(itr.get_next()) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py index 6a9c8843186061..81ef7d16be2c9d 100644 --- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py @@ -353,7 +353,7 @@ def testCaptureHashTable(self): get_next = iterator.get_next() with self.cached_session() as sess: - sess.run(table.init) + sess.run(table.initializer) sess.run(init_op) sess.run(get_next) sess.run(get_next) diff --git a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py b/tensorflow/python/data/kernel_tests/range_dataset_op_test.py index b7e2a5f615ea97..b71e6b2ea43a19 100644 --- a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/range_dataset_op_test.py @@ -26,7 +26,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import io_ops from tensorflow.python.ops import parsing_ops @@ -35,139 +35,58 @@ from tensorflow.python.platform import test +@test_util.run_all_in_graph_and_eager_modes class RangeDatasetTest(test_base.DatasetTestBase): - def tearDown(self): - # Remove all checkpoint files. - prefix = self._iterator_checkpoint_prefix() - pattern = prefix + "*" - files = gfile.Glob(pattern) - map(gfile.Remove, files) - def testStop(self): - stop = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = dataset_ops.Dataset.range(stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op, feed_dict={stop: 5}) - for i in range(5): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + dataset = dataset_ops.Dataset.range(5) + self.assertDatasetProduces(dataset, expected_output=range(5)) def testStartStop(self): - start = array_ops.placeholder(dtypes.int64, shape=[]) - stop = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = dataset_ops.Dataset.range(start, - stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op, feed_dict={start: 2, stop: 5}) - for i in range(2, 5): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + start, stop = 2, 5 + dataset = dataset_ops.Dataset.range(start, stop) + self.assertDatasetProduces(dataset, expected_output=range(2, 5)) def testStartStopStep(self): - start = array_ops.placeholder(dtypes.int64, shape=[]) - stop = array_ops.placeholder(dtypes.int64, shape=[]) - step = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = dataset_ops.Dataset.range(start, stop, - step).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op, feed_dict={start: 2, stop: 10, step: 2}) - for i in range(2, 10, 2): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + start, stop, step = 2, 10, 2 + dataset = dataset_ops.Dataset.range(start, stop, step) + self.assertDatasetProduces(dataset, expected_output=range(2, 10, 2)) def testZeroStep(self): - start = array_ops.placeholder(dtypes.int64, shape=[]) - stop = array_ops.placeholder(dtypes.int64, shape=[]) - step = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = dataset_ops.Dataset.range(start, stop, - step).make_initializable_iterator() - init_op = iterator.initializer - - with self.cached_session() as sess: - with self.assertRaises(errors.InvalidArgumentError): - sess.run(init_op, feed_dict={start: 2, stop: 10, step: 0}) + start, stop, step = 2, 10, 0 + dataset = dataset_ops.Dataset.range(start, stop, step) + self.assertDatasetProduces( + dataset, expected_err=(errors.InvalidArgumentError, "")) def testNegativeStep(self): - start = array_ops.placeholder(dtypes.int64, shape=[]) - stop = array_ops.placeholder(dtypes.int64, shape=[]) - step = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = dataset_ops.Dataset.range(start, stop, - step).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op, feed_dict={start: 2, stop: 10, step: -1}) - # This for loop is a no-op but will ensure that the implementation is - # consistent with range if it ever changes. - for i in range(2, 10, -1): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + start, stop, step = 2, 10, -1 + dataset = dataset_ops.Dataset.range(start, stop, step) + self.assertDatasetProduces(dataset, expected_output=range(2, 10, -1)) def testStopLessThanStart(self): - start = array_ops.placeholder(dtypes.int64, shape=[]) - stop = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = dataset_ops.Dataset.range(start, - stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op, feed_dict={start: 10, stop: 2}) - # This for loop is a no-op but will ensure that the implementation is - # consistent with range if it ever changes. - for i in range(10, 2): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + start, stop = 10, 2 + dataset = dataset_ops.Dataset.range(start, stop) + self.assertDatasetProduces(dataset, expected_output=range(10, 2)) def testStopLessThanStartWithPositiveStep(self): - start = array_ops.placeholder(dtypes.int64, shape=[]) - stop = array_ops.placeholder(dtypes.int64, shape=[]) - step = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = dataset_ops.Dataset.range(start, stop, - step).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op, feed_dict={start: 10, stop: 2, step: 2}) - # This for loop is a no-op but will ensure that the implementation is - # consistent with range if it ever changes. - for i in range(10, 2, 2): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + start, stop, step = 10, 2, 2 + dataset = dataset_ops.Dataset.range(start, stop, step) + self.assertDatasetProduces(dataset, expected_output=range(10, 2, 2)) def testStopLessThanStartWithNegativeStep(self): - start = array_ops.placeholder(dtypes.int64, shape=[]) - stop = array_ops.placeholder(dtypes.int64, shape=[]) - step = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = dataset_ops.Dataset.range(start, stop, - step).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op, feed_dict={start: 10, stop: 2, step: -1}) - for i in range(10, 2, -1): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + start, stop, step = 10, 2, -1 + dataset = dataset_ops.Dataset.range(start, stop, step) + self.assertDatasetProduces(dataset, expected_output=range(10, 2, -1)) + + +class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase): + + def tearDown(self): + # Remove all checkpoint files. + prefix = self._iterator_checkpoint_prefix() + pattern = prefix + "*" + files = gfile.Glob(pattern) + map(gfile.Remove, files) def _iterator_checkpoint_prefix(self): return os.path.join(self.get_temp_dir(), "iterator") diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py index b73a94e6833636..edb3eff3c172d9 100644 --- a/tensorflow/python/data/kernel_tests/test_base.py +++ b/tensorflow/python/data/kernel_tests/test_base.py @@ -62,10 +62,42 @@ def getNext(self, dataset): nxt = it.get_next() return lambda: nxt + def _compare_output_to_expected(self, result_values, expected_values): + for i in range(len(result_values)): + if sparse_tensor.is_sparse(result_values[i]): + self.assertSparseValuesEqual(result_values[i], expected_values[i]) + else: + self.assertAllEqual(result_values[i], expected_values[i]) + + def assertDatasetProduces(self, + input_dataset, + expected_output=None, + expected_err=None, + create_iterator_twice=True): + + if expected_err: + with self.assertRaisesWithPredicateMatch(expected_err[0], + expected_err[1]): + get_next = self.getNext(input_dataset) + self.evaluate(get_next()) + return + repeated = 2 if create_iterator_twice else 1 + for _ in range(repeated): + get_next = self.getNext(input_dataset) + result = [] + for _ in range(len(expected_output)): + result.append(self.evaluate(get_next())) + self._compare_output_to_expected(result, expected_output) + with self.assertRaises(errors.OutOfRangeError): + self.evaluate(get_next()) + with self.assertRaises(errors.OutOfRangeError): + self.evaluate(get_next()) + def assertDatasetsEqual(self, dataset1, dataset2): """Checks that datasets are equal. Supports both graph and eager mode.""" self.assertEqual(dataset1.output_types, dataset2.output_types) self.assertEqual(dataset1.output_classes, dataset2.output_classes) + flattened_types = nest.flatten(dataset1.output_types) next1 = self.getNext(dataset1) next2 = self.getNext(dataset2) @@ -82,12 +114,12 @@ def assertDatasetsEqual(self, dataset1, dataset2): op2 = nest.flatten(op2) assert len(op1) == len(op2) for i in range(len(op1)): - if isinstance( - op1[i], - (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): + if sparse_tensor.is_sparse(op1[i]): self.assertSparseValuesEqual(op1[i], op2[i]) - else: + elif flattened_types[i] == dtypes.string: self.assertAllEqual(op1[i], op2[i]) + else: + self.assertAllClose(op1[i], op2[i]) def assertDatasetsRaiseSameError(self, dataset1, diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 38e46963790f25..e4b5da6403265b 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -30,7 +30,6 @@ from tensorflow.python.data.util import random_seed from tensorflow.python.data.util import sparse from tensorflow.python.eager import context -from tensorflow.python.eager import function as eager_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function @@ -39,7 +38,6 @@ from tensorflow.python.framework import smart_cond from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -54,6 +52,7 @@ @tf_export("data.Dataset") +@six.add_metaclass(abc.ABCMeta) class Dataset(object): """Represents a potentially large set of elements. @@ -61,8 +60,6 @@ class Dataset(object): collection of elements (nested structures of tensors) and a "logical plan" of transformations that act on those elements. """ - __metaclass__ = abc.ABCMeta - def __init__(self): pass @@ -91,16 +88,17 @@ def _inputs(self): raise NotImplementedError("Dataset._inputs") def options(self): - """Returns the options for this dataset. + """Returns the options for this dataset and its inputs. Returns: A `tf.data.Options` object representing the dataset options. """ + options = Options() for input_dataset in self._inputs(): - options = input_dataset.options() - if options is not None: - return options - return Options() + input_options = input_dataset.options() + if input_options is not None: + options = options.merge(input_options) + return options def _apply_options(self): dataset = self @@ -108,7 +106,7 @@ def _apply_options(self): static_optimizations = options._static_optimizations() # pylint: disable=protected-access if static_optimizations: dataset = _OptimizeDataset(dataset, static_optimizations) - if options.experimental_autotune: + if options.experimental_autotune is not False: dataset = _ModelDataset(dataset) return dataset @@ -186,7 +184,8 @@ def make_one_shot_iterator(self): An `Iterator` over the elements of this dataset. """ if context.executing_eagerly(): - return iterator_ops.EagerIterator(self) + dataset = self._apply_options() + return iterator_ops.EagerIterator(dataset) graph_level_seed, op_level_seed = core_random_seed.get_seed(None) @@ -665,7 +664,7 @@ def prefetch(self, buffer_size): @staticmethod def list_files(file_pattern, shuffle=None, seed=None): - """A dataset of all files matching a pattern. + """A dataset of all files matching one or more glob patterns. NOTE: The default behavior of this method is to return filenames in a non-deterministic random shuffled order. Pass a `seed` or `shuffle=False` @@ -682,12 +681,13 @@ def list_files(file_pattern, shuffle=None, seed=None): - /path/to/dir/c.py Args: - file_pattern: A string or scalar string `tf.Tensor`, representing - the filename pattern that will be matched. + file_pattern: A string, a list of strings, or a `tf.Tensor` of string type + (scalar or vector), representing the filename glob (i.e. shell wildcard) + pattern(s) that will be matched. shuffle: (Optional.) If `True`, the file names will be shuffled randomly. Defaults to `True`. - seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the - random seed that will be used to create the distribution. See + seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random + seed that will be used to create the distribution. See `tf.set_random_seed` for behavior. Returns: @@ -1749,8 +1749,7 @@ def output_types(self): class StructuredFunctionWrapper(object): - """A wrapper for `defun` that supports structured arguments and return values. - + """A wrapper for `Defun` that supports structured arguments and return values. """ def __init__(self, @@ -1808,15 +1807,15 @@ def __init__(self, ".", "_")[:-2] if len(transformation_name) > 2 else "" self._func_name = "_".join([ readable_transformation_name, - function_utils.get_func_name(func) + function_utils.get_func_name(func), + str(ops.uid()) + ]) # TODO(b/110122868): Enable this support for all `tf.data` functions. self._nested_dataset_support = experimental_nested_dataset_support - @eager_function.defun_with_attributes( - input_signature=self._defun_args(), - attributes={"func_name": self._func_name}) + @function.Defun(*self._defun_args(), func_name=self._func_name) def tf_data_structured_function_wrapper(*args): """Wrapper for passing nested structures to and from tf.data functions.""" flat_args = [] @@ -1901,50 +1900,37 @@ def tf_data_structured_function_wrapper(*args): self._output_classes = nest.pack_sequence_as(ret, flat_classes) self._output_shapes = nest.pack_sequence_as(ret, flat_shapes) self._output_types = nest.pack_sequence_as(ret, flat_types) - return flat_ret - table_initializers_len = len(ops.get_default_graph().get_collection( - ops.GraphKeys.TABLE_INITIALIZERS)) + _warn_if_collections(transformation_name) - self._function = tf_data_structured_function_wrapper.get_concrete_function() + return flat_ret + self._function = tf_data_structured_function_wrapper if add_to_graph: self._function.add_to_graph(ops.get_default_graph()) - - if len( - self._function.graph.get_collection( - ops.GraphKeys.TABLE_INITIALIZERS)) != table_initializers_len: - warnings.warn( - "Creating lookup tables inside a function passed to %s is not" - " supported. Create each table outside the function, and " - "capture it inside the function to use it." % transformation_name) + else: + # Use the private method that will execute + # `tf_data_structured_function_wrapper` but delay adding it to the graph + # in case (e.g.) we need to rerun the function. + self._function._create_definition_if_needed() # pylint: disable=protected-access def _defun_args(self): - """Returns a list of `tf.TensorSpec` for the input element structure.""" + """Returns a flat list of `tf.DType` for the input element structure.""" ret = [] - for input_type, input_shape, input_class in zip( - nest.flatten(self._input_types), nest.flatten(self._input_shapes), - nest.flatten(self._input_classes)): + for input_type, input_class in zip(nest.flatten(self._input_types), + nest.flatten(self._input_classes)): # TODO(b/110122868): Add a registration mechanism for new component types. if input_class is sparse_tensor_lib.SparseTensor: - # Give TensorSpec objects unique names to satisfy error checking in - # get_concrete_function. - ret.append( - tensor_spec.TensorSpec( - tensor_shape.TensorShape(None), dtypes.variant, - name="arg_{}".format(len(ret)))) + ret.append(dtypes.variant) elif isinstance(input_class, _NestedDatasetComponent): if not self._nested_dataset_support: raise NotImplementedError( "The %s transformation does not currently support nested " "datasets as inputs." % self._transformation_name) - ret.append( - tensor_spec.TensorSpec(tensor_shape.scalar(), dtypes.variant, - name="arg_{}".format(len(ret)))) + ret.append(dtypes.variant) else: assert isinstance(input_type, dtypes.DType) - ret.append(tensor_spec.TensorSpec(input_shape, input_type, - name="arg_{}".format(len(ret)))) + ret.append(input_type) return ret @property @@ -2665,6 +2651,24 @@ def _should_unpack_args(args): return type(args) is tuple # pylint: disable=unidiomatic-typecheck +def _warn_if_collections(transformation_name): + """Prints warning message if the current graph uses common graph collections. + + NOTE(mrry): Currently a warning is only generated for lookup tables. Any + variables created will be automatically hoisted out to the outermost scope + using `init_scope()`. Some collections (such as for control-flow contexts) + are benign and should not generate a warning. + + Args: + transformation_name: A human-readable name for the transformation. + """ + if ops.get_default_graph().get_collection(ops.GraphKeys.TABLE_INITIALIZERS): + warnings.warn("Creating lookup tables inside a function passed to %s is not" + " supported. Create each table outside the function, and " + "capture it inside the function to use it." + % transformation_name) + + class MapDataset(UnaryDataset): """A `Dataset` that maps a function over elements in its input.""" diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index 9214f7d79d0e79..68b03ba93be675 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -571,7 +571,7 @@ def _next_internal(self): output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) - return self._structure._from_tensor_list(ret) # pylint: disable=protected-access + return self._structure._from_compatible_tensor_list(ret) # pylint: disable=protected-access def next(self): """Returns a nested structure of `tf.Tensor`s containing the next element. diff --git a/tensorflow/python/data/ops/multi_device_iterator_ops.py b/tensorflow/python/data/ops/multi_device_iterator_ops.py index 7b103d241b6db1..0f9add6461aeeb 100644 --- a/tensorflow/python/data/ops/multi_device_iterator_ops.py +++ b/tensorflow/python/data/ops/multi_device_iterator_ops.py @@ -55,15 +55,17 @@ def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id, def _init_func(): return multi_device_iterator_string_handle + init_func_concrete = _init_func._get_concrete_function_internal() # pylint: disable=protected-access + @function.defun() def _remote_init_func(): return functional_ops.remote_call( target=source_device, - args=_init_func.get_concrete_function().captured_inputs, + args=init_func_concrete.captured_inputs, Tout=[dtypes.string], - f=_init_func.get_concrete_function()) + f=init_func_concrete) - self._init_func = _remote_init_func.get_concrete_function() + self._init_func = _remote_init_func._get_concrete_function_internal() # pylint: disable=protected-access self._init_captured_args = self._init_func.captured_inputs @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) @@ -80,6 +82,8 @@ def _next_func(string_handle): output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) + next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access + @function.defun_with_attributes( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], attributes={"experimental_ints_on_device": True}) @@ -87,27 +91,30 @@ def _remote_next_func(string_handle): return functional_ops.remote_call( target=source_device, args=[string_handle] + - _next_func.get_concrete_function().captured_inputs, + next_func_concrete.captured_inputs, Tout=self._flat_output_types, - f=_next_func.get_concrete_function()) + f=next_func_concrete) - self._next_func = _remote_next_func.get_concrete_function() + self._next_func = _remote_next_func._get_concrete_function_internal() # pylint: disable=protected-access self._next_captured_args = self._next_func.captured_inputs @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _finalize_func(unused_string_handle): return array_ops.constant(0, dtypes.int64) + finalize_func_concrete = _finalize_func._get_concrete_function_internal() # pylint: disable=protected-access + @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _remote_finalize_func(string_handle): return functional_ops.remote_call( target=source_device, args=[string_handle] + - _finalize_func.get_concrete_function().captured_inputs, + finalize_func_concrete.captured_inputs, Tout=[dtypes.int64], - f=_finalize_func.get_concrete_function()) + f=finalize_func_concrete) - self._finalize_func = _remote_finalize_func.get_concrete_function() + self._finalize_func = _remote_finalize_func._get_concrete_function_internal( # pylint: disable=protected-access + ) self._finalize_captured_args = self._finalize_func.captured_inputs def _as_variant_tensor(self): @@ -213,6 +220,10 @@ def __init__(self, self._dataset.output_types, self._dataset.output_classes) if prefetch_buffer_size > 0: ds = ds.prefetch(prefetch_buffer_size) + # TODO(jsimsa): Enable auto-tuning when supported for non-CPU devices. + options = dataset_ops.Options() + options.experimental_autotune = False + ds = ds.with_options(options) with ops.device(device): self._device_iterators.append(ds.make_initializable_iterator()) diff --git a/tensorflow/python/data/ops/optional_ops.py b/tensorflow/python/data/ops/optional_ops.py index 620f403f841f43..91cf883ce94648 100644 --- a/tensorflow/python/data/ops/optional_ops.py +++ b/tensorflow/python/data/ops/optional_ops.py @@ -19,6 +19,8 @@ import abc +import six + from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -26,6 +28,7 @@ from tensorflow.python.ops import gen_dataset_ops +@six.add_metaclass(abc.ABCMeta) class Optional(object): """Wraps a nested structure of tensors that may/may not be present at runtime. @@ -169,6 +172,9 @@ def _from_tensor_list(self, flat_value): not flat_value[0].shape.is_compatible_with(tensor_shape.scalar())): raise ValueError( "OptionalStructure corresponds to a single tf.variant scalar.") + return self._from_compatible_tensor_list(flat_value) + + def _from_compatible_tensor_list(self, flat_value): # pylint: disable=protected-access return _OptionalImpl(flat_value[0], self._value_structure) diff --git a/tensorflow/python/data/util/structure.py b/tensorflow/python/data/util/structure.py index 10bb6f3aa4e75f..9a3118297dbf71 100644 --- a/tensorflow/python/data/util/structure.py +++ b/tensorflow/python/data/util/structure.py @@ -19,6 +19,8 @@ import abc +import six + from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -31,6 +33,7 @@ _STRUCTURE_CONVERSION_FUNCTION_REGISTRY = {} +@six.add_metaclass(abc.ABCMeta) class Structure(object): """Represents structural information, such as type and shape, about a value. @@ -46,7 +49,6 @@ class Structure(object): and `tf.data.Dataset.output_classes`, and similar properties and arguments in the `tf.data.Iterator` and `Optional` classes. """ - __metaclass__ = abc.ABCMeta @abc.abstractproperty def _flat_shapes(self): @@ -113,17 +115,35 @@ def _to_tensor_list(self, value): def _from_tensor_list(self, flat_value): """Builds a flat list of `tf.Tensor` into a value matching this structure. - Requires: The shapes and types of the tensors in `flat_value` must be - compatible with `self._flat_shapes` and `self._flat_types` respectively. - Args: flat_value: A list of `tf.Tensor` with compatible flat structure. Returns: A structured object matching this structure. + + Raises: + ValueError: If the shapes and types of the tensors in `flat_value` are not + compatible with `self._flat_shapes` and `self._flat_types` respectively. """ raise NotImplementedError("Structure._from_tensor_list()") + def _from_compatible_tensor_list(self, flat_value): + """A version of `_from_tensor_list()` that may avoid performing checks. + + NOTE: This method should be used to avoid checks for performance reasons, + when the validity of `flat_value` has been validated by other means. + The shapes and types of the tensors in `flat_value` must be compatible with + `self._flat_shapes` and `self._flat_types` respectively. The behavior is + undefined if this requirement is not met. + + Args: + flat_value: A list of `tf.Tensor` with compatible flat structure. + + Returns: + A structured object matching this structure. + """ + return self._from_tensor_list(flat_value) + @staticmethod def from_value(value): """Returns a `Structure` that represents the given `value`. @@ -238,6 +258,7 @@ class NestedStructure(Structure): def __init__(self, nested_structure): self._nested_structure = nested_structure + self._flat_nested_structure = nest.flatten(nested_structure) self._flat_shapes_list = [] self._flat_types_list = [] for s in nest.flatten(nested_structure): @@ -280,8 +301,7 @@ def _to_tensor_list(self, value): raise ValueError("The value %r is not compatible with the nested " "structure %r." % (value, self._nested_structure)) - for sub_value, structure in zip(flat_value, - nest.flatten(self._nested_structure)): + for sub_value, structure in zip(flat_value, self._flat_nested_structure): if not structure.is_compatible_with(Structure.from_value(sub_value)): raise ValueError("Component value %r is not compatible with the nested " "structure %r." % (sub_value, structure)) @@ -294,12 +314,18 @@ def _from_tensor_list(self, flat_value): % (len(self._flat_types), len(flat_value))) flat_ret = [] - for sub_value, structure in zip(flat_value, - nest.flatten(self._nested_structure)): + for sub_value, structure in zip(flat_value, self._flat_nested_structure): flat_ret.append(structure._from_tensor_list([sub_value])) return nest.pack_sequence_as(self._nested_structure, flat_ret) + def _from_compatible_tensor_list(self, flat_value): + flat_ret = [] + for sub_value, structure in zip(flat_value, self._flat_nested_structure): + flat_ret.append(structure._from_compatible_tensor_list([sub_value])) + + return nest.pack_sequence_as(self._nested_structure, flat_ret) + @staticmethod def from_value(value): flat_nested_structure = [ @@ -352,6 +378,9 @@ def _from_tensor_list(self, flat_value): if not self.is_compatible_with(Structure.from_value(flat_value[0])): raise ValueError("Cannot convert %r to a tensor with dtype %s and shape " "%s." % (flat_value[0], self._dtype, self._shape)) + return self._from_compatible_tensor_list(flat_value) + + def _from_compatible_tensor_list(self, flat_value): return flat_value[0] @staticmethod @@ -396,6 +425,9 @@ def _from_tensor_list(self, flat_value): not flat_value[0].shape.is_compatible_with(tensor_shape.vector(3))): raise ValueError("SparseTensorStructure corresponds to a single " "tf.variant vector of length 3.") + return self._from_compatible_tensor_list(flat_value) + + def _from_compatible_tensor_list(self, flat_value): return sparse_ops.deserialize_sparse( flat_value[0], dtype=self._dtype, rank=self._dense_shape.ndims) diff --git a/tensorflow/python/debug/__init__.py b/tensorflow/python/debug/__init__.py index 242215dccb95c3..ffbdff8c47b720 100644 --- a/tensorflow/python/debug/__init__.py +++ b/tensorflow/python/debug/__init__.py @@ -14,7 +14,7 @@ # ============================================================================== """Public Python API of TensorFlow Debugger (tfdbg). -See the [TFDBG](https://tensorflow.org/api_guides/python/tfdbg) guide. +See the [TFDBG](https://www.tensorflow.org/guide/debugger) guide. @@add_debug_tensor_watch @@watch_graph diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py index afda1fdc0de73b..ae403205b7cc08 100644 --- a/tensorflow/python/debug/wrappers/framework.py +++ b/tensorflow/python/debug/wrappers/framework.py @@ -115,6 +115,8 @@ import re import threading +import six + from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.debug.lib import debug_utils @@ -329,6 +331,7 @@ def __init__(self): pass +@six.add_metaclass(abc.ABCMeta) class BaseDebugWrapperSession(session.SessionInterface): """Base class of debug-wrapper session classes. @@ -788,7 +791,6 @@ def close(self): # TODO(cais): Add _node_name_regex_whitelist and # _node_op_type_regex_whitelist. - @abc.abstractmethod def invoke_node_stepper(self, node_stepper, restore_variable_values_on_exit=True): @@ -805,6 +807,9 @@ def invoke_node_stepper(self, The same return values as the `Session.run()` call on the same fetches as the NodeStepper. """ + raise NotImplementedError( + self.__class__.__name__ + " does not support node-stepper mode.") + def should_stop(self): if hasattr(self._sess, "should_stop"): diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 057e609a6551e8..3b4fcd2d977203 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -18,7 +18,7 @@ cc_library( "pywrap_tfe.h", ], visibility = [ - "//learning/deepmind/courier:__pkg__", + "//learning/deepmind/courier:__subpackages__", "//tensorflow:internal", ], deps = [ @@ -408,10 +408,12 @@ py_library( deps = [ ":context", ":function", + "//tensorflow/python:cond_v2", # TODO(b/118513001): Imported via control_flow_ops; remove. "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:variable_scope", + "//tensorflow/python:while_v2", # TODO(b/118513001): Imported via control_flow_ops; remove. "//tensorflow/python/training/checkpointable:base", ], ) diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 5f18ab27b7e350..844c9b52e7fda6 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -574,14 +574,10 @@ def _num_elements(grad): raise ValueError("`grad` not a Tensor or IndexedSlices.") -def _cast_constant(value, dtype): - return math_ops.cast(constant_op.constant(value), dtype) - - def _fast_fill(value, shape, dtype): return array_ops.fill( - _cast_constant(shape, dtype=dtypes.int32), - _cast_constant(value, dtype=dtype)) + constant_op.constant(shape, dtype=dtypes.int32), + constant_op.constant(value, dtype=dtype)) def _zeros(shape, dtype): @@ -599,7 +595,11 @@ def _zeros(shape, dtype): cache_key = shape, dtype, device cached = ctx.zeros_cache().get(cache_key) if cached is None: - cached = _fast_fill(0, shape, dtype) + if dtypes.as_dtype(dtype).is_bool: + value = False + else: + value = 0 + cached = _fast_fill(value, shape, dtype) ctx.zeros_cache().put(cache_key, cached) return cached @@ -608,9 +608,14 @@ def _ones(shape, dtype): if not context.context().executing_eagerly(): return array_ops.ones(shape, dtype) + if dtypes.as_dtype(dtype).is_bool: + value = True + else: + value = 1 + if shape == (): # pylint: disable=g-explicit-bool-comparison - return _cast_constant(1, dtype=dtype) - return _fast_fill(1, shape, dtype) + return constant_op.constant(value, dtype=dtype) + return _fast_fill(value, shape, dtype) _default_vspace = imperative_grad.VSpace( @@ -767,7 +772,10 @@ def _pop_tape(self): def __del__(self): if self._created_eagerly: - context.context().end_step() + try: + context.context().end_step() + except AttributeError: + pass def watch(self, tensor): """Ensures that `tensor` is being traced by this tape. diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 15e47120bf8136..e3fef524bf9f12 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -81,6 +81,57 @@ def flush(self): self._data = {} +class FunctionCallOptions(object): + """Options applied at call sites of eager functions. + Eager functions are functions decorated with tf.contrib.eager.defun. + """ + + def __init__(self, executor_type=None, rewriter_config=None): + """Constructor. + + Args: + executor_type: (optional) name of the executor to be used to execute the + eager function. If None or an empty string, the default Tensorflow + executor will be used. + rewriter_config: (optional) a rewriter_config_pb2.RewriterConfig proto or + a serialized string of that proto. + The config used by Grappler when optimizing the function graph. + Each concrete function is optimized the first time is called. Changing + rewriter_config after the first call has no effect. + If rewriter_config is None, an empty RewriterConfig will be used. + """ + self.rewriter_config_serialized = rewriter_config + self.executor_type = executor_type + + @property + def executor_type(self): + return self._executor_type + + @executor_type.setter + def executor_type(self, executor_type): + self._executor_type = executor_type + + @property + def rewriter_config_serialized(self): + return self._rewriter_config_serialized + + @rewriter_config_serialized.setter + def rewriter_config_serialized(self, config): + if isinstance(config, rewriter_config_pb2.RewriterConfig): + self._rewriter_config_serialized = config.SerializeToString() + elif isinstance(config, str): + self._rewriter_config_serialized = config + elif config is None: + self._rewriter_config_serialized = rewriter_config_pb2.RewriterConfig( + ).SerializeToString() + else: + raise ValueError( + "the rewriter config must be either a " + "rewriter_config_pb2.RewriterConfig, or a serialized string of that " + "proto or None. got: {}" + .format(type(config))) + + # TODO(agarwal): better name ? class _EagerContext(threading.local): """Thread local eager context.""" @@ -99,18 +150,16 @@ def __init__(self, config=None): self.zeros_cache = _EagerTensorCache() self.execution_mode = None - # An empty string corresponds to turning all default grappler optimizations - # on. + # Default rewriter config corresponds to turning all default grappler + # optimizations on. base_config = rewriter_config_pb2.RewriterConfig() - # TODO(b/117959922): Turn this back on once the bug is fixed. - base_config.function_optimization = rewriter_config_pb2.RewriterConfig.OFF - if config is not None and config.HasField( "graph_options") and config.graph_options.HasField("rewrite_options"): base_config.Merge(config.graph_options.rewrite_options) - self.rewriter_config = base_config.SerializeToString() + self.function_call_options = FunctionCallOptions( + rewriter_config=base_config) ContextSwitch = collections.namedtuple( @@ -375,36 +424,6 @@ def _mode(self, mode): if mode == EAGER_MODE: self.context_switches.pop() - @tf_contextlib.contextmanager - def rewriter_config(self, rewriter_config_=None): - """A context manager to allow setting the grappler rewrite options. - - Args: - rewriter_config_: A tensorflow.RewriterConfig proto object. - - Yields: - Nothing. - - Raises: - ValueError: if rewriter_config is not a tensorflow.RewriterConfig proto. - """ - if rewriter_config_ is None or not isinstance( - rewriter_config_, rewriter_config_pb2.RewriterConfig): - raise ValueError("Must pass a rewriter_config proto") - - ctx = self._eager_context - old_rewriter_config = ctx.rewriter_config - ctx.rewriter_config = rewriter_config_.SerializeToString() - try: - yield - finally: - ctx.rewriter_config = old_rewriter_config - - @property - def rewriter_config_string(self): - """Returns the serialized rewriter_config for the current thread.""" - return self._eager_context.rewriter_config - def executing_eagerly(self): """Returns True if current thread has eager executing enabled.""" return self._eager_context.is_eager @@ -533,6 +552,35 @@ def execution_mode(self, mode): finally: self.set_execution_mode(old_mode) + def get_function_call_options(self): + """Returns function call options for current thread. + + Note that the returned object is still referenced by the eager context. + + Returns: the FunctionCallOptions for current thread. + """ + return self._eager_context.function_call_options + + @tf_contextlib.contextmanager + def function_call_options(self, set_options_func): + """Context manager for setting function call options of current thread. + + Args: + set_options_func: A callable that takes one argument of type + FunctionCallOptions. It should set the properties of that + FunctionCallOptions. + + Yields: + Nothing. + """ + current_options = self.get_function_call_options() + old_options = copy.copy(current_options) + try: + set_options_func(current_options) + yield + finally: + self._eager_context.function_call_options = old_options + def async_wait(self): """Waits for ops dispatched in ASYNC mode to finish.""" pywrap_tensorflow.TFE_ContextAsyncWait(self._handle) @@ -785,6 +833,25 @@ def execution_mode(mode): return context().execution_mode(mode) +@tf_export("experimental.function_executor_type") +def function_executor_type(executor_type): + """Context manager for setting the executor of eagar defined functions. + + Eager defined functions are functions decorated by tf.contrib.eager.defun. + + Args: + executor_type: a string for the name of the executor to be used + to execute functions defined by tf.contrib.eager.defun. + + Returns: + Context manager for setting the executor of eager defined functions. + """ + def _set_options_func(options): + options.executor_type = executor_type + + return context().function_call_options(_set_options_func) + + def async_wait(): """Waits for ops dispatched in ASYNC mode to finish.""" return context().async_wait() @@ -830,9 +897,23 @@ def export_run_metadata(): return context().export_run_metadata() -def rewriter_config(rewriter_config_): - """Context manager for setting the grappler rewrite config.""" - return context().rewriter_config(rewriter_config_) +def function_rewriter_config(rewriter_config): + """Context manager for setting the grappler rewrite config. + + This config is used by Grappler when optimizing the function graph. + + Args: + rewriter_config: a rewriter_config_pb2.RewriterConfig proto or + a serialized string of that proto or None. If None, the default instance + of rewriter_config_pb2.RewriterConfig will be used. + + Returns: + A context manager. + """ + def _set_options_func(options): + options.rewriter_config_serialized = rewriter_config + + return context().function_call_options(_set_options_func) def set_server_def(server_def): diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py index acc4fecf2d4713..543dcd19ae8450 100644 --- a/tensorflow/python/eager/def_function_test.py +++ b/tensorflow/python/eager/def_function_test.py @@ -18,12 +18,37 @@ from __future__ import print_function +from tensorflow.python.eager import backprop from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_spec +from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import core from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training import adam + + +class _ModelWithOptimizer(training.Model): + + def __init__(self): + super(_ModelWithOptimizer, self).__init__() + self.dense = core.Dense(1) + self.optimizer = adam.AdamOptimizer(0.01) + + @def_function.function( + input_signature=(tensor_spec.TensorSpec([None, 2], dtypes.float32), + tensor_spec.TensorSpec([None], dtypes.float32))) + def call(self, x, y): + with backprop.GradientTape() as tape: + loss = math_ops.reduce_mean((self.dense(x) - y) ** 2.) + trainable_variables = self.trainable_variables + gradients = tape.gradient(loss, trainable_variables) + self.optimizer.apply_gradients(zip(gradients, trainable_variables)) + return {'loss': loss} class DefFunctionTest(test.TestCase): @@ -164,6 +189,12 @@ def apply(self, x): m1 = MyModel() self.assertAllEqual(m1.apply(3.0), 6.0) + def test_optimizer(self): + x = constant_op.constant([[3., 4.]]) + y = constant_op.constant([2.]) + model = _ModelWithOptimizer() + model(x, y) + if __name__ == '__main__': ops.enable_eager_execution() diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 221d52a278f893..08266a115b2e45 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -254,12 +254,14 @@ def call(self, ctx, args): raise ValueError( "Arguments and signature arguments do not match: %s %s " % (len(args), len(list(self.signature.input_arg)))) + function_call_options = ctx.get_function_call_options() outputs = functional_ops.partitioned_call( args=args, f=self, tout=self._output_types, executing_eagerly=executing_eagerly, - config=ctx.rewriter_config_string) # pylint: disable=protected-access + config=function_call_options.rewriter_config_serialized, + executor_type=function_call_options.executor_type) if executing_eagerly: return outputs @@ -537,31 +539,37 @@ def output_dtypes(self): self._func_graph.structured_outputs) def add_to_graph(self, g=None, register_gradient_functions=False): - """Registers the function into the graph g.""" + """Registers the function, adds it to the graph g or default graph.""" + # If we are not executing eagerly, adds the function to default graph if no + # graph is specified. + # In case of eager execution, function definition gets added to context + # during construction itself. + # TODO(allel/shivaniagrawal): rename this to register to reflect the # method's functionality better. Remove register_gradient_functions argument # and figure out if these needs to be registered. - if not g: - g = ops.get_default_graph() - self._inference_function.add_to_graph(g) # pylint: disable=protected-access + if not context.executing_eagerly() or g: + if not g: + g = ops.get_default_graph() + self._inference_function.add_to_graph(g) # pylint: disable=protected-access - # pylint: disable=protected-access - if register_gradient_functions: - # There are two situations for the actual call of a defun: - # 1. If none of the input args are resource variables or watch by any - # tape, and it will run the _inference_function of concrete_func for - # forward pass, the gradient will be generated by standard mechanism. - # 2. Otherwise, defun will create two functions, one for forward pass, and - # the backward pass will be created via tape. - # When registering the function, we register both cases. - if self._backward_graph_function is None: - self._construct_backprop_function() - forward_function = self._forward_function - backward_function = self._backward_graph_function._inference_function - # pylint: enable=protected-access - forward_function.add_to_graph(g) - backward_function.add_to_graph(g) + # pylint: disable=protected-access + if register_gradient_functions: + # There are two situations for the actual call of a defun: + # 1. If none of the input args are resource variables or watch by any + # tape, and it will run the _inference_function of concrete_func for + # forward pass, the gradient will be generated by standard mechanism. + # 2. Otherwise, defun will create two functions, one for forward pass, + # and the backward pass will be created via tape. + # When registering the function, we register both cases. + if self._backward_graph_function is None: + self._construct_backprop_function() + forward_function = self._forward_function + backward_function = self._backward_graph_function._inference_function + # pylint: enable=protected-access + forward_function.add_to_graph(g) + backward_function.add_to_graph(g) def _construct_backprop_function(self): """Constructs the backprop function object for this function.""" @@ -1103,9 +1111,6 @@ def register(func, *args, **kwargs): function definition into graph. Register function with different input param will result into multiple version of functions registered in graph. - Also, `args` and `kwargs` are ignored if this `PolymorphicFunction` was - created with an `input_signature`. - Args: func: the PolymorphicFunction instance that generated by a @defun *args: input arguments for the Python function. diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 651d6cec7247bc..781c3f0a18ae00 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -105,7 +105,7 @@ def add(x, y): # The default config allows everything. rewrites = rewriter_config_pb2.RewriterConfig() - with context.rewriter_config(rewrites): + with context.function_rewriter_config(rewrites): t = constant_op.constant(1.0) self.assertAllEqual(add(t, t).numpy(), 2.0) @@ -2703,6 +2703,26 @@ def testDecoratedMethodVariableCleanup(self): del m self.assertEqual([], list(weak_variables)) + def testExecutorType(self): + @function.defun + def add_five(x): + return x + 5 + + self.assertEqual( + 5, + add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy()) + + with self.assertRaisesRegexp(errors.NotFoundError, 'NON_EXISTENT_EXECUTOR'): + with context.function_executor_type('NON_EXISTENT_EXECUTOR'): + add_five(constant_op.constant(0, dtype=dtypes.int32)) + + for executor_type in ('', 'DEFAULT', None): + with context.function_executor_type(executor_type): + self.assertAllEqual( + 5, + add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy()) + + @parameterized.named_parameters( dict(testcase_name='Defun', function_decorator=function.defun), diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index a2407854fd7483..55f0896e3b4c1b 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -420,9 +420,14 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) { if (TF_GetCode(self->status) != TF_OK) { PyErr_SetString( PyExc_TypeError, - tensorflow::strings::StrCat("Error while casting from DataType ", - handle_dtype, " to ", desired_dtype, - ". ", TF_Message(self->status)) + tensorflow::strings::StrCat( + "Error while casting from DataType ", + tensorflow::DataTypeString( + static_cast(handle_dtype)), + " to ", + tensorflow::DataTypeString( + static_cast(desired_dtype)), + ". ", TF_Message(self->status)) .c_str()); // Cleanup self->status before returning. TF_SetStatus(self->status, TF_OK, ""); @@ -435,7 +440,9 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) { PyExc_TypeError, tensorflow::strings::StrCat( "Cannot convert value ", TFE_GetPythonString(value_str.get()), - " to EagerTensor with requested dtype: ", desired_dtype) + " to EagerTensor with requested dtype: ", + tensorflow::DataTypeString( + static_cast(desired_dtype))) .c_str()); return -1; } diff --git a/tensorflow/python/eager/pywrap_tfe_test.py b/tensorflow/python/eager/pywrap_tfe_test.py index 669fa084888a52..6282a6c4595c96 100644 --- a/tensorflow/python/eager/pywrap_tfe_test.py +++ b/tensorflow/python/eager/pywrap_tfe_test.py @@ -22,6 +22,7 @@ from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import core +from tensorflow.python.eager import tape from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -34,6 +35,11 @@ class Tests(test.TestCase): + def setUp(self): + # Force-load `distribution_strategy_context` to prevent GC at + # test time. See discussion in cl//219478951. + tape.distribution_strategy_context.get_distribution_strategy() + @test_util.assert_no_new_tensors @test_util.assert_no_garbage_created def testFastpathExecute_MatMulCorrectResponse(self): diff --git a/tensorflow/python/estimator/api/BUILD b/tensorflow/python/estimator/api/BUILD deleted file mode 100644 index 60e0e8c8450124..00000000000000 --- a/tensorflow/python/estimator/api/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -package( - default_visibility = [ - "//tensorflow:internal", - ], -) - -licenses(["notice"]) # Apache 2.0 - -load("//tensorflow/python/tools/api/generator:api_gen.bzl", "gen_api_init_files") -load("//tensorflow/python/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES") - -gen_api_init_files( - name = "estimator_python_api_gen", - api_name = "estimator", - output_files = ESTIMATOR_API_INIT_FILES, - output_package = "tensorflow.python.estimator.api", - package_deps = ["//tensorflow/python/estimator:estimator_py"], - packages = ["tensorflow.python.estimator"], -) diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 6032b07f69320e..cb0a340c06a811 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -1908,6 +1908,7 @@ def call(self, _): return self._embedding_weight_var +@six.add_metaclass(abc.ABCMeta) class _FeatureColumn(object): """Represents a feature column abstraction. @@ -1923,7 +1924,6 @@ class _FeatureColumn(object): This class is an abstract class. User should not create instances of this. """ - __metaclass__ = abc.ABCMeta @abc.abstractproperty def name(self): @@ -2000,8 +2000,6 @@ class _DenseColumn(_FeatureColumn): indicator_column. """ - __metaclass__ = abc.ABCMeta - @abc.abstractproperty def _variable_shape(self): """`TensorShape` of `_get_dense_tensor`, without batch dimension.""" @@ -2094,7 +2092,6 @@ class _CategoricalColumn(_FeatureColumn): A categorical feature typically handled with a `tf.SparseTensor` of IDs. """ - __metaclass__ = abc.ABCMeta IdWeightPair = collections.namedtuple( # pylint: disable=invalid-name 'IdWeightPair', ['id_tensor', 'weight_tensor']) @@ -2199,8 +2196,6 @@ def _create_categorical_column_weighted_sum(column, class _SequenceDenseColumn(_FeatureColumn): """Represents dense sequence data.""" - __metaclass__ = abc.ABCMeta - TensorSequenceLengthPair = collections.namedtuple( # pylint: disable=invalid-name 'TensorSequenceLengthPair', ['dense_tensor', 'sequence_length']) diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py index fe079af5470524..d97d41dd830f57 100644 --- a/tensorflow/python/feature_column/feature_column_v2.py +++ b/tensorflow/python/feature_column/feature_column_v2.py @@ -1749,6 +1749,7 @@ def crossed_column(keys, hash_bucket_size, hash_key=None): keys=tuple(keys), hash_bucket_size=hash_bucket_size, hash_key=hash_key) +@six.add_metaclass(abc.ABCMeta) class FeatureColumn(object): """Represents a feature column abstraction. @@ -1764,7 +1765,6 @@ class FeatureColumn(object): This class is an abstract class. Users should not create instances of this. """ - __metaclass__ = abc.ABCMeta @abc.abstractproperty def name(self): @@ -1847,8 +1847,6 @@ class DenseColumn(FeatureColumn): indicator_column. """ - __metaclass__ = abc.ABCMeta - @abc.abstractproperty def variable_shape(self): """`TensorShape` of `get_dense_tensor`, without batch dimension.""" @@ -1922,7 +1920,6 @@ class CategoricalColumn(FeatureColumn): A categorical feature typically handled with a `tf.SparseTensor` of IDs. """ - __metaclass__ = abc.ABCMeta IdWeightPair = collections.namedtuple( # pylint: disable=invalid-name 'IdWeightPair', ('id_tensor', 'weight_tensor')) @@ -2006,8 +2003,6 @@ def _create_categorical_column_weighted_sum( class SequenceDenseColumn(FeatureColumn): """Represents dense sequence data.""" - __metaclass__ = abc.ABCMeta - TensorSequenceLengthPair = collections.namedtuple( # pylint: disable=invalid-name 'TensorSequenceLengthPair', ('dense_tensor', 'sequence_length')) diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py index e36643b338ed55..48e9f0524e8e8c 100644 --- a/tensorflow/python/framework/dtypes.py +++ b/tensorflow/python/framework/dtypes.py @@ -652,8 +652,10 @@ def size(self): _QUANTIZED_DTYPES_REF = frozenset( [qint8_ref, quint8_ref, qint16_ref, quint16_ref, qint32_ref]) QUANTIZED_DTYPES = _QUANTIZED_DTYPES_REF.union(_QUANTIZED_DTYPES_NO_REF) -tf_export("dtypes.QUANTIZED_DTYPES", "QUANTIZED_DTYPES").export_constant( - __name__, "QUANTIZED_DTYPES") +tf_export( + "dtypes.QUANTIZED_DTYPES", + v1=["dtypes.QUANTIZED_DTYPES", "QUANTIZED_DTYPES"]).export_constant( + __name__, "QUANTIZED_DTYPES") _PYTHON_TO_TF = { float: float32, diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py index b1bb5626f94316..9a3751f4e51f79 100644 --- a/tensorflow/python/framework/func_graph.py +++ b/tensorflow/python/framework/func_graph.py @@ -457,6 +457,24 @@ def convert(x): return func_graph +def maybe_captured(tensor): + """If t is a captured value placeholder, returns the original captured value. + + Args: + tensor: Tensor. + + Returns: + A tensor, potentially from a different Graph/FuncGraph. + """ + if (not isinstance(tensor, ops.EagerTensor) and + tensor.op.graph.building_function and tensor.op.type == "Placeholder"): + for input_t, placeholder_t in tensor.op.graph.captures.items(): + if tensor == placeholder_t: + return maybe_captured(input_t) + # pylint: enable=protected-access + return tensor + + def device_stack_has_callable(device_stack): """Checks whether a device stack contains a callable.""" return any(callable(spec._device_name_or_function) # pylint: disable=protected-access diff --git a/tensorflow/python/framework/meta_graph.py b/tensorflow/python/framework/meta_graph.py index 33631282bd03a1..ddf6f66e8ab5e1 100644 --- a/tensorflow/python/framework/meta_graph.py +++ b/tensorflow/python/framework/meta_graph.py @@ -462,7 +462,7 @@ def _is_default_attr_value(op_def, attr_name, attr_value): return False -def _strip_graph_default_valued_attrs(meta_graph_def): +def strip_graph_default_valued_attrs(meta_graph_def): """Strips default valued attributes for node defs in given MetaGraphDef. This method also sets `meta_info_def.stripped_default_attrs` in the given @@ -587,7 +587,7 @@ def create_meta_graph_def(meta_info_def=None, # Strip default valued attributes in graph_def. if strip_default_attrs: - _strip_graph_default_valued_attrs(meta_graph_def) + strip_graph_default_valued_attrs(meta_graph_def) # Adds saver_def. if saver_def: diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index ad8ba4b2be5f15..14e4c0ca41aa2b 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -4960,6 +4960,8 @@ def container(container_name): def _colocate_with_for_gradient(op, gradient_uid, ignore_existing=False): if context.executing_eagerly(): if op is not None: + if not hasattr(op, "device"): + op = internal_convert_to_tensor_or_indexed_slices(op) return device(op.device) else: return NullContextmanager() @@ -4975,7 +4977,10 @@ def _colocate_with_for_gradient(op, gradient_uid, ignore_existing=False): op, gradient_uid=gradient_uid, ignore_existing=ignore_existing) -@tf_export("colocate_with") +@deprecation.deprecated( + date=None, + instructions="Colocations handled automatically by placer.") +@tf_export(v1=["colocate_with"]) def colocate_with(op, ignore_existing=False): return _colocate_with_for_gradient(op, None, ignore_existing=ignore_existing) @@ -5360,6 +5365,16 @@ def func(): outer_graph._device_function_stack = outer_device_stack # pylint: disable=protected-access +def executing_eagerly_outside_functions(): + """Returns True if executing eagerly, even if inside a graph function.""" + with init_scope(): + return context.executing_eagerly() + + +def inside_function(): + return get_default_graph().building_function + + @tf_export("enable_eager_execution") def enable_eager_execution(config=None, device_policy=None, diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index c945f026390191..0fb17081e758a7 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -2292,6 +2292,19 @@ def foo(): foo_compiled() self.assertEqual(ops.get_name_scope(), "") + def testExecutingEagerlyOutsideFunctions(self): + + @eager_function.defun + def f(): + return ops.executing_eagerly_outside_functions() + + with context.eager_mode(): + self.assertTrue(ops.executing_eagerly_outside_functions()) + self.assertTrue(f()) + g = ops.Graph() + with g.as_default(): + self.assertFalse(ops.executing_eagerly_outside_functions()) + class GraphTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/framework/python_op_gen_internal.cc b/tensorflow/python/framework/python_op_gen_internal.cc index f6aef5bc50b57a..65b9ad5c6a2b51 100644 --- a/tensorflow/python/framework/python_op_gen_internal.cc +++ b/tensorflow/python/framework/python_op_gen_internal.cc @@ -45,6 +45,9 @@ namespace tensorflow { namespace python_op_gen_internal { const int kRightMargin = 78; +// Names specified in tf_export decorators are exported to +// TensorFlow 2.0 by default. +const int kLatestAPIExportVersion = 2; bool IsPythonReserved(const string& s) { static const std::set* const kPythonReserved = new std::set( @@ -585,28 +588,42 @@ void GenPythonOp::AddExport() { if (api_def_.visibility() != ApiDef::VISIBLE) { return; } + // Whether op should be available in latest export version. + bool op_available_in_latest = + !api_def_.deprecation_version() || + api_def_.deprecation_version() > kLatestAPIExportVersion; - // Add @tf_export decorator. - strings::StrAppend(&result_, "@tf_export("); + string names; + string names_v1; + string deprecated_endpoints; - // Add all endpoint names to tf_export. - bool first_endpoint = true; - std::vector deprecated_endpoints; for (const auto& endpoint : api_def_.endpoint()) { - if (!first_endpoint) { - strings::StrAppend(&result_, ", "); - } else { - first_endpoint = false; - } string endpoint_name; python_op_gen_internal::GenerateLowerCaseOpName(endpoint.name(), &endpoint_name); - if (endpoint.deprecated()) { - deprecated_endpoints.push_back(endpoint_name); + if (endpoint.deprecated() || endpoint.deprecation_version() > 0) { + AddDelimiter(&deprecated_endpoints, ", "); + strings::StrAppend(&deprecated_endpoints, "'", endpoint_name, "'"); + } + // Add all endpoints to TensorFlow 1.* API. + AddDelimiter(&names_v1, ", "); + strings::StrAppend(&names_v1, "'", endpoint_name, "'"); + // Add non-deprecated endpoints to TensorFlow 2.* API. + if (op_available_in_latest && + (!endpoint.deprecation_version() || + endpoint.deprecation_version() > kLatestAPIExportVersion)) { + AddDelimiter(&names, ", "); + strings::StrAppend(&names, "'", endpoint_name, "'"); } - strings::StrAppend(&result_, "'", endpoint_name, "'"); } - strings::StrAppend(&result_, ")\n"); + + // tf_export decorator has the following format: + // @tf_export(v2_name, v2_name, v1=[v1_name, v1_name]) + if (names != names_v1) { + AddDelimiter(&names, ", "); + strings::StrAppend(&names, "v1=[", names_v1, "]"); + } + strings::StrAppend(&result_, "@tf_export(", names, ")\n"); // If all endpoints are deprecated, add @deprecated decorator. if (!api_def_.deprecation_message().empty()) { @@ -615,17 +632,8 @@ void GenPythonOp::AddExport() { } // Add @deprecated_endpoints decorator. if (!deprecated_endpoints.empty()) { - strings::StrAppend(&result_, "@deprecated_endpoints("); - bool first_endpoint = true; - for (auto& endpoint_name : deprecated_endpoints) { - if (first_endpoint) { - first_endpoint = false; - } else { - strings::StrAppend(&result_, ", "); - } - strings::StrAppend(&result_, "'", endpoint_name, "'"); - } - strings::StrAppend(&result_, ")\n"); + strings::StrAppend(&result_, "@deprecated_endpoints(", deprecated_endpoints, + ")\n"); } } diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py index a3c7bd2db29ff9..5a58d271488080 100644 --- a/tensorflow/python/framework/tensor_shape.py +++ b/tensorflow/python/framework/tensor_shape.py @@ -18,6 +18,7 @@ from __future__ import print_function from tensorflow.core.framework import tensor_shape_pb2 +from tensorflow.python import tf2 from tensorflow.python.framework import dtypes from tensorflow.python.util import compat from tensorflow.python.util.tf_export import tf_export @@ -1157,7 +1158,10 @@ def _v2_behavior(self): return _TENSORSHAPE_V2_OVERRIDE -TensorShape = TensorShapeV1 +if tf2.enabled(): + TensorShape = TensorShapeV2 +else: + TensorShape = TensorShapeV1 def scalar(): diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 1d9eacdcdbef51..768ed36917ffec 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -61,6 +61,7 @@ from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import versions from tensorflow.python.ops import array_ops @@ -764,7 +765,8 @@ def run_all_in_graph_and_eager_modes(cls): """Execute all test methods in the given class with and without eager.""" base_decorator = run_in_graph_and_eager_modes for name, value in cls.__dict__.copy().items(): - if callable(value) and name.startswith("test"): + if callable(value) and name.startswith( + "test") and not name.startswith("testSkipEager"): setattr(cls, name, base_decorator(value)) return cls @@ -832,7 +834,7 @@ def decorator(f): if tf_inspect.isclass(f): raise ValueError( "`run_test_in_graph_and_eager_modes` only supports test methods. " - "Did you mean to use `run_all_tests_in_graph_and_eager_modes`?") + "Did you mean to use `run_all_in_graph_and_eager_modes`?") def decorated(self, **kwargs): try: @@ -1134,6 +1136,9 @@ def _eval_tensor(self, tensor): return self._eval_helper(tensor()) else: try: + if sparse_tensor.is_sparse(tensor): + return sparse_tensor.SparseTensorValue(tensor.indices, tensor.values, + tensor.dense_shape) return tensor.numpy() except AttributeError as e: six.raise_from(ValueError("Unsupported type %s." % type(tensor)), e) @@ -1669,9 +1674,16 @@ def assertAllEqual(self, a, b, msg=None): msg = msg if msg else "" a = self._GetNdArray(a) b = self._GetNdArray(b) - self.assertEqual( - a.shape, b.shape, "Shape mismatch: expected %s, got %s." - " %s" % (a.shape, b.shape, msg)) + # Arbitrary bounds so that we don't print giant tensors. + if (b.ndim <= 3 or b.size < 500): + self.assertEqual( + a.shape, b.shape, "Shape mismatch: expected %s, got %s." + " Contents: %s. \n%s." % (a.shape, b.shape, b, msg)) + else: + self.assertEqual( + a.shape, b.shape, "Shape mismatch: expected %s, got %s." + " %s" % (a.shape, b.shape, msg)) + same = (a == b) if (a.dtype in [ diff --git a/tensorflow/python/framework/traceable_stack.py b/tensorflow/python/framework/traceable_stack.py index 7f4d28237ffba8..c4e35a83256c2d 100644 --- a/tensorflow/python/framework/traceable_stack.py +++ b/tensorflow/python/framework/traceable_stack.py @@ -58,7 +58,7 @@ def set_filename_and_line_from_caller(self, offset=0): frame_records = tf_stack.extract_stack() if not frame_records: return self.FAILURE - if len(frame_records) >= local_offset: + if len(frame_records) > local_offset: # Negative indexing is one-indexed instead of zero-indexed. negative_offset = -(local_offset + 1) self.filename, self.lineno = frame_records[negative_offset][:2] diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index d827fe3c30b437..9bdef21234a973 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -55,19 +55,26 @@ py_library( ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = select({ - ":empty_condition": [], - "//conditions:default": [], - }) + [ + deps = [ ":backend", ":engine", ":layers", - "//tensorflow/python/keras/optimizer_v2:optimizer_v2", - "//tensorflow/python/saved_model", + ":pil_for_keras", "//tensorflow/python:training", + "//tensorflow/python/keras/optimizer_v2", + "//tensorflow/python/saved_model", + "@keras_applications_archive//:keras_applications", ], ) +py_library( + name = "pil_for_keras", + deps = select({ + ":empty_condition": [], + "//conditions:default": [], + }), +) + py_library( name = "backend", srcs = ["backend.py"], @@ -356,15 +363,12 @@ cuda_py_test( "//tensorflow/python:client_testlib", ], shard_count = 2, - tags = [ - "no_oss", # b/117834718 - "no_windows_gpu", - ], + tags = ["no_windows_gpu"], ) py_test( name = "pooling_test", - size = "small", + size = "medium", srcs = ["layers/pooling_test.py"], srcs_version = "PY2AND3", deps = [ diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index b3beccd82bed72..b1999d9566b53c 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -2783,9 +2783,14 @@ def get_value(x): Returns: A Numpy array. + + Raises: + RuntimeError: If this method is called inside defun. """ if context.executing_eagerly(): return x.numpy() + elif ops.inside_function(): + raise RuntimeError('Cannot get value inside Tensorflow graph function.') return x.eval(session=get_session()) @@ -2798,9 +2803,14 @@ def batch_get_value(tensors): Returns: A list of Numpy arrays. + + Raises: + RuntimeError: If this method is called inside defun. """ if context.executing_eagerly(): return [x.numpy() for x in tensors] + elif ops.inside_function(): # pylint: disable=protected-access + raise RuntimeError('Cannot get value inside Tensorflow graph function.') if tensors: return get_session().run(tensors) else: @@ -2817,7 +2827,7 @@ def set_value(x, value): (of the same shape). """ value = np.asarray(value, dtype=dtype(x)) - if context.executing_eagerly(): + if ops.executing_eagerly_outside_functions(): x.assign(value) else: with get_graph().as_default(): @@ -2841,7 +2851,7 @@ def batch_set_value(tuples): tuples: a list of tuples `(tensor, value)`. `value` should be a Numpy array. """ - if context.executing_eagerly(): + if ops.executing_eagerly_outside_functions(): for x, value in tuples: x.assign(np.asarray(value, dtype=dtype(x))) else: diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 4c12c83a4c2cc0..4bdab56eb4416b 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -96,8 +96,8 @@ def configure_callbacks(callbacks, # Add additional callbacks model.history = History() stateful_metric_names = None - if hasattr(model, 'stateful_metric_names'): - stateful_metric_names = model.stateful_metric_names + if hasattr(model, 'metrics_names'): + stateful_metric_names = model.metrics_names[1:] # Exclude `loss` callbacks = [BaseLogger(stateful_metrics=stateful_metric_names) ] + (callbacks or []) + [model.history] if verbose: @@ -108,10 +108,10 @@ def configure_callbacks(callbacks, # Set callback model callback_model = model._get_callback_model() # pylint: disable=protected-access if do_validation and val_inputs and not context.executing_eagerly(): - # Need to create the test_function before start of the first epoch + # Need to create the eval_function before start of the first epoch # because TensorBoard callback on_epoch_begin adds summary to the - # list of fetches of the test_function - callback_model._make_test_function() # pylint: disable=protected-access + # list of fetches of the eval_function + callback_model._make_eval_function() # pylint: disable=protected-access callback_list.set_model(callback_model) # Set callback parameters @@ -1124,17 +1124,19 @@ def on_batch_end(self, batch, logs=None): self._total_batches_seen += 1 def on_epoch_begin(self, epoch, logs=None): - """Add histogram op to Model test_function callbacks, reset batch count.""" + """Add histogram op to Model eval_function callbacks, reset batch count.""" # check if histogram summary should be run for this epoch if self.histogram_freq and epoch % self.histogram_freq == 0: self._epoch = epoch self._current_val_batch = 0 + # pylint: disable=protected-access # add the histogram summary op if it should run this epoch - if self.merged not in self.model.test_function.fetches: - self.model.test_function.fetches.append(self.merged) - self.model.test_function.fetch_callbacks[ + if self.merged not in self.model._eval_function.fetches: + self.model._eval_function.fetches.append(self.merged) + self.model._eval_function.fetch_callbacks[ self.merged] = self._fetch_callback + # pylint: enable=protected-access def on_epoch_end(self, epoch, logs=None): """Checks if summary ops should run next epoch, logs scalar summaries.""" @@ -1152,10 +1154,12 @@ def on_epoch_end(self, epoch, logs=None): # pop the histogram summary op after each epoch if self.histogram_freq: - if self.merged in self.model.test_function.fetches: - self.model.test_function.fetches.remove(self.merged) - if self.merged in self.model.test_function.fetch_callbacks: - self.model.test_function.fetch_callbacks.pop(self.merged) + # pylint: disable=protected-access + if self.merged in self.model._eval_function.fetches: + self.model._eval_function.fetches.remove(self.merged) + if self.merged in self.model._eval_function.fetch_callbacks: + self.model._eval_function.fetch_callbacks.pop(self.merged) + # pylint: enable=protected-access if self.embeddings_data is None and self.embeddings_freq: raise ValueError('To visualize embeddings, embeddings_data must ' diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index bb85347033cdee..22efa7a378113a 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -422,8 +422,7 @@ def make_model(): num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM) model.compile( loss='categorical_crossentropy', - optimizer=keras.optimizers.SGD(lr=0.1), - metrics=['accuracy']) + optimizer=keras.optimizers.SGD(lr=0.1)) return model model = make_model() diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py index c25702d964ed52..5ce4ca4df41c7a 100644 --- a/tensorflow/python/keras/engine/sequential.py +++ b/tensorflow/python/keras/engine/sequential.py @@ -357,6 +357,12 @@ def from_config(cls, config, custom_objects=None): model.built = False return model + @property + def input_spec(self): + if self.layers and hasattr(self.layers[0], 'input_spec'): + return self.layers[0].input_spec + return None + def get_input_shape_and_dtype(layer): """Retrieve input shape and input dtype of layer if applicable. diff --git a/tensorflow/python/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py index 401dff308ad121..1401c1ed996aa0 100644 --- a/tensorflow/python/keras/engine/sequential_test.py +++ b/tensorflow/python/keras/engine/sequential_test.py @@ -23,6 +23,7 @@ from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras import testing_utils @@ -317,6 +318,15 @@ def test_variable_names(self): 'sequential/dense_1/kernel:0', 'sequential/dense_1/bias:0'], [v.name for v in model.variables]) + @tf_test_util.run_in_graph_and_eager_modes + def test_input_assumptions_propagation(self): + model = keras.models.Sequential() + model.add(keras.layers.Dense(1)) + if context.executing_eagerly(): + with self.assertRaisesRegexp(ValueError, + 'expected min_ndim=2, found ndim=0'): + model(1.0) + class TestSequentialEagerIntegration(test.TestCase): diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 7536a9a6e72f94..1847a6a38979d3 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -31,7 +31,6 @@ from tensorflow.python.keras import losses from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras import optimizers -from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import distributed_training_utils from tensorflow.python.keras.engine import training_arrays from tensorflow.python.keras.engine import training_distributed @@ -41,8 +40,6 @@ from tensorflow.python.keras.engine.network import Network from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils.generic_utils import slice_arrays -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import optimizer as tf_optimizer_module from tensorflow.python.training.checkpointable import base as checkpointable @@ -188,10 +185,16 @@ def _add_unique_metric_name(self, metric_name, output_index): def _init_metric_attributes(self): """Initialized model metric attributes.""" + # List of all metric names in the model. self.metrics_names = ['loss'] + # List of all aggregated metric result tensors. This includes aggregated + # loss result tensors. + self._stateful_metrics_tensors = [] + # List of all metric result tensors (aggregated or not - based on the + # values given in compile.) self.metrics_tensors = [] - self.metrics_updates = [] - self.stateful_metric_names = [] + # List of stateful metric functions. Used for resetting metric state during + # training/eval. This includes loss functions. self.stateful_metric_functions = [] def _set_per_output_metric_attributes(self, metrics_dict, output_index): @@ -202,15 +205,13 @@ def _set_per_output_metric_attributes(self, metrics_dict, output_index): output_index: The index of the model output for which the metric attributes are added. """ - for metric_name, metric_fn in metrics_dict.items(): + for metric_name, (_, stateful_metric_fn) in metrics_dict.items(): metric_name = self._add_unique_metric_name(metric_name, output_index) # Keep track of metric name. self.metrics_names.append(metric_name) - # Keep track of stateful metric attributes (name and metric function). - if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful: - self.stateful_metric_names.append(metric_name) - self.stateful_metric_functions.append(metric_fn) + # Keep track of stateful metric function. + self.stateful_metric_functions.append(stateful_metric_fn) def _set_metric_attributes(self, outputs, skip_target_indices=None): """Sets the metric attributes on the model for all the model outputs.""" @@ -227,7 +228,8 @@ def _handle_per_output_metrics(self, y_true, y_pred, mask, - weights=None): + weights=None, + return_stateful_result=True): """Calls metric functions for a single output. Arguments: @@ -236,52 +238,49 @@ def _handle_per_output_metrics(self, y_pred: Predicted output. mask: Computed mask value for the current output. weights: Weights to be applied on the current output. + return_stateful_result: Boolean, indicates whether the stateful + (aggregated)/stateless metric result should be returned. Returns: A list of metric result tensors. """ metric_results = [] - for metric_name, metric_fn in metrics_dict.items(): + for metric_name, (metric_fn, stateful_fn) in metrics_dict.items(): with K.name_scope(metric_name): + + def _call_stateful_fn(fn): + return training_utils.call_metric_function( + fn, y_true, y_pred, weights=weights, mask=mask) + + def _call_stateless_fn(fn): + weighted_metric_fn = training_utils.weighted_masked_objective(fn) + return weighted_metric_fn(y_true, y_pred, weights=weights, mask=mask) + + def _track_metric_tensors(stateless_result, stateful_result): + self.metrics_tensors.append(stateless_result) + self._stateful_metrics_tensors.append(stateful_result) + if isinstance(metric_fn, metrics_module.Metric): - # Call the stateful metric function. - if mask is not None: - mask = math_ops.cast(mask, y_pred.dtype) - # Update weights with mask. - if weights is None: - weights = mask - else: - # Update shape of weights if possible before adding mask. - # Update dimensions of weights to match with mask if possible. - mask, _, weights = metrics_module.squeeze_or_expand_dimensions( - mask, None, weights) - try: - # Broadcast weights if possible. - weights = weights_broadcast_ops.broadcast_weights(weights, mask) - except ValueError: - pass - # TODO(psv): Handle case when mask and weight shapes are not - # compatible. - weights *= mask - - metric_result = metric_fn(y_true, y_pred, weights) + # If the given metric fn is stateful, call the fn and return result. + metric_result = _call_stateful_fn(metric_fn) + metric_results.append(metric_result) + if not context.executing_eagerly(): + _track_metric_tensors(metric_result, metric_result) + elif context.executing_eagerly(): + # In eager mode, if the given metric fn is not stateful, we invoke the + # given fn or its stateful version based on the given flag. + if return_stateful_result: + metric_result = _call_stateful_fn(stateful_fn) + else: + metric_result = _call_stateless_fn(metric_fn) + metric_results.append(metric_result) else: - # Call the stateless metric function. - weighted_metric_fn = training_utils.weighted_masked_objective( - metric_fn) - metric_result = weighted_metric_fn( - y_true, y_pred, weights=weights, mask=mask) - - if not context.executing_eagerly(): - # Keep track of metric result tensor. - self.metrics_tensors.append(metric_result) - - metric_results.append(metric_result) - is_stateful = isinstance(metric_fn, - base_layer.Layer) and metric_fn.stateful - if is_stateful and not context.executing_eagerly(): - # Keep track of updates created by stateful metrics. - self.metrics_updates += metric_fn.updates + # In graph mode, we build the sub-graph for both the stateful and the + # stateless fns. + stateful_metric_result = _call_stateful_fn(stateful_fn) + metric_result = _call_stateless_fn(metric_fn) + _track_metric_tensors(metric_result, stateful_metric_result) + return metric_results def _handle_metrics(self, @@ -289,7 +288,8 @@ def _handle_metrics(self, skip_target_indices=None, targets=None, sample_weights=None, - masks=None): + masks=None, + return_stateful_result=True): """Handles calling metric functions. Arguments: @@ -298,6 +298,8 @@ def _handle_metrics(self, targets: List of targets. sample_weights: Optional list of sample weight arrays. masks: List of computed output mask values. + return_stateful_result: Boolean, indicates whether the stateful + (aggregated)/stateless metric result should be returned. Returns: A list of metric result tensors. @@ -312,15 +314,20 @@ def _handle_metrics(self, target = targets[i] if targets else None output_mask = masks[i] if masks else None metric_results.extend( - self._handle_per_output_metrics(self._per_output_metrics[i], target, - output, output_mask)) + self._handle_per_output_metrics( + self._per_output_metrics[i], + target, + output, + output_mask, + return_stateful_result=return_stateful_result)) metric_results.extend( self._handle_per_output_metrics( self._per_output_weighted_metrics[i], target, output, output_mask, - weights=sample_weights[i])) + weights=sample_weights[i], + return_stateful_result=return_stateful_result)) return metric_results @checkpointable.no_automatic_dependency_tracking @@ -474,16 +481,14 @@ def compile(self, loss_functions = [loss_function for _ in range(len(self.outputs))] self.loss_functions = loss_functions - weighted_losses = [training_utils.weighted_masked_objective(fn) - for fn in loss_functions] skip_target_indices = [] skip_target_weighing_indices = [] self._feed_outputs = [] self._feed_output_names = [] self._feed_output_shapes = [] self._feed_loss_fns = [] - for i in range(len(weighted_losses)): - if weighted_losses[i] is None: + for i in range(len(loss_functions)): + if loss_functions[i] is None: skip_target_indices.append(i) skip_target_weighing_indices.append(i) @@ -618,14 +623,30 @@ def compile(self, continue y_true = self.targets[i] y_pred = self.outputs[i] - weighted_loss = weighted_losses[i] + loss_fn = loss_functions[i] sample_weight = self.sample_weights[i] mask = masks[i] loss_weight = loss_weights_list[i] with K.name_scope(self.output_names[i] + '_loss'): + weighted_loss = training_utils.weighted_masked_objective(loss_fn) output_loss = weighted_loss(y_true, y_pred, sample_weight, mask) + if len(self.outputs) > 1: + # Keep track of the un-aggregated loss result tensor. self.metrics_tensors.append(output_loss) + + # Keep track of stateful result tensor and function for the loss. + mean_wrapped_loss = metrics_module.MeanMetricWrapper( + loss_fn, name=loss_fn.__name__) + result_tensor = training_utils.call_metric_function( + mean_wrapped_loss, + y_true, + y_pred, + weights=sample_weight, + mask=mask) + self._stateful_metrics_tensors.append(result_tensor) + self.stateful_metric_functions.append(mean_wrapped_loss) + self.metrics_names.append(self.output_names[i] + '_loss') if total_loss is None: total_loss = loss_weight * output_loss @@ -664,6 +685,8 @@ def compile(self, # This saves time when the user is not using all functions. self._function_kwargs = kwargs + self._fit_function = None + self._eval_function = None self.train_function = None self.test_function = None self.predict_function = None @@ -690,11 +713,11 @@ def _check_trainable_weights_consistency(self): ' trainable weights, did you set `model.trainable`' ' without calling `model.compile` after ?', 1) - def _make_train_function(self): - if not hasattr(self, 'train_function'): + def _make_train_function_helper(self, fn_name, outputs, metric_updates=None): + if not hasattr(self, fn_name): raise RuntimeError('You must compile your model before using it.') self._check_trainable_weights_consistency() - if self.train_function is None: + if getattr(self, fn_name) is None: inputs = (self._feed_inputs + self._feed_targets + self._feed_sample_weights) @@ -710,31 +733,62 @@ def _make_train_function(self): updates += self.get_updates_for(None) # Conditional updates relevant to this model updates += self.get_updates_for(self.inputs) - # Stateful metrics updates - updates += self.metrics_updates + # Add stateful metrics updates. + if metric_updates is not None: + updates += metric_updates # Gets loss and metrics. Updates weights at each call. - self.train_function = K.function( - inputs, [self.total_loss] + self.metrics_tensors, + fn = K.function( + inputs, + outputs, updates=updates, name='train_function', **self._function_kwargs) + setattr(self, fn_name, fn) - def _make_test_function(self): - if not hasattr(self, 'test_function'): + def _make_train_function(self): + self._make_train_function_helper('train_function', + [self.total_loss] + self.metrics_tensors) + + def _make_fit_function(self): + # TODO(psv/anjalisridhar): Remove updates after we fix b/118841692 + # Stateful metrics updates + metric_updates = [] + for m in self.stateful_metric_functions: + metric_updates += m.updates + self._make_train_function_helper( + '_fit_function', [self.total_loss] + self._stateful_metrics_tensors, + metric_updates) + + def _make_test_function_helper(self, fn_name, outputs, metric_updates=None): + if not hasattr(self, fn_name): raise RuntimeError('You must compile your model before using it.') - if self.test_function is None: + if getattr(self, fn_name) is None: inputs = (self._feed_inputs + self._feed_targets + self._feed_sample_weights) if self.uses_learning_phase and not isinstance(K.learning_phase(), int): inputs += [K.learning_phase()] + updates = self.state_updates + # Add stateful metrics updates. + if metric_updates is not None: + updates += metric_updates # Return loss and metrics, no gradient updates. # Does update the network states. - self.test_function = K.function( - inputs, [self.total_loss] + self.metrics_tensors, - updates=self.state_updates + self.metrics_updates, + fn = K.function( + inputs, + outputs, + updates=updates, name='test_function', **self._function_kwargs) + setattr(self, fn_name, fn) + + def _make_test_function(self): + self._make_test_function_helper('test_function', + [self.total_loss] + self.metrics_tensors) + + def _make_eval_function(self): + self._make_test_function_helper( + '_eval_function', [self.total_loss] + self._stateful_metrics_tensors) def _make_predict_function(self): if not hasattr(self, 'predict_function'): @@ -770,7 +824,8 @@ def _distribution_standardize_user_data(self, check_steps=False, steps_name='steps', steps=None, - validation_split=0): + validation_split=0, + shuffle=False): """Runs validation checks on input and target data passed by the user. This is called when using DistributionStrategy to train, evaluate or serve @@ -793,6 +848,7 @@ def _distribution_standardize_user_data(self, execute. validation_split: Float between 0 and 1. Fraction of the training data to be used as validation data. + shuffle: Boolean whether to shuffle the training data before each epoch. Returns: Iterator for reading the dataset `x`. @@ -845,11 +901,12 @@ def _distribution_standardize_user_data(self, x = dataset_ops.Dataset.from_tensor_slices((var_x, var_y)) x = dataset_ops.Dataset.from_tensor_slices((var_x, var_y)) - # 1024 is a good buffer size since it is much larger than the average - # batch size provided by the user and provides sufficient randomness. - # One thing to keep in mind is the memory usage based on the size of - # each sample. - x = x.shuffle(1024) + if shuffle: + # 1024 is a good buffer size since it is much larger than the average + # batch size provided by the user and provides sufficient randomness. + # One thing to keep in mind is the memory usage based on the size of + # each sample. + x = x.shuffle(1024) x = x.repeat() x = x.batch(batch_size, drop_remainder=drop_remainder) y = None @@ -887,7 +944,8 @@ def _standardize_user_data(self, check_steps=False, steps_name='steps', steps=None, - validation_split=0): + validation_split=0, + shuffle=False): """Runs validation checks on input and target data passed by the user. Also standardizes the data to lists of arrays, in order. @@ -929,6 +987,7 @@ def _standardize_user_data(self, execute. validation_split: Float between 0 and 1. Fraction of the training data to be used as validation data. + shuffle: Boolean whether to shuffle the training data before each epoch. Returns: A tuple of 3: inputs (arrays or dicts, depending on whether `x` was a dict @@ -951,7 +1010,8 @@ def _standardize_user_data(self, check_steps=check_steps, steps_name=steps_name, steps=steps, - validation_split=validation_split) + validation_split=validation_split, + shuffle=shuffle) return iterator, None, None if isinstance(x, dataset_ops.Dataset): @@ -1489,7 +1549,8 @@ def fit(self, check_steps=True, steps_name='steps_per_epoch', steps=steps_per_epoch, - validation_split=validation_split) + validation_split=validation_split, + shuffle=shuffle) # Prepare validation data. if validation_data: @@ -1893,7 +1954,7 @@ class indices (integers) to ins = x + y + sample_weights self._make_train_function() - outputs = self.train_function(ins) + outputs = self.train_function(ins) # pylint: disable=not-callable if len(outputs) == 1: return outputs[0] @@ -1951,7 +2012,7 @@ def test_on_batch(self, x, y=None, sample_weight=None): else: ins = x + y + sample_weights self._make_test_function() - outputs = self.test_function(ins) + outputs = self.test_function(ins) # pylint: disable=not-callable if len(outputs) == 1: return outputs[0] diff --git a/tensorflow/python/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py index 95b864bef028ec..bad1cb40f24505 100644 --- a/tensorflow/python/keras/engine/training_arrays.py +++ b/tensorflow/python/keras/engine/training_arrays.py @@ -83,8 +83,8 @@ def fit_loop(model, Raises: ValueError: in case of invalid arguments. """ - model._make_train_function() - f = model.train_function + model._make_fit_function() + f = model._fit_function sample_weights = sample_weights or [] val_sample_weights = val_sample_weights or [] @@ -366,8 +366,8 @@ def test_loop(model, and/or metrics). The attribute `model.metrics_names` will give you the display labels for the scalar outputs. """ - model._make_test_function() - f = model.test_function + model._make_eval_function() + f = model._eval_function sample_weights = sample_weights or [] inputs = training_utils.ModelInputs(inputs).as_list() @@ -379,12 +379,6 @@ def test_loop(model, if hasattr(model, 'metrics'): for m in model.stateful_metric_functions: m.reset_states() - stateful_metric_indices = [ - i for i, name in enumerate(model.metrics_names) - if str(name) in model.stateful_metric_names - ] - else: - stateful_metric_indices = [] num_samples = training_utils.check_num_samples( ins, batch_size, steps, 'steps') @@ -409,20 +403,15 @@ def test_loop(model, if step == 0: for _ in enumerate(batch_outs): outs.append(0.) - for i, batch_out in enumerate(batch_outs): - if i in stateful_metric_indices: - outs[i] = batch_out - else: - outs[i] += batch_out + outs[0] += batch_outs[0] # index 0 = 'loss' + outs[1:] = batch_outs[1:] else: if step == 0: outs.append(0.) outs[0] += batch_outs if verbose == 1: progbar.update(step + 1) - for i in range(len(outs)): - if i not in stateful_metric_indices: - outs[i] /= steps + outs[0] /= steps else: batches = make_batches(num_samples, batch_size) index_array = np.arange(num_samples) @@ -441,20 +430,15 @@ def test_loop(model, if isinstance(batch_outs, list): if batch_index == 0: outs.extend([0.] * len(batch_outs)) - for i, batch_out in enumerate(batch_outs): - if i in stateful_metric_indices: - outs[i] = batch_out - else: - outs[i] += batch_out * len(batch_ids) + outs[0] += batch_outs[0] * len(batch_ids) # index 0 = 'loss' + outs[1:] = batch_outs[1:] else: if batch_index == 0: outs.append(0.) outs[0] += batch_outs * len(batch_ids) if verbose == 1: progbar.update(batch_end) - for i in range(len(outs)): - if i not in stateful_metric_indices: - outs[i] /= num_samples + outs[0] /= num_samples if len(outs) == 1: return outs[0] return outs diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py index 418bebccb091b8..8550b960557f62 100644 --- a/tensorflow/python/keras/engine/training_distributed.py +++ b/tensorflow/python/keras/engine/training_distributed.py @@ -93,20 +93,18 @@ def fit_loop( if not model._grouped_model: clone_model_on_replicas(model, current_strategy, make_callback_model=True) - def _per_device_train_function(model): - model._make_train_function() - return (model.train_function.inputs, - model.train_function.outputs, - model.train_function.updates_op, - model.train_function.session_kwargs) + def _per_device_fit_function(model): + model._make_fit_function() + return (model._fit_function.inputs, model._fit_function.outputs, + model._fit_function.updates_op, model._fit_function.session_kwargs) inputs, targets, sample_weights = _get_input_from_iterator(iterator, model) with current_strategy.scope(): # Create train ops on each of the devices when we call - # `_per_device_train_function`. + # `_per_device_fit_function`. (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_replica( - _per_device_train_function, model._grouped_model) + _per_device_fit_function, model._grouped_model) # Unwrap all the per device values returned from `call_for_each_replica`. # Unwrapping per device values gives you a list of values that can be # used to construct a new train function that is composed of update ops on @@ -124,10 +122,11 @@ def _per_device_train_function(model): current_strategy, targets) # Create a train function that is composed of all the parameters above. - distributed_train_function = K.Function( - all_inputs, all_outputs, + distributed_fit_function = K.Function( + all_inputs, + all_outputs, updates=all_updates, - name='distributed_train_function', + name='distributed_fit_function', **all_session_args) # We need to set sample_weights to None since there are sample weight @@ -173,7 +172,7 @@ def _per_device_train_function(model): batch_logs = {'batch': step_index, 'size': 1} callbacks.on_batch_begin(step_index, batch_logs) try: - outs = distributed_train_function(ins) + outs = distributed_fit_function(ins) except errors.OutOfRangeError: logging.warning('Your dataset iterator ran out of data; ' 'interrupting training. Make sure that your dataset ' @@ -184,11 +183,6 @@ def _per_device_train_function(model): if not isinstance(outs, list): outs = [outs] - - outs = _aggregate_metrics_across_replicas(current_strategy.num_replicas, - out_labels, - model.stateful_metric_names, - outs) for l, o in zip(out_labels, outs): batch_logs[l] = o callbacks.on_batch_end(step_index, batch_logs) @@ -256,19 +250,17 @@ def _experimental_fit_loop( K.get_session().run(current_strategy.initialize()) - def _per_device_train_function(model): - model._make_train_function() - return (model.train_function.inputs, - model.train_function.outputs, - model.train_function.updates_op, - model.train_function.session_kwargs) + def _per_device_fit_function(model): + model._make_fit_function() + return (model._fit_function.inputs, model._fit_function.outputs, + model._fit_function.updates_op, model._fit_function.session_kwargs) # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here. K.set_learning_phase(1) out_labels = model.metrics_names or [] def step_fn(ctx, inputs, targets): - """Clones the model and calls make_train_function.""" + """Clones the model and calls make_fit_function.""" # TODO(priyag, sourabhbajaj): The model gets cloned every time # fit/test/predict is called. We should look into caching this keyed on # input shapes. @@ -282,15 +274,16 @@ def step_fn(ctx, inputs, targets): (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_replica( - _per_device_train_function, model._grouped_model_train) + _per_device_fit_function, model._grouped_model_train) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.Function( - all_inputs, all_outputs, + all_inputs, + all_outputs, updates=all_updates, - name='distributed_train_function', + name='distributed_fit_function', **all_session_args) for label, output in zip(out_labels, combined_fn.outputs): @@ -444,18 +437,17 @@ def test_loop(model, iterator, verbose=0, steps=None): if not model._grouped_model: clone_model_on_replicas(model, current_strategy) - def _per_device_test_function(model): - model._make_test_function() - return (model.test_function.inputs, - model.test_function.outputs, - model.test_function.updates_op, - model.test_function.session_kwargs) + def _per_device_eval_function(model): + model._make_eval_function() + return (model._eval_function.inputs, model._eval_function.outputs, + model._eval_function.updates_op, + model._eval_function.session_kwargs) inputs, targets, sample_weights = _get_input_from_iterator(iterator, model) with current_strategy.scope(): (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_replica( - _per_device_test_function, model._grouped_model) + _per_device_eval_function, model._grouped_model) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( @@ -484,10 +476,6 @@ def _per_device_test_function(model): for m in model.stateful_metric_functions: m.reset_states() - stateful_metric_indices = [ - i for i, name in enumerate(model.metrics_names) - if str(name) in model.stateful_metric_names - ] outs = [] if verbose == 1: @@ -502,26 +490,18 @@ def _per_device_test_function(model): assert steps is not None for step in range(steps): batch_outs = distributed_test_function(ins) - batch_outs = _aggregate_metrics_across_replicas( - current_strategy.num_replicas, model.metrics_names, - model.stateful_metric_names, batch_outs) if isinstance(batch_outs, list): if step == 0: outs = [0.] * len(batch_outs) - for i, batch_out in enumerate(batch_outs): - if i in stateful_metric_indices: - outs[i] = batch_out - else: - outs[i] += batch_out + outs[0] += batch_outs[0] # index 0 = 'loss' + outs[1:] = batch_outs[1:] else: if step == 0: outs.append(0.) - outs[0] += batch_outs + outs[0] += batch_outs # index 0 = 'loss' if verbose >= 1: progbar.update(step + 1) - for i in range(len(outs)): - if i not in stateful_metric_indices: - outs[i] /= steps + outs[0] /= steps # index 0 = 'loss' if len(outs) == 1: return outs[0] @@ -552,18 +532,17 @@ def _experimental_test_loop(model, iterator, verbose=0, steps=None, if initialize_finalize_strategy: K.get_session().run(current_strategy.initialize()) - def _per_device_test_function(model): - model._make_test_function() - return (model.test_function.inputs, - model.test_function.outputs, - model.test_function.updates_op, - model.test_function.session_kwargs) + def _per_device_eval_function(model): + model._make_eval_function() + return (model._eval_function.inputs, model._eval_function.outputs, + model._eval_function.updates_op, + model._eval_function.session_kwargs) # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here. K.set_learning_phase(0) def step_fn(ctx, inputs, targets): - """Clones the model and calls make_test_function.""" + """Clones the model and calls make_eval_function.""" # TODO(priyag, sourabhbajaj): The model gets cloned every time # fit/test/predict is called. We should look into caching this keyed on # input shapes. @@ -577,7 +556,7 @@ def step_fn(ctx, inputs, targets): (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_replica( - _per_device_test_function, model._grouped_model_test) + _per_device_eval_function, model._grouped_model_test) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( @@ -911,45 +890,6 @@ def clone_model_on_replicas(model, strategy, make_callback_model=False, model._make_callback_model(grouped_model) -def _aggregate_metrics_across_replicas(num_devices, out_labels, - stateful_metric_names, outs): - """Aggregates stateless metrics values across replicas. - - When using `MirroredStrategy`, the number of replicas is equal to the - number of devices over which training is distributed. This may not always be - the case. - - Args: - num_devices: Number of devices over which the model is being distributed. - out_labels: The list of metric names passed to `compile`. - stateful_metric_names: List of stateful metric names on the model. - outs: The output from all the replicas. - - Returns: - The average value of each metric across the replicas. - """ - # TODO(anjalisridhar): Temporary workaround for aggregating metrics - # across replicas. Replace with the new metrics module eventually. - merged_output = [] - # The first output is the total loss. - merged_output.append(outs[0]) - current_index = 1 - # Each label in `out_labels` corresponds to one set of metrics. The - # number of metric values corresponds to the number of devices. We - # currently take the mean of the values. - for metric_name in out_labels[1:]: - if metric_name in stateful_metric_names: - # For stateful metrics, we get one aggregated result value. - merged_output.append(outs[current_index]) - current_index += 1 - else: - m = np.mean(outs[current_index:current_index + num_devices]) - merged_output.append(m) - current_index += num_devices - - return merged_output - - def _get_input_from_iterator(iterator, model): """Get elements from the iterator and verify the input shape and type.""" next_element = iterator.get_next() diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py index 955493a8123f01..9131df5cd0a355 100644 --- a/tensorflow/python/keras/engine/training_eager.py +++ b/tensorflow/python/keras/engine/training_eager.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend from tensorflow.python.keras import callbacks as cbks +from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.utils import generic_utils from tensorflow.python.ops import math_ops @@ -48,7 +49,12 @@ def _eager_loss_fn(outputs, targets, loss_fn, output_name): return loss -def _eager_metrics_fn(model, outputs, targets, sample_weights=None, masks=None): +def _eager_metrics_fn(model, + outputs, + targets, + sample_weights=None, + masks=None, + return_stateful_result=True): """Calculates the metrics for each output of the given model. Arguments: @@ -57,6 +63,8 @@ def _eager_metrics_fn(model, outputs, targets, sample_weights=None, masks=None): targets: The predictions or targets of the given model. sample_weights: Optional list of sample weights for each output. masks: Optional list of masks for each output. + return_stateful_result: Boolean, indicates whether the stateful + (aggregated)/stateless metric result should be returned. Returns: Returns the metric results for each output of the model. @@ -65,11 +73,20 @@ def _eager_metrics_fn(model, outputs, targets, sample_weights=None, masks=None): targets = generic_utils.to_list(targets) # TODO(psv): Consider supporting skip target indices in eager mode? metric_results = model._handle_metrics( - outputs, targets=targets, sample_weights=sample_weights, masks=masks) + outputs, + targets=targets, + sample_weights=sample_weights, + masks=masks, + return_stateful_result=return_stateful_result) return [backend.mean(t) for t in metric_results] -def _model_loss(model, inputs, targets, sample_weights=None, training=False): +def _model_loss(model, + inputs, + targets, + output_loss_metrics=None, + sample_weights=None, + training=False): """Calculates the loss for a given model. Arguments: @@ -77,6 +94,8 @@ def _model_loss(model, inputs, targets, sample_weights=None, training=False): inputs: Either a dictionary of inputs to the model or a list of input arrays. targets: List of target arrays. + output_loss_metrics: List of metrics that are used to aggregated output + loss values. sample_weights: Optional list of sample weight arrays. training: Whether the model should be run in inference or training mode. @@ -106,6 +125,7 @@ def _model_loss(model, inputs, targets, sample_weights=None, training=False): targets = generic_utils.to_list(targets) loss_metrics = [] + aggregated_loss_metrics = [] with backend.name_scope('loss'): for i, loss_fn in enumerate(model.loss_functions): if sample_weights: @@ -125,6 +145,16 @@ def _model_loss(model, inputs, targets, sample_weights=None, training=False): if len(model.outputs) > 1: loss_metrics.append(backend.mean(output_loss)) + if output_loss_metrics is not None: + # Keep track of the stateful loss result. + aggregated_loss_metrics.append( + training_utils.call_metric_function( + output_loss_metrics[i], + targets[i], + outs[i], + weights=weights, + mask=mask)) + loss_weight = model.loss_weights_list[i] if total_loss is None: total_loss = loss_weight * output_loss @@ -138,7 +168,7 @@ def _model_loss(model, inputs, targets, sample_weights=None, training=False): total_loss += math_ops.add_n(custom_losses) model._clear_losses() - return outs, total_loss, loss_metrics, masks + return outs, total_loss, loss_metrics, aggregated_loss_metrics, masks def _maybe_build_graph_functions(model): @@ -156,22 +186,30 @@ def _maybe_build_graph_functions(model): model._built_graph_functions = True -def _maybe_graph_function_model_loss( - model, - inputs, - targets, - sample_weights=None, - training=False): +def _maybe_graph_function_model_loss(model, + inputs, + targets, + output_loss_metrics=None, + sample_weights=None, + training=False): """Compute model loss, using defun if the model supports it.""" if model._can_use_graph_functions: _maybe_build_graph_functions(model) return model._eager_model_loss_graph_function( - model, inputs, targets, - sample_weights=sample_weights, training=training) + model, + inputs, + targets, + output_loss_metrics=output_loss_metrics, + sample_weights=sample_weights, + training=training) else: - return _model_loss(model, inputs, targets, - sample_weights=sample_weights, - training=training) + return _model_loss( + model, + inputs, + targets, + output_loss_metrics=output_loss_metrics, + sample_weights=sample_weights, + training=training) def _maybe_graph_function_model_call(model, *args, **kwargs): @@ -196,7 +234,8 @@ def iterator_fit_loop(model, callbacks=None, validation_steps=None, do_validation=False, - batch_size=None): + batch_size=None, + output_loss_metrics=None): """Fit function for eager execution when input is given as dataset iterator. Updates the given epoch logs. @@ -222,6 +261,8 @@ def iterator_fit_loop(model, do_validation: Boolean value indicating whether we should do validation. batch_size: int, val_inputs and val_targets will be evaled batch by batch with size batch_size if they are array. + output_loss_metrics: List of metrics that are used to aggregated output + loss values. Raises: ValueError: In case of mismatch between given number of inputs and @@ -276,7 +317,7 @@ def iterator_fit_loop(model, for cbk in callbacks: if (isinstance(cbk, cbks.BaseLogger) or isinstance(cbk, cbks.ProgbarLogger)): - cbk.stateful_metrics = model.stateful_metric_names + cbk.stateful_metrics = model.metrics_names[1:] # Exclude `loss` if step_index == 0 and not callbacks.params['metrics']: callback_metrics = copy.copy(model.metrics_names) @@ -293,21 +334,26 @@ def iterator_fit_loop(model, }) # Train model. - outs, loss, loss_metrics, masks = ( - _maybe_graph_function_process_single_batch( - model, x, y, sample_weights=sample_weights, training=True)) + outs, loss, _, aggregated_loss_metrics, masks = \ + _maybe_graph_function_process_single_batch( + model, + x, + y, + output_loss_metrics=output_loss_metrics, + sample_weights=sample_weights, + training=True) outs = generic_utils.to_list(outs) # Calculate metrics. for l, o in zip(model.metrics_names, outs): batch_logs[l] = o - # Required for eager execution metrics_results = _eager_metrics_fn( model, outs, y, sample_weights=sample_weights, masks=masks) batch_logs['loss'] = tensor_util.constant_value(backend.mean(loss)) - for k, v in zip(model.metrics_names, - [backend.mean(loss)] + loss_metrics + metrics_results): + for k, v in zip( + model.metrics_names, + [backend.mean(loss)] + aggregated_loss_metrics + metrics_results): batch_logs[k] = tensor_util.constant_value(v) callbacks.on_batch_end(step_index, batch_logs) if callbacks.model.stop_training: @@ -357,6 +403,15 @@ def iterator_test_loop(model, inputs, steps, verbose=0): raise ValueError('Please provide either inputs and targets' 'or inputs, targets, and sample_weights') outs = [] + + # Create metric wrapper for the losses. + output_loss_metrics = [] + for i in range(len(model.outputs)): + loss_fn = model.loss_functions[i] + mean_wrapped_loss = metrics_module.MeanMetricWrapper( + loss_fn, name=loss_fn.__name__) + output_loss_metrics.append(mean_wrapped_loss) + num_samples = 0 if verbose == 1: progbar = generic_utils.Progbar(target=steps) @@ -397,21 +452,24 @@ def iterator_test_loop(model, inputs, steps, verbose=0): if hasattr(model, 'metrics'): for m in model.stateful_metric_functions: m.reset_states() - stateful_metric_indices = [ - i for i, name in enumerate(model.metrics_names) - if str(name) in model.stateful_metric_names - ] - else: - stateful_metric_indices = [] + for m in output_loss_metrics: + m.reset_states() # Calculate model output, loss values. - loss_outs, loss, loss_metrics, masks = _maybe_graph_function_model_loss( - model, x, y, sample_weights=sample_weights, training=False) + loss_outs, loss, _, aggregated_loss_metrics, masks = \ + _maybe_graph_function_model_loss( + model, + x, + y, + output_loss_metrics=output_loss_metrics, + sample_weights=sample_weights, + training=False) metrics_results = _eager_metrics_fn( model, loss_outs, y, sample_weights=sample_weights, masks=masks) batch_outs = [] - for _, v in zip(model.metrics_names, - [backend.mean(loss)] + loss_metrics + metrics_results): + for _, v in zip( + model.metrics_names, + [backend.mean(loss)] + aggregated_loss_metrics + metrics_results): batch_outs.append(tensor_util.constant_value(v)) # Get current step size. @@ -428,20 +486,15 @@ def iterator_test_loop(model, inputs, steps, verbose=0): if step_index == 0: for _ in enumerate(batch_outs): outs.append(0.) - for i, batch_out in enumerate(batch_outs): - if i in stateful_metric_indices: - outs[i] = batch_out - else: - outs[i] += batch_out * step_size + outs[0] += batch_outs[0] * step_size # index 0 = 'loss' + outs[1:] = batch_outs[1:] # Calculate sample size. num_samples += step_size if verbose == 1: progbar.update(step_index + 1) - for i in range(len(outs)): - if i not in stateful_metric_indices: - outs[i] /= num_samples + outs[0] /= num_samples # index 0 = 'loss' if len(outs) == 1: return outs[0] return outs @@ -527,6 +580,7 @@ def iterator_predict_loop(model, inputs, steps, verbose=0): def _process_single_batch(model, inputs, targets, + output_loss_metrics=None, sample_weights=None, training=False): """Calculate the loss and gradient for one input batch. @@ -537,6 +591,8 @@ def _process_single_batch(model, model: Model whose loss has to be calculated. inputs: List of input arrays. targets: List of target arrays. + output_loss_metrics: List of metrics that are used to aggregated output + loss values. sample_weights: Optional list of sample weight arrays. training: The boolean represents if the weights of the model are updated. 'fit' methods will set this to True while 'evaluate' methods will @@ -551,12 +607,14 @@ def _process_single_batch(model, """ with backend.learning_phase_scope(1 if training else 0): with GradientTape() as tape: - outs, loss, loss_metrics, masks = _model_loss( - model, - inputs, - targets, - sample_weights=sample_weights, - training=training) + outs, loss, loss_metrics, aggregated_loss_metrics, masks\ + = _model_loss( + model, + inputs, + targets, + output_loss_metrics=output_loss_metrics, + sample_weights=sample_weights, + training=training) if loss is None: raise ValueError('The model cannot be run ' 'because it has no loss to optimize.') @@ -569,25 +627,33 @@ def _process_single_batch(model, grads = tape.gradient(loss, model._collected_trainable_weights) model.optimizer.apply_gradients(zip(grads, model._collected_trainable_weights)) - return outs, loss, loss_metrics, masks + return outs, loss, loss_metrics, aggregated_loss_metrics, masks -def _maybe_graph_function_process_single_batch( - model, - inputs, - targets, - sample_weights=None, - training=False): +def _maybe_graph_function_process_single_batch(model, + inputs, + targets, + output_loss_metrics=None, + sample_weights=None, + training=False): """Process a single batch, using defun if the model supports it.""" if model._can_use_graph_functions: _maybe_build_graph_functions(model) return model._eager_process_single_batch_graph_function( - model, inputs, targets, sample_weights=sample_weights, + model, + inputs, + targets, + output_loss_metrics=output_loss_metrics, + sample_weights=sample_weights, training=training) else: - return _process_single_batch(model, inputs, targets, - sample_weights=sample_weights, - training=training) + return _process_single_batch( + model, + inputs, + targets, + output_loss_metrics=output_loss_metrics, + sample_weights=sample_weights, + training=training) def train_on_batch(model, inputs, targets, sample_weights=None): @@ -618,12 +684,18 @@ def train_on_batch(model, inputs, targets, sample_weights=None): if val is not None else None for val in sample_weights ] - outs, loss, loss_metrics, masks = _maybe_graph_function_process_single_batch( - model, inputs, targets, sample_weights=sample_weights, training=True) + outs, loss, loss_metrics, _, masks = \ + _maybe_graph_function_process_single_batch( + model, inputs, targets, sample_weights=sample_weights, training=True) if not isinstance(outs, list): outs = [outs] metrics_results = _eager_metrics_fn( - model, outs, targets, sample_weights=sample_weights, masks=masks) + model, + outs, + targets, + sample_weights=sample_weights, + masks=masks, + return_stateful_result=False) loss = generic_utils.to_list(loss) return [ @@ -659,12 +731,17 @@ def test_on_batch(model, inputs, targets, sample_weights=None): ops.convert_to_tensor(val, dtype=backend.floatx()) if val is not None else None for val in sample_weights ] - outs, loss, loss_metrics, masks = _maybe_graph_function_model_loss( + outs, loss, loss_metrics, _, masks = _maybe_graph_function_model_loss( model, inputs, targets, sample_weights=sample_weights, training=False) if not isinstance(outs, list): outs = [outs] metrics_results = _eager_metrics_fn( - model, outs, targets, sample_weights=sample_weights, masks=masks) + model, + outs, + targets, + sample_weights=sample_weights, + masks=masks, + return_stateful_result=False) loss = generic_utils.to_list(loss) return [ @@ -746,12 +823,24 @@ def fit_loop(model, validation_steps=validation_steps, verbose=verbose) + # Create metric wrapper for the losses. + output_loss_metrics = [] + for i in range(len(model.outputs)): + loss_fn = model.loss_functions[i] + mean_wrapped_loss = metrics_module.MeanMetricWrapper( + loss_fn, name=loss_fn.__name__) + output_loss_metrics.append(mean_wrapped_loss) + callbacks.on_train_begin() for epoch in range(initial_epoch, epochs): if model._is_compiled: # Model may not be compiled the first time. # Reset stateful metrics for m in model.stateful_metric_functions: m.reset_states() + + for m in output_loss_metrics: + m.reset_states() + callbacks.on_epoch_begin(epoch) epoch_logs = {} iterator_fit_loop( @@ -768,7 +857,8 @@ def fit_loop(model, callbacks=callbacks, validation_steps=validation_steps, do_validation=do_validation, - batch_size=batch_size) + batch_size=batch_size, + output_loss_metrics=output_loss_metrics) callbacks.on_epoch_end(epoch, epoch_logs) if callbacks.model.stop_training: break diff --git a/tensorflow/python/keras/engine/training_eager_test.py b/tensorflow/python/keras/engine/training_eager_test.py index 234cbab9ebff4c..76aaf1643b07e5 100644 --- a/tensorflow/python/keras/engine/training_eager_test.py +++ b/tensorflow/python/keras/engine/training_eager_test.py @@ -196,8 +196,7 @@ def test_loss_correctness(self): np.random.seed(123) y = np.random.randint(0, 1, size=(100, 1)) history = model.fit(x, y, epochs=1, batch_size=10) - self.assertEqual( - np.around(history.history['loss'][-1], decimals=4), 0.6173) + self.assertAlmostEqual(history.history['loss'][-1], 0.6173, 4) @tf_test_util.run_in_graph_and_eager_modes def test_loss_correctness_with_iterator(self): @@ -220,7 +219,7 @@ def test_loss_correctness_with_iterator(self): dataset = dataset.batch(10) iterator = dataset.make_one_shot_iterator() history = model.fit(iterator, epochs=1, steps_per_epoch=10) - self.assertEqual(np.around(history.history['loss'][-1], decimals=4), 0.6173) + self.assertAlmostEqual(history.history['loss'][-1], 0.6173, 4) def test_loss_in_call(self): diff --git a/tensorflow/python/keras/engine/training_generator.py b/tensorflow/python/keras/engine/training_generator.py index 21f44423ec03c7..b5e3a039767d74 100644 --- a/tensorflow/python/keras/engine/training_generator.py +++ b/tensorflow/python/keras/engine/training_generator.py @@ -243,11 +243,6 @@ def evaluate_generator(model, if hasattr(model, 'metrics'): for m in model.stateful_metric_functions: m.reset_states() - stateful_metric_indices = [ - i for i, name in enumerate(model.metrics_names) - if str(name) in model.stateful_metric_names] - else: - stateful_metric_indices = [] steps_done = 0 all_outs = [] @@ -329,13 +324,12 @@ def evaluate_generator(model, if not isinstance(outs, list): return np.average(np.asarray(all_outs), weights=batch_sizes) else: - averages = [] - for i in range(len(outs)): - if i not in stateful_metric_indices: - averages.append( - np.average([out[i] for out in all_outs], weights=batch_sizes)) - else: - averages.append(np.float64(all_outs[-1][i])) + averages = [float(all_outs[-1][0])] # index 0 = 'loss' + averages.extend([ + np.average([out[i] + for out in all_outs], weights=batch_sizes) + for i in range(1, len(outs)) + ]) return averages @@ -348,7 +342,7 @@ def predict_generator(model, verbose=0): """See docstring for `Model.predict_generator`.""" if not context.executing_eagerly(): - model._make_test_function() + model._make_predict_function() steps_done = 0 all_outs = [] diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index faafc60d428f36..df5669b5cc12df 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.callbacks import Callback from tensorflow.python.keras.engine.training_utils import weighted_masked_objective from tensorflow.python.keras.utils.generic_utils import slice_arrays from tensorflow.python.ops import array_ops @@ -535,6 +536,64 @@ def test_compile_warning_for_loss_missing_output(self): 'expecting any data to be passed to "dense_1".') self.assertRegexpMatches(str(mock_log.call_args), msg) + def test_logs_passed_to_callbacks(self): + with self.cached_session(): + input_dim = 5 + num_classes = 1 + + class TestCallback(Callback): + + def __init__(self): + super(TestCallback, self).__init__() + self.epoch_end_logs = None + self.batch_end_logs = None + self.epoch_end_call_count = 0 + self.batch_end_call_count = 0 + + def on_epoch_end(self, epoch, logs=None): + self.epoch_end_logs = logs + self.epoch_end_call_count += 1 + + def on_batch_end(self, batch, logs=None): + self.batch_end_logs = logs + self.batch_end_call_count += 1 + + model = testing_utils.get_small_sequential_mlp( + num_hidden=10, num_classes=num_classes, input_dim=input_dim) + model.compile( + loss='binary_crossentropy', + metrics=['acc'], + weighted_metrics=['mae'], + optimizer=RMSPropOptimizer(learning_rate=0.01)) + + np.random.seed(1337) + (x_train, y_train), (_, _) = testing_utils.get_test_data( + train_samples=10, + test_samples=10, + input_shape=(input_dim,), + num_classes=num_classes) + + test_callback = TestCallback() + model.fit( + x_train, + y_train, + batch_size=2, + epochs=2, + verbose=0, + callbacks=[test_callback], + validation_data=(x_train, y_train)) + self.assertEqual(test_callback.batch_end_call_count, 10) + self.assertEqual(test_callback.epoch_end_call_count, 2) + self.assertSetEqual( + set(test_callback.batch_end_logs.keys()), + set(['batch', 'size', 'acc', 'loss', 'weighted_mean_absolute_error'])) + self.assertSetEqual( + set(test_callback.epoch_end_logs.keys()), + set([ + 'acc', 'loss', 'weighted_mean_absolute_error', 'val_acc', + 'val_loss', 'val_weighted_mean_absolute_error' + ])) + class LossWeightingTest(test.TestCase): @@ -2059,12 +2118,7 @@ def test_metrics_names(self): 'dense_binary_accuracy', 'dropout_mean_squared_error', 'dropout_binary_accuracy' ] - reference_stateful_metric_names = [ - 'dense_binary_accuracy', 'dropout_binary_accuracy' - ] self.assertEqual(reference_metric_names, model.metrics_names) - self.assertEqual(reference_stateful_metric_names, - model.stateful_metric_names) # Verify that model metric names are not altered during training. input_a_np = np.random.random((10, 3)) @@ -2077,8 +2131,6 @@ def test_metrics_names(self): epochs=1, batch_size=5) self.assertEqual(reference_metric_names, model.metrics_names) - self.assertEqual(reference_stateful_metric_names, - model.stateful_metric_names) @tf_test_util.run_in_graph_and_eager_modes def test_metrics_correctness(self): @@ -2152,8 +2204,7 @@ def test_metrics_correctness_with_weighted_metrics(self): RMSPropOptimizer(learning_rate=0.001), loss='mse', sample_weight_mode='temporal', - weighted_metrics=['accuracy', - metrics_module.BinaryAccuracy()]) + weighted_metrics=['accuracy', 'mse']) y = np.array([[[1.], [1.]], [[1.], [1.]]]) outs = model.evaluate(x, y) @@ -2165,7 +2216,15 @@ def test_metrics_correctness_with_weighted_metrics(self): w = np.array([[3., 4.], [1., 2.]]) outs = model.evaluate(x, y, sample_weight=w) - self.assertArrayNear(outs, [0.3, 0.7, 0.7], .001) + self.assertArrayNear(outs, [0.3, 0.7, 0.3], .001) + + # Verify that metric value is same with arbitrary weights and batch size. + x = np.random.random((50, 2, 1)) + y = np.random.random((50, 2, 1)) + w = np.random.random((50, 2)) + mse1 = model.evaluate(x, y, sample_weight=w, batch_size=5)[2] + mse2 = model.evaluate(x, y, sample_weight=w, batch_size=10)[2] + self.assertEqual(mse1, mse2) @tf_test_util.run_in_graph_and_eager_modes def test_metric_state_reset_between_fit_and_evaluate(self): @@ -2216,19 +2275,18 @@ def test_metrics_masking(self): model.compile( RMSPropOptimizer(learning_rate=0.001), loss='mse', - weighted_metrics=['accuracy', - metrics_module.BinaryAccuracy()]) + weighted_metrics=['accuracy']) - # verify that masking is applied for stateless and stateful metrics. + # verify that masking is applied. x = np.array([[[1], [1]], [[1], [1]], [[0], [0]]]) y = np.array([[[1], [1]], [[0], [1]], [[1], [1]]]) scores = model.train_on_batch(x, y) - self.assertArrayNear(scores, [0.25, 0.75, 0.75], 0.1) + self.assertArrayNear(scores, [0.25, 0.75], 0.1) # verify that masking is combined with sample weights. w = np.array([3, 2, 4]) scores = model.train_on_batch(x, y, sample_weight=w) - self.assertArrayNear(scores, [0.2, 0.8, 0.8], 0.1) + self.assertArrayNear(scores, [0.2, 0.8], 0.1) def test_losses_in_defun(self): with context.eager_mode(): diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py index 7034874ee8de74..e563b7a23dfa4b 100644 --- a/tensorflow/python/keras/engine/training_utils.py +++ b/tensorflow/python/keras/engine/training_utils.py @@ -34,6 +34,7 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras import losses from tensorflow.python.keras import metrics as metrics_module +from tensorflow.python.keras.engine import base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import weights_broadcast_ops @@ -510,8 +511,15 @@ def collect_per_output_metric_info(metrics, For instance, if the model has 2 outputs, and for the first output we want to compute "binary_accuracy" and "binary_crossentropy", and just "binary_accuracy" for the second output, - the list would look like: `[[('acc', binary_accuracy()), - ('ce', binary_crossentropy())], [('acc', binary_accuracy())]]` + the list would look like: `[ + { + 'acc': (binary_accuracy(), mean_obj_1), + 'ce': (binary_crossentropy(), mean_obj_2) + }, + { + 'acc': (binary_accuracy(), mean_obj_3) + } + ]` Raises: TypeError: if an incorrect type is passed for the `metrics` argument. @@ -541,7 +549,19 @@ def collect_per_output_metric_info(metrics, metric_name = get_metric_name(metric, weighted) metric_fn = get_metric_function( metric, output_shape=output_shapes[i], loss_fn=loss_fns[i]) - metrics_dict[metric_name] = metric_fn + + # If the metric function is not stateful, we create a stateful version and + # return both the stateless and the stateful version together. For batch + # APIs like `train_on_batch` we will use the stateless version and for + # other APIs like `fit` we will use the stateful version. + is_stateful = isinstance(metric_fn, + base_layer.Layer) and metric_fn.stateful + stateful_fn = metric_fn + if not is_stateful: + stateful_fn = metrics_module.MeanMetricWrapper( + metric_fn, name=metric_fn.__name__) + + metrics_dict[metric_name] = (metric_fn, stateful_fn) per_output_metrics.append(metrics_dict) return per_output_metrics @@ -608,19 +628,10 @@ def weighted(y_true, y_pred, weights, mask=None): if weights is None: weights = mask else: - # Update shape of weights if possible before adding mask. # Update dimensions of weights to match with mask if possible. mask, _, weights = metrics_module.squeeze_or_expand_dimensions( mask, None, weights) - try: - # Broadcast weights if possible. - weights = weights_broadcast_ops.broadcast_weights(weights, mask) - weights *= mask - except ValueError: - score_array *= mask - score_array /= K.mean(mask) - # TODO(psv): Handle case when mask and weight shapes are not - # compatible. + weights *= mask # Apply sample weighting. if weights is not None: @@ -813,6 +824,23 @@ def get_metric_function(metric, output_shape=None, loss_fn=None): return metrics_module.get(metric) +def call_metric_function(metric_fn, y_true, y_pred, weights=None, mask=None): + """Invokes metric function and returns the metric result tensor.""" + if mask is None: + return metric_fn(y_true, y_pred, sample_weight=weights) + + mask = math_ops.cast(mask, y_pred.dtype) + if weights is None: + # Use mask as sample weight. + return metric_fn(y_true, y_pred, sample_weight=mask) + + # Update dimensions of weights to match with mask. + mask, _, weights = metrics_module.squeeze_or_expand_dimensions( + mask, None, weights) + weights *= mask + return metric_fn(y_true, y_pred, sample_weight=weights) + + def validate_iterator_input(x, y, sample_weight, validation_split=None): """Validates user input arguments when a dataset iterator is passed. diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index bc343325395417..33e526352fae16 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -19,9 +19,7 @@ from __future__ import division from __future__ import print_function -from abc import ABCMeta -from abc import abstractmethod - +import abc import functools import sys import types @@ -269,6 +267,7 @@ def _maybe_adjust_weights(): return y_pred, y_true, sample_weight +@six.add_metaclass(abc.ABCMeta) class Metric(Layer): """Encapsulates metric logic and state. @@ -351,7 +350,6 @@ def result(self): return array_ops.identity(self.true_positives) ``` """ - __metaclass__ = ABCMeta def __init__(self, name=None, dtype=None): super(Metric, self).__init__(name=name, dtype=dtype) @@ -403,7 +401,7 @@ def reset_states(self): for v in self.variables: K.set_value(v, 0) - @abstractmethod + @abc.abstractmethod def update_state(self, *args, **kwargs): """Accumulates statistics for the metric. @@ -424,7 +422,7 @@ def update_state(self, *args, **kwargs): """ NotImplementedError('Must be implemented in subclasses.') - @abstractmethod + @abc.abstractmethod def result(self): """Computes and returns the metric value tensor. diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py index 11054b5b08686c..0c9c066a852007 100644 --- a/tensorflow/python/keras/models.py +++ b/tensorflow/python/keras/models.py @@ -341,11 +341,11 @@ def _in_place_subclassed_model_reset(model): 'weighted_metrics', 'metrics_names', 'metrics_tensors', - 'metrics_updates', - 'stateful_metric_names', 'total_loss', 'sample_weights', '_feed_sample_weights', + '_fit_function', + '_eval_function', 'train_function', 'test_function', 'predict_function', diff --git a/tensorflow/python/keras/optimizer_v2/BUILD b/tensorflow/python/keras/optimizer_v2/BUILD index f742e8aa26521c..b4b84fad0cecd8 100644 --- a/tensorflow/python/keras/optimizer_v2/BUILD +++ b/tensorflow/python/keras/optimizer_v2/BUILD @@ -54,6 +54,7 @@ cuda_py_test( srcs = ["optimizer_v2_test.py"], additional_deps = [ ":optimizer_v2", + "//tensorflow/python/eager:def_function", "//tensorflow/python:client_testlib", "//tensorflow/python:framework", "//tensorflow/python:framework_test_lib", diff --git a/tensorflow/python/keras/optimizer_v2/adam.py b/tensorflow/python/keras/optimizer_v2/adam.py index 6c67cd3a61aabe..b05811c419fa8e 100644 --- a/tensorflow/python/keras/optimizer_v2/adam.py +++ b/tensorflow/python/keras/optimizer_v2/adam.py @@ -94,10 +94,10 @@ def __init__(self, """ super(Adam, self).__init__(name) - self._lr = learning_rate - self._beta_1 = beta_1 - self._beta_2 = beta_2 - self._epsilon = epsilon + self._set_hyper('learning_rate', learning_rate) + self._set_hyper('beta_1', beta_1) + self._set_hyper('beta_2', beta_2) + self._set_hyper('epsilon', epsilon) def _create_slots(self, var_list): # Create slots for the first and second moments. @@ -114,11 +114,21 @@ def _resource_apply_dense(self, grad, var): var.handle, m.handle, v.handle, - math_ops.cast(self._beta_1, grad.dtype.base_dtype), - math_ops.cast(self._beta_2, grad.dtype.base_dtype), - math_ops.cast(self._lr, grad.dtype.base_dtype), - math_ops.cast(self._beta_1, grad.dtype.base_dtype), - math_ops.cast(self._beta_2, grad.dtype.base_dtype), - math_ops.cast(self._epsilon, grad.dtype.base_dtype), + math_ops.cast(self._get_hyper('beta_1'), grad.dtype.base_dtype), + math_ops.cast(self._get_hyper('beta_2'), grad.dtype.base_dtype), + math_ops.cast(self._get_hyper('learning_rate'), grad.dtype.base_dtype), + math_ops.cast(self._get_hyper('beta_1'), grad.dtype.base_dtype), + math_ops.cast(self._get_hyper('beta_2'), grad.dtype.base_dtype), + math_ops.cast(self._get_hyper('epsilon'), grad.dtype.base_dtype), grad, use_locking=self._use_locking) + + def get_config(self): + config = super(Adam, self).get_config() + config.update({ + 'learning_rate': self._serialize_hyperparameter('learning_rate'), + 'beta_1': self._serialize_hyperparameter('beta_1'), + 'beta_2': self._serialize_hyperparameter('beta_2'), + 'epsilon': self._serialize_hyperparameter('epsilon'), + }) + return config diff --git a/tensorflow/python/keras/optimizer_v2/gradient_descent.py b/tensorflow/python/keras/optimizer_v2/gradient_descent.py index e26c82279f580f..3ee1982af95c64 100644 --- a/tensorflow/python/keras/optimizer_v2/gradient_descent.py +++ b/tensorflow/python/keras/optimizer_v2/gradient_descent.py @@ -72,9 +72,10 @@ def _resource_apply_dense(self, grad, var): grad, use_locking=self._use_locking) - def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices): + def _resource_apply_sparse_duplicate_indices(self, grad, var, indices): return resource_variable_ops.resource_scatter_add( - handle.handle, indices, -grad * self._get_hyper("learning_rate")) + var.handle, indices, -grad * math_ops.cast( + self._get_hyper("learning_rate"), var.dtype.base_dtype)) def _apply_sparse_duplicate_indices(self, grad, var): delta = ops.IndexedSlices( diff --git a/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py b/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py index a1f534d55f6367..3fb15c51d04b8b 100644 --- a/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py +++ b/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py @@ -66,11 +66,7 @@ def testBasicResourceVariable(self): grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) sgd = gradient_descent.SGD(3.0) sgd_op = sgd.apply_gradients(zip([grads0, grads1], [var0, var1])) - # TODO(apassos) calling initialize_resources on all resources here - # doesn't work because the sessions and graph are reused across unit - # tests and this would mean trying to reinitialize variables. Figure out - # a long-term solution for this. - resources.initialize_resources([var0, var1, sgd.iteration]).run() + variables.global_variables_initializer().run() # Fetch params to validate initial values self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) @@ -118,11 +114,7 @@ def testMinimizeResourceVariable(self): loss = pred * pred sgd = gradient_descent.SGD(1.0) sgd_op = sgd.minimize(loss, [var0, var1]) - # TODO(apassos) calling initialize_resources on all resources here - # doesn't work because the sessions and graph are reused across unit - # tests and this would mean trying to reinitialize variables. Figure out - # a long-term solution for this. - resources.initialize_resources([var0, var1, sgd.iteration]).run() + variables.global_variables_initializer().run() # Fetch params to validate initial values self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval()) self.assertAllCloseAccordingToType([3.0], var1.eval()) @@ -258,12 +250,6 @@ def step(): # be an EagerTensor once again, not a graph Tensor. self.assertEqual(float(step()), -1.0) - def testConfig(self): - opt = gradient_descent.SGD(learning_rate=1.0) - config = opt.get_config() - opt2 = gradient_descent.SGD.from_config(config) - self.assertEqual(opt._hyper["learning_rate"], opt2._hyper["learning_rate"]) - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index c820847e53d7d3..26e6dc294c0a90 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -22,10 +22,13 @@ import abc +import six + from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.keras import backend from tensorflow.python.keras import initializers from tensorflow.python.keras.engine import base_layer from tensorflow.python.ops import control_flow_ops @@ -38,6 +41,7 @@ from tensorflow.python.util import nest +@six.add_metaclass(abc.ABCMeta) class OptimizerV2(optimizer_v1.Optimizer): """Updated base class for optimizers. @@ -134,7 +138,9 @@ def __init__(self, name): self._use_locking = True super(OptimizerV2, self).__init__(self._use_locking, name) self._hyper = {} + # dict: {variable name : {slot name : variable}} self._slots = {} + self._weights = [] self._prepared = False def minimize(self, @@ -314,7 +320,8 @@ def update_grad_to_var(grad, var): with ops.name_scope(name, self._name) as name: self._prepare() for grad, var in grads_and_vars: - scope_name = "" if in_eager_execution() else "_" + var.op.name + scope_name = ("" if ops.executing_eagerly_outside_functions() else + "_" + var.op.name) with ops.name_scope("update" + scope_name), ops.colocate_with(var): update_ops.append(update_grad_to_var(grad, var)) with ops.colocate_with(self._iterations): @@ -322,33 +329,61 @@ def update_grad_to_var(grad, var): return control_flow_ops.group(*update_ops) def _set_hyper(self, name, value): - self._hyper[name] = value + """set hyper `name` to value. value can be callable, tensor, numeric.""" + if name not in self._hyper: + self._hyper[name] = value + else: + prev_value = self._hyper[name] + if callable(prev_value) or isinstance(prev_value, + (ops.Tensor, int, float)): + self._hyper[name] = value + else: + backend.set_value(self._hyper[name], value) def _get_hyper(self, name): - # TODO(tanzheny): if hyper variable exists then return it. value = self._hyper[name] return self._call_if_callable(value) + def __setattr__(self, name, value): + """Override setattr to support dynamic hyperparameter setting.""" + if hasattr(self, "_hyper") and name in self._hyper: + self._set_hyper(name, value) + else: + super(OptimizerV2, self).__setattr__(name, value) + def add_slot(self, var, slot_name): - slot_key = _get_slot_key_from_var(var, slot_name) - if slot_key not in self._slots: - self._slots[slot_key] = self.add_weight( - name=slot_key, shape=var.shape, dtype=var.dtype) + var_key = _var_key(var) + slot_dict = self._slots.setdefault(var_key, {}) + if slot_name not in slot_dict: + slot_key = _get_slot_key_from_var(var, slot_name) + weight = self.add_weight(name=slot_key, shape=var.shape, dtype=var.dtype) + slot_dict[slot_name] = weight + self._weights.append(weight) def get_slot(self, var, slot_name): - slot_key = _get_slot_key_from_var(var, slot_name) - return self._slots[slot_key] + var_key = _var_key(var) + slot_dict = self._slots[var_key] + return slot_dict[slot_name] def _prepare(self): if self._prepared: return - # This is where all hyper variables will be created. with ops.device("cpu:0"): self._iterations = self.add_weight( - self._name + "/iter", + "iter", shape=[], trainable=False, aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA) + for name, value in self._hyper.items(): + if isinstance(value, ops.Tensor) or callable(value): + pass + else: + self._hyper[name] = self.add_weight( + name, + shape=[], + trainable=False, + initializer=value, + aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA) self._prepared = True @property @@ -392,7 +427,42 @@ def from_config(cls, config, custom_objects=None): def _serialize_hyperparameter(self, hyperparameter_name): """Serialize a hyperparameter that can be a float, callable, or Tensor.""" - return self._hyper[hyperparameter_name] + value = self._get_hyper(hyperparameter_name) + if callable(value): + return value() + if isinstance(value, (ops.Tensor, variables.Variable)): + return backend.get_value(value) + return value + + @property + def weights(self): + """Returns variables of this Optimizer based on the order created.""" + return self._weights + + def get_weights(self): + params = self.weights + return backend.batch_get_value(params) + + # TODO(tanzheny): Maybe share this logic with base_layer. + def set_weights(self, weights): + params = self.weights + if len(params) != len(weights): + raise ValueError( + "You called `set_weights(weights)` on optimizer " + self._name + + " with a weight list of length " + str(len(weights)) + + ", but the optimizer was expecting " + str(len(params)) + + " weights. Provided weights: " + str(weights)[:50] + "...") + if not params: + return + weight_value_tuples = [] + param_values = backend.batch_get_value(params) + for pv, p, w in zip(param_values, params, weights): + if pv.shape != w.shape: + raise ValueError("Optimizer weight shape " + str(pv.shape) + + " not compatible with " + "provided weight shape " + str(w.shape)) + weight_value_tuples.append((p, w)) + backend.batch_set_value(weight_value_tuples) def add_weight(self, name, @@ -405,7 +475,8 @@ def add_weight(self, if dtype is None: dtype = dtypes.float32 - initializer = initializers.get(initializer) + if isinstance(initializer, six.string_types) or callable(initializer): + initializer = initializers.get(initializer) if synchronization == variables.VariableSynchronization.ON_READ: if trainable: @@ -425,7 +496,7 @@ def add_weight(self, shape=shape, getter=base_layer.make_variable, overwrite=True, - initializer=initializers.get(initializer), + initializer=initializer, dtype=dtype, trainable=trainable, use_resource=True, @@ -470,39 +541,31 @@ def merge_grad_fn(strategy, grads_and_vars): merge_grad_fn, grads_and_vars) -def in_eager_execution(): - with ops.init_scope(): - return context.executing_eagerly() +def _var_key(var): + """Key for representing a primary variable, for looking up slots. - -def _get_slot_key_from_var(var, slot_name): - """Get the slot key for the variable. - - Scope the slot name in the namespace of the primary variable. - Set "primary.op.name + '/' + slot_name" as default name. - - In graph mode the name is derived from the op. - In eager mode the name is derived from the var. - If distribution strategy exists, then the name is derived from the primary - variable instead of replica variable, i.e., /dense/kernel instead of - /dense/kernel/replica_1. If the slot name is 'm', then the slot variables - being created are /dense/kernel/m and /dense/kernel/m/replica_1, instead of - /dense/kernel/replica_1/m/replica_1. + In graph mode the name is derived from the var shared name. + In eager mode the name is derived from the var unique id. + If distribution strategy exists, get the primary variable first. Args: var: the variable. - slot_name: the name of the slot. Returns: - the name of the variable. + the unique name of the variable. """ # pylint: disable=protected-access if distribution_strategy_context.has_distribution_strategy() and hasattr( var, "_primary_var"): var = var._primary_var - if context.executing_eagerly(): - name = var._shared_name - else: - name = var.op.name + if hasattr(var, "op"): + return var._shared_name + return var._unique_id + + +def _get_slot_key_from_var(var, slot_name): + """Get the slot key for the variable: var_name/slot_name.""" + + name = _var_key(var) return name + "/" + slot_name diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py index fe12ab204f1e51..e5d1a104ca408d 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py @@ -19,10 +19,13 @@ from __future__ import print_function from tensorflow.python.eager import context +from tensorflow.python.eager import def_function +from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.keras.optimizer_v2 import adam from tensorflow.python.keras.optimizer_v2 import gradient_descent from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -52,12 +55,48 @@ def testBasic(self): self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 1 step of sgd through optimizer opt_op = sgd.minimize(loss, var_list=[var0, var1]) - self.evaluate(sgd.iteration.initializer) + self.evaluate(variables.global_variables_initializer()) self.evaluate(opt_op) # Validate updated params self.assertAllClose([-14., -13.], self.evaluate(var0)) self.assertAllClose([-6., -5.], self.evaluate(var1)) + @test_util.run_in_graph_and_eager_modes + def testAdaptiveLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) + + def loss(): + return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop + + sgd = gradient_descent.SGD(1.0) + + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + # Run 1 step of sgd through optimizer + opt_op = sgd.minimize(loss, [var0, var1]) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(opt_op) + # Validate updated params + # var0 = [1., 2.] - 1.0 * [5, 5] + self.assertAllClose([-4., -3.], self.evaluate(var0)) + # var1 = [3., 4.] - 1.0 * [3, 3] + self.assertAllClose([0., 1.], self.evaluate(var1)) + + sgd.learning_rate = 0.5 + if context.executing_eagerly(): + sgd.minimize(loss, [var0, var1]) + else: + self.evaluate(opt_op) + # Validate updated params + # var0 = [-4., -3.] - 0.5 * [5, 5] + self.assertAllClose([-6.5, -5.5], self.evaluate(var0)) + # var1 = [0., 1.] - 0.5 * [3, 3] + self.assertAllClose([-1.5, -0.5], self.evaluate(var1)) + @test_util.run_in_graph_and_eager_modes def testAggregationMethod(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: @@ -79,7 +118,7 @@ def testAggregationMethod(self): var_list=[var0, var1], aggregation_method=gradients_impl.AggregationMethod .EXPERIMENTAL_ACCUMULATE_N) - self.evaluate(sgd.iteration.initializer) + self.evaluate(variables.global_variables_initializer()) self.evaluate(opt_op) # Validate updated params self.assertAllClose([-14., -13.], self.evaluate(var0)) @@ -103,7 +142,7 @@ def testPrecomputedGradient(self): self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 1 step of sgd through optimizer opt_op = sgd.minimize(loss, var_list=[var0, var1], grad_loss=grad_loss) - self.evaluate(sgd.iteration.initializer) + self.evaluate(variables.global_variables_initializer()) self.evaluate(opt_op) # Validate updated params self.assertAllClose([1.0 - 3 * 5 * 42.0, 2.0 - 3 * 5 * (-42.0)], @@ -184,7 +223,8 @@ def testGradientsAsVariables(self): # Run 1 step of sgd through optimizer converted_grads_and_vars = list(zip(converted_grads, [var0, var1])) opt_op = sgd.apply_gradients(converted_grads_and_vars) - self.evaluate(sgd.iteration.initializer) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(convert_ops) self.evaluate(opt_op) # Validate updated params @@ -229,7 +269,7 @@ def testConstraint(self): self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 1 step of sgd through optimizer opt_op = sgd.minimize(loss, var_list=[var0, var1]) - self.evaluate(sgd.iteration.initializer) + self.evaluate(variables.global_variables_initializer()) self.evaluate(opt_op) # Validate updated params self.assertAllClose([-0.1, -0.1], self.evaluate(var0)) @@ -242,6 +282,111 @@ def testIterationWithoutMinimize(self): self.evaluate(sgd.iteration.initializer) self.assertEqual(0, self.evaluate(sgd.iteration)) + @test_util.run_in_graph_and_eager_modes + def testSerializationWithinDefun(self): + with self.cached_session(): + sgd = gradient_descent.SGD(3.0) + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], + dtype=dtypes.float32) + loss = lambda: 3 * var0 + sgd.minimize(loss, [var0]) + + def serialize(): + config = sgd.get_config() + gradient_descent.SGD.from_config(config) + + compiled_serialize = function.defun(serialize) + with self.assertRaisesRegexp(RuntimeError, 'inside Tensorflow graph'): + compiled_serialize() + + @test_util.run_in_graph_and_eager_modes + def testConfig(self): + with self.cached_session(): + opt = gradient_descent.SGD(learning_rate=1.0) + config = opt.get_config() + opt2 = gradient_descent.SGD.from_config(config) + # assert both are equal float values. + self.assertEqual( + opt._get_hyper('learning_rate'), opt2._get_hyper('learning_rate')) + var0 = variables.Variable([[1.0], [2.0]], dtype=dtypes.float32) + loss = lambda: 3 * var0 + # learning rate variable created when calling minimize. + opt.minimize(loss, [var0]) + self.evaluate(variables.global_variables_initializer()) + config = opt.get_config() + opt3 = gradient_descent.SGD.from_config(config) + self.assertEqual( + self.evaluate(opt._get_hyper('learning_rate')), + opt3._get_hyper('learning_rate')) + + @test_util.run_in_graph_and_eager_modes + def testWeights(self): + with self.cached_session(): + opt1 = adam.Adam(learning_rate=1.0) + var1 = resource_variable_ops.ResourceVariable([1.0, 2.0], + dtype=dtypes.float32) + loss1 = lambda: 3 * var1 + opt_op_1 = opt1.minimize(loss1, [var1]) + self.evaluate(variables.global_variables_initializer()) + config = opt1.get_config() + opt2 = adam.Adam.from_config(config) + var2 = resource_variable_ops.ResourceVariable([1.0, 2.0], + dtype=dtypes.float32) + loss2 = lambda: 3 * var2 + opt_op_2 = opt2.minimize(loss2, [var2]) + weights = opt1.get_weights() + + # Assert set_weights and both variables get updated to same value. + self.evaluate(variables.global_variables_initializer()) + opt2.set_weights(weights) + self.evaluate([opt_op_1, opt_op_2]) + self.assertAllClose(self.evaluate(var1), self.evaluate(var2)) + self.assertEqual(1, self.evaluate(opt1.iteration)) + self.assertEqual(1, self.evaluate(opt2.iteration)) + + var3 = resource_variable_ops.ResourceVariable([1.0, 2.0, 3.0], + dtype=dtypes.float32) + var4 = resource_variable_ops.ResourceVariable([4.0, 5.0, 6.0], + dtype=dtypes.float32) + loss3 = lambda: 3 * var3 + 5 * var4 + opt_op_3 = opt1.minimize(loss3, [var3, var4]) + + # Assert set_weights with ValueError since weight list does not match. + self.evaluate(variables.global_variables_initializer()) + weights = opt1.get_weights() + with self.assertRaisesRegexp(ValueError, 'but the optimizer was'): + opt2.set_weights(weights) + + # Assert set_weights and variables get updated to same value. + var5 = resource_variable_ops.ResourceVariable([1.0, 2.0, 3.0], + dtype=dtypes.float32) + var6 = resource_variable_ops.ResourceVariable([4.0, 5.0, 6.0], + dtype=dtypes.float32) + loss4 = lambda: 3 * var5 + 5 * var6 + opt_op_4 = opt2.minimize(loss4, [var5, var6]) + self.evaluate(variables.global_variables_initializer()) + opt2.set_weights(weights) + self.evaluate([opt_op_3, opt_op_4]) + self.assertAllClose( + self.evaluate([var3, var4]), self.evaluate([var5, var6])) + + def testOptimizerWithFunction(self): + with context.eager_mode(): + var = resource_variable_ops.ResourceVariable([1.0, 2.0], + dtype=dtypes.float32) + loss = lambda: 3 * var + opt = adam.Adam(learning_rate=1.0) + + @def_function.function + def fn(): + opt.minimize(loss, [var]) + return var + + self.assertAllClose([0., 1.], fn()) + # This is just to test tf.function. The values needs to be updated + # when adam updates beta_1_power. + self.assertAllClose([-1.343838, -0.343838], fn()) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py index 2fae094a1ef93e..d342131a521a90 100644 --- a/tensorflow/python/keras/testing_utils.py +++ b/tensorflow/python/keras/testing_utils.py @@ -149,7 +149,7 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None, np.testing.assert_allclose(output, actual_output, rtol=1e-3) # test training mode (e.g. useful for dropout tests) - model.compile(RMSPropOptimizer(0.01), 'mse') + model.compile(RMSPropOptimizer(0.01), 'mse', weighted_metrics=['acc']) model.train_on_batch(input_data, actual_output) # test as first layer in Sequential API diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py index 2e56fa2dc54746..5af82f36911578 100644 --- a/tensorflow/python/keras/utils/generic_utils.py +++ b/tensorflow/python/keras/utils/generic_utils.py @@ -146,6 +146,8 @@ def deserialize_keras_object(identifier, module_objects=None, custom_objects=None, printable_module_name='object'): + if identifier is None: + return None if isinstance(identifier, dict): # In this case we are dealing with a Keras config dictionary. config = identifier diff --git a/tensorflow/python/keras/utils/generic_utils_test.py b/tensorflow/python/keras/utils/generic_utils_test.py index 87bc19eb37d15d..ead4beee1cbeb7 100644 --- a/tensorflow/python/keras/utils/generic_utils_test.py +++ b/tensorflow/python/keras/utils/generic_utils_test.py @@ -71,5 +71,15 @@ class CustomClass(object): self.assertEqual(cl.__class__, CustomClass) +class SerializeKerasObjectTest(test.TestCase): + + def test_serialize_none(self): + serialized = keras.utils.generic_utils.serialize_keras_object(None) + self.assertEqual(serialized, None) + deserialized = keras.utils.generic_utils.deserialize_keras_object( + serialized) + self.assertEqual(deserialized, None) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index dc3f8a7b516999..e6508fde0f682c 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -2562,6 +2562,8 @@ cuda_py_test( ], shard_count = 4, tags = [ + # TODO(b/118887316): Re-enable this test in Kokoro. + "no_oss", "optonly", # times out ], ) @@ -2579,6 +2581,8 @@ cuda_py_test( "//tensorflow/python:nn_grad", "//tensorflow/python:nn_ops", ], + # TODO(b/118842098): Re-enable this test in Kokoro. + tags = ["no_oss"], ) tf_py_test( @@ -2726,6 +2730,22 @@ cuda_py_test( ], ) +cuda_py_test( + name = "huge_slice_op_test", + size = "medium", + srcs = ["huge_slice_op_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:framework_for_generated_wrappers", + ], + tags = [ + "no_oss", # Requires 4GB+ RAM + ], +) + cuda_py_test( name = "sparse_matmul_op_test", size = "medium", diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index fd7f9d2798aa7c..3a5d817e9d6d95 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -928,7 +928,6 @@ def testWhile_2(self): r = isum(s) self.assertAllEqual(45, r.eval()) - @test_util.disable_control_flow_v2("b/115776323 (max_iters)") def testWhileWithMaximumIterations(self): with self.cached_session(): s = constant_op.constant([1, 2, 3, 4, 5]) @@ -1045,7 +1044,7 @@ def create_mi(): r"while loop context '' \(currently defined in 'cond/.+'\)"): _ = gradients_impl.gradients(loop, v) - @test_util.disable_control_flow_v2("b/115776323 (max_iters)") + @test_util.disable_control_flow_v2("b/118457764") def testNestedWhileLoopWithMaxItersFromOuterContextInXLAContext(self): v = constant_op.constant(1.0) @@ -1277,8 +1276,8 @@ def b(i, j): r = control_flow_ops.while_loop( c, b, [i, m], [i.get_shape(), tensor_shape.TensorShape([None, 2])]) - self.assertIsNone(r[1].get_shape()[0].value) - self.assertEqual(r[1].get_shape()[1], tensor_shape.Dimension(2)) + self.assertIsNone(r[1].shape.dims[0].value) + self.assertEqual(r[1].shape.dims[1], tensor_shape.Dimension(2)) with self.assertRaisesRegexp( ValueError, @@ -1991,7 +1990,6 @@ def testNestedWhileCondWhileGrad(self): def testNestedWhileCondWhileGradGpu(self): self._testNestedWhileCondWhileGrad(use_gpu=True) - @test_util.disable_control_flow_v2("b/116823782") def testWhileGrad_Variable(self): with self.cached_session(): a = variables.Variable(3.0) @@ -2004,6 +2002,18 @@ def testWhileGrad_Variable(self): variables.global_variables_initializer().run() self.assertAllClose(216.0, r[0].eval()) + def testWhileGrad_ResourceVariable(self): + with self.cached_session(): + a = resource_variable_ops.ResourceVariable(3.0) + v = constant_op.constant(2.0, name="v") + c = lambda v: math_ops.less(v, 100.0) + b = lambda v: math_ops.multiply(v, a) + r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1) + + g = gradients_impl.gradients(r, a) + variables.global_variables_initializer().run() + self.assertAllClose(216.0, g[0].eval()) + def testWhileGradInCond(self): with self.cached_session(): @@ -2709,7 +2719,6 @@ def body(k, w, chg_w): grad, = gradients_impl.gradients(w, c) self.assertIsNotNone(grad) - @test_util.disable_control_flow_v2("b/116270461 (resource)") def testStopGradMultiFlows(self): with self.cached_session(): diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py index f35450b6fd67d1..04c1032722caac 100644 --- a/tensorflow/python/kernel_tests/functional_ops_test.py +++ b/tensorflow/python/kernel_tests/functional_ops_test.py @@ -829,6 +829,49 @@ def Body(n, x): self.assertAllEqual(5050., sess.run([result, c], feed_dict={n: 100.})[0]) + # pylint: disable=cell-var-from-loop + def testWhileCapturedInputs(self): + for use_gpu in (True, False): + with ops.Graph().as_default() as g: + v = variables.Variable(1.0) + + def TestCond(n, *args): + del args + return n < 10 + + @function.Defun(*[dtypes.float32] * 2) + def TestUnary(n, x): + return math_ops.add(n, 1), x + n + v + + @function.Defun(*[dtypes.float32] * 3) + def TestBinary(n, x, x2): + return math_ops.add(n, 1), x + n + v, x2 + v + + with self.session(graph=g, use_gpu=use_gpu) as sess: + result_unary = functional_ops.While( + [1.0, 0.], + function.Defun(*[dtypes.float32] * 2)(TestCond), TestUnary) + result_binary = functional_ops.While( + [1.0, 0., 0.], + function.Defun(*[dtypes.float32] * 3)(TestCond), TestBinary) + sess.run(variables.global_variables_initializer()) + assert len(result_unary) == 2 + self.assertEqual([10.0, 54.0], sess.run(result_unary)) + assert len(result_binary) == 3 + self.assertEqual([10.0, 54.0, 9.0], sess.run(result_binary)) + + def TestCondCapture(n, *args): + del args + return math_ops.to_float(n) + v < 10 + + with self.assertRaises(ValueError): + _ = functional_ops.While( + [1], + function.Defun(dtypes.int32)(TestCondCapture), + function.Defun(dtypes.int32, dtypes.float32)(TestUnary)) + + # pylint: enable=cell-var-from-loop + def _tfSum(self, use_gpu, rewrite_with_while): with ops.Graph().as_default() as g: with self.session(graph=g, use_gpu=use_gpu) as sess: diff --git a/tensorflow/python/kernel_tests/huge_slice_op_test.py b/tensorflow/python/kernel_tests/huge_slice_op_test.py new file mode 100644 index 00000000000000..8646d74c96f179 --- /dev/null +++ b/tensorflow/python/kernel_tests/huge_slice_op_test.py @@ -0,0 +1,43 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for slice op that consume a lot of GPU memory.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class SliceTest(test.TestCase): + + def testInt64Slicing(self): + with self.cached_session(force_gpu=test.is_gpu_available()): + a_large = array_ops.tile( + constant_op.constant(np.array([False, True] * 4)), [2**29 + 3]) + slice_t = array_ops.slice(a_large, np.asarray([3]).astype(np.int64), [3]) + slice_val = slice_t.eval() + self.assertAllEqual([True, False, True], slice_val) + + slice_t = array_ops.slice( + a_large, constant_op.constant([long(2)**32 + 3], dtype=dtypes.int64), + [3]) + slice_val = slice_t.eval() + self.assertAllEqual([True, False, True], slice_val) diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py index 13218787e22575..31fb19e4a69b68 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py @@ -205,7 +205,7 @@ def test_static_dims_broadcast(self): result = linear_operator_util.cholesky_solve_with_broadcast(chol, rhs) self.assertAllEqual((2, 3, 7), result.get_shape()) expected = linalg_ops.cholesky_solve(chol_broadcast, rhs) - self.assertAllEqual(expected.eval(), result.eval()) + self.assertAllClose(expected.eval(), result.eval()) def test_dynamic_dims_broadcast_64bit(self): # batch_shape = [2, 2] @@ -228,12 +228,12 @@ def test_dynamic_dims_broadcast_64bit(self): chol_ph: chol, rhs_ph: rhs, }) - self.assertAllEqual(expected, result) + self.assertAllClose(expected, result) class MatmulWithBroadcastTest(test.TestCase): - def test_static_dims_broadcast(self): + def test_static_dims_broadcast_x_has_extra_dims(self): # batch_shape = [2] # for each batch member, we have a 1x3 matrix times a 3x7 matrix ==> 1x7 x = rng.rand(2, 1, 3) @@ -244,7 +244,69 @@ def test_static_dims_broadcast(self): result = linear_operator_util.matmul_with_broadcast(x, y) self.assertAllEqual((2, 1, 7), result.get_shape()) expected = math_ops.matmul(x, y_broadcast) - self.assertAllEqual(expected.eval(), result.eval()) + self.assertAllClose(expected.eval(), result.eval()) + + def test_static_dims_broadcast_y_has_extra_dims(self): + # Since the second arg has extra dims, and the domain dim of the first arg + # is larger than the number of linear equations, code will "flip" the extra + # dims of the first arg to the far right, making extra linear equations + # (then call the matrix function, then flip back). + # We have verified that this optimization indeed happens. How? We stepped + # through with a debugger. + x = rng.rand(5, 7) + y = rng.rand(2, 3, 7, 5) + x_broadcast = x + np.zeros((2, 3, 5, 7)) + + with self.cached_session(): + result = linear_operator_util.matmul_with_broadcast(x, y) + self.assertAllEqual((2, 3, 5, 5), result.get_shape()) + expected = math_ops.matmul(x_broadcast, y) + self.assertAllClose(expected.eval(), result.eval()) + + def test_static_dims_broadcast_y_has_extra_dims_transpose_a_and_b(self): + # Since the second arg has extra dims, and the domain dim of the first arg + # is larger than the number of linear equations, code will "flip" the extra + # dims of the first arg to the far right, making extra linear equations + # (then call the matrix function, then flip back). + # We have verified that this optimization indeed happens. How? We stepped + # through with a debugger. + x = rng.rand(1, 7, 5) + y = rng.rand(2, 3, 1, 7) + x_broadcast = x + np.zeros((2, 3, 1, 1)) + + with self.cached_session(): + result = linear_operator_util.matmul_with_broadcast( + x, y, transpose_a=True, transpose_b=True) + self.assertAllEqual((2, 3, 5, 1), result.get_shape()) + expected = math_ops.matmul( + x_broadcast, y, transpose_a=True, transpose_b=True) + self.assertAllClose(expected.eval(), result.eval()) + + def test_static_dims_broadcast_y_has_extra_dims_transpose_dynamic(self): + # Since the second arg has extra dims, and the domain dim of the first arg + # is larger than the number of linear equations, code will "flip" the extra + # dims of the first arg to the far right, making extra linear equations + # (then call the matrix function, then flip back). + # We have verified that this optimization indeed happens. How? We stepped + # through with a debugger. + x = rng.rand(1, 7, 5) + y = rng.rand(2, 3, 1, 7) + x_broadcast = x + np.zeros((2, 3, 1, 1)) + + x_ph = array_ops.placeholder(dtypes.float64, [None, None, None]) + y_ph = array_ops.placeholder(dtypes.float64, [None, None, None, None]) + + with self.cached_session(): + result = linear_operator_util.matmul_with_broadcast( + x_ph, y_ph, transpose_a=True, transpose_b=True) + self.assertAllEqual(4, result.shape.ndims) + expected = math_ops.matmul( + x_broadcast, y, transpose_a=True, transpose_b=True) + self.assertAllClose(expected.eval(), + result.eval(feed_dict={ + x_ph: x, + y_ph: y + })) def test_dynamic_dims_broadcast_64bit(self): # batch_shape = [2] @@ -266,22 +328,87 @@ def test_dynamic_dims_broadcast_64bit(self): x_ph: x, y_ph: y }) - self.assertAllEqual(expected, result) + self.assertAllClose(expected, result) class MatrixSolveWithBroadcastTest(test.TestCase): - def test_static_dims_broadcast(self): + def test_static_dims_broadcast_matrix_has_extra_dims(self): + # batch_shape = [2] + matrix = rng.rand(2, 3, 3) + rhs = rng.rand(3, 7) + rhs_broadcast = rhs + np.zeros((2, 1, 1)) + + with self.cached_session(): + result = linear_operator_util.matrix_solve_with_broadcast( + matrix, rhs) + self.assertAllEqual((2, 3, 7), result.get_shape()) + expected = linalg_ops.matrix_solve(matrix, rhs_broadcast) + self.assertAllClose(expected.eval(), result.eval()) + + def test_static_dims_broadcast_rhs_has_extra_dims(self): + # Since the second arg has extra dims, and the domain dim of the first arg + # is larger than the number of linear equations, code will "flip" the extra + # dims of the first arg to the far right, making extra linear equations + # (then call the matrix function, then flip back). + # We have verified that this optimization indeed happens. How? We stepped + # through with a debugger. # batch_shape = [2] matrix = rng.rand(3, 3) - rhs = rng.rand(2, 3, 7) + rhs = rng.rand(2, 3, 2) matrix_broadcast = matrix + np.zeros((2, 1, 1)) with self.cached_session(): result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs) - self.assertAllEqual((2, 3, 7), result.get_shape()) + self.assertAllEqual((2, 3, 2), result.get_shape()) + expected = linalg_ops.matrix_solve(matrix_broadcast, rhs) + self.assertAllClose(expected.eval(), result.eval()) + + def test_static_dims_broadcast_rhs_has_extra_dims_dynamic(self): + # Since the second arg has extra dims, and the domain dim of the first arg + # is larger than the number of linear equations, code will "flip" the extra + # dims of the first arg to the far right, making extra linear equations + # (then call the matrix function, then flip back). + # We have verified that this optimization indeed happens. How? We stepped + # through with a debugger. + # batch_shape = [2] + matrix = rng.rand(3, 3) + rhs = rng.rand(2, 3, 2) + matrix_broadcast = matrix + np.zeros((2, 1, 1)) + + matrix_ph = array_ops.placeholder(dtypes.float64, shape=[None, None]) + rhs_ph = array_ops.placeholder(dtypes.float64, shape=[None, None, None]) + + with self.cached_session(): + result = linear_operator_util.matrix_solve_with_broadcast(matrix_ph, + rhs_ph) + self.assertAllEqual(3, result.shape.ndims) expected = linalg_ops.matrix_solve(matrix_broadcast, rhs) - self.assertAllEqual(expected.eval(), result.eval()) + self.assertAllClose( + expected.eval(), + result.eval(feed_dict={ + matrix_ph: matrix, + rhs_ph: rhs + })) + + def test_static_dims_broadcast_rhs_has_extra_dims_and_adjoint(self): + # Since the second arg has extra dims, and the domain dim of the first arg + # is larger than the number of linear equations, code will "flip" the extra + # dims of the first arg to the far right, making extra linear equations + # (then call the matrix function, then flip back). + # We have verified that this optimization indeed happens. How? We stepped + # through with a debugger. + # batch_shape = [2] + matrix = rng.rand(3, 3) + rhs = rng.rand(2, 3, 2) + matrix_broadcast = matrix + np.zeros((2, 1, 1)) + + with self.cached_session(): + result = linear_operator_util.matrix_solve_with_broadcast( + matrix, rhs, adjoint=True) + self.assertAllEqual((2, 3, 2), result.get_shape()) + expected = linalg_ops.matrix_solve(matrix_broadcast, rhs, adjoint=True) + self.assertAllClose(expected.eval(), result.eval()) def test_dynamic_dims_broadcast_64bit(self): # batch_shape = [2, 2] @@ -304,12 +431,12 @@ def test_dynamic_dims_broadcast_64bit(self): matrix_ph: matrix, rhs_ph: rhs, }) - self.assertAllEqual(expected, result) + self.assertAllClose(expected, result) class MatrixTriangularSolveWithBroadcastTest(test.TestCase): - def test_static_dims_broadcast(self): + def test_static_dims_broadcast_matrix_has_extra_dims(self): # batch_shape = [2] matrix = rng.rand(2, 3, 3) rhs = rng.rand(3, 7) @@ -320,7 +447,46 @@ def test_static_dims_broadcast(self): matrix, rhs) self.assertAllEqual((2, 3, 7), result.get_shape()) expected = linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast) - self.assertAllEqual(expected.eval(), result.eval()) + self.assertAllClose(expected.eval(), result.eval()) + + def test_static_dims_broadcast_rhs_has_extra_dims(self): + # Since the second arg has extra dims, and the domain dim of the first arg + # is larger than the number of linear equations, code will "flip" the extra + # dims of the first arg to the far right, making extra linear equations + # (then call the matrix function, then flip back). + # We have verified that this optimization indeed happens. How? We stepped + # through with a debugger. + # batch_shape = [2] + matrix = rng.rand(3, 3) + rhs = rng.rand(2, 3, 2) + matrix_broadcast = matrix + np.zeros((2, 1, 1)) + + with self.cached_session(): + result = linear_operator_util.matrix_triangular_solve_with_broadcast( + matrix, rhs) + self.assertAllEqual((2, 3, 2), result.get_shape()) + expected = linalg_ops.matrix_triangular_solve(matrix_broadcast, rhs) + self.assertAllClose(expected.eval(), result.eval()) + + def test_static_dims_broadcast_rhs_has_extra_dims_and_adjoint(self): + # Since the second arg has extra dims, and the domain dim of the first arg + # is larger than the number of linear equations, code will "flip" the extra + # dims of the first arg to the far right, making extra linear equations + # (then call the matrix function, then flip back). + # We have verified that this optimization indeed happens. How? We stepped + # through with a debugger. + # batch_shape = [2] + matrix = rng.rand(3, 3) + rhs = rng.rand(2, 3, 2) + matrix_broadcast = matrix + np.zeros((2, 1, 1)) + + with self.cached_session(): + result = linear_operator_util.matrix_triangular_solve_with_broadcast( + matrix, rhs, adjoint=True) + self.assertAllEqual((2, 3, 2), result.get_shape()) + expected = linalg_ops.matrix_triangular_solve( + matrix_broadcast, rhs, adjoint=True) + self.assertAllClose(expected.eval(), result.eval()) def test_dynamic_dims_broadcast_64bit(self): # batch_shape = [2] @@ -342,7 +508,7 @@ def test_dynamic_dims_broadcast_64bit(self): matrix_ph: matrix, rhs_ph: rhs, }) - self.assertAllEqual(expected, result) + self.assertAllClose(expected, result) class DomainDimensionStubOperator(object): diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py index d2128f0cb8d565..d57012dc860d85 100644 --- a/tensorflow/python/kernel_tests/list_ops_test.py +++ b/tensorflow/python/kernel_tests/list_ops_test.py @@ -36,7 +36,6 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import test -from tensorflow.python.training import server_lib def scalar_shape(): @@ -395,35 +394,51 @@ def body(i, m, t1): @test_util.run_in_graph_and_eager_modes def testSerialize(self): - # pylint: disable=g-import-not-at-top - try: - import portpicker - except ImportError: - return - with context.graph_mode(): - worker_port = portpicker.pick_unused_port() - ps_port = portpicker.pick_unused_port() - cluster_dict = { - "worker": ["localhost:%s" % worker_port], - "ps": ["localhost:%s" % ps_port] - } - cs = server_lib.ClusterSpec(cluster_dict) - - worker = server_lib.Server( - cs, job_name="worker", protocol="grpc", task_index=0, start=True) - unused_ps = server_lib.Server( - cs, job_name="ps", protocol="grpc", task_index=0, start=True) - with ops.Graph().as_default(), session.Session(target=worker.target): - with ops.device("/job:worker"): - t = constant_op.constant([[1.0], [2.0]]) - l = list_ops.tensor_list_from_tensor(t, element_shape=[1]) - with ops.device("/job:ps"): - l_ps = array_ops.identity(l) - l_ps, e = list_ops.tensor_list_pop_back( - l_ps, element_dtype=dtypes.float32) - with ops.device("/job:worker"): - worker_e = array_ops.identity(e) - self.assertAllEqual(self.evaluate(worker_e), [2.0]) + worker = test_util.create_local_cluster(num_workers=1, num_ps=1)[0][0] + with ops.Graph().as_default(), session.Session(target=worker.target): + with ops.device("/job:worker"): + t = constant_op.constant([[1.0], [2.0]]) + l = list_ops.tensor_list_from_tensor(t, element_shape=[1]) + with ops.device("/job:ps"): + l_ps = array_ops.identity(l) + l_ps, e = list_ops.tensor_list_pop_back( + l_ps, element_dtype=dtypes.float32) + with ops.device("/job:worker"): + worker_e = array_ops.identity(e) + self.assertAllEqual(self.evaluate(worker_e), [2.0]) + + @test_util.run_in_graph_and_eager_modes + def testSerializeListWithInvalidTensors(self): + worker = test_util.create_local_cluster(num_workers=1, num_ps=1)[0][0] + with ops.Graph().as_default(), session.Session(target=worker.target): + with ops.device("/job:worker"): + l = list_ops.tensor_list_reserve( + element_dtype=dtypes.float32, + element_shape=scalar_shape(), + num_elements=2) + l = list_ops.tensor_list_set_item(l, 0, 1.) + with ops.device("/job:ps"): + l_ps = array_ops.identity(l) + l_ps = list_ops.tensor_list_set_item(l_ps, 1, 2.) + t = list_ops.tensor_list_stack(l_ps, element_dtype=dtypes.float32) + with ops.device("/job:worker"): + worker_t = array_ops.identity(t) + self.assertAllEqual(self.evaluate(worker_t), [1.0, 2.0]) + + @test_util.run_in_graph_and_eager_modes + def testSerializeListWithUnknownRank(self): + worker = test_util.create_local_cluster(num_workers=1, num_ps=1)[0][0] + with ops.Graph().as_default(), session.Session(target=worker.target): + with ops.device("/job:worker"): + t = constant_op.constant([[1.0], [2.0]]) + l = list_ops.tensor_list_from_tensor(t, element_shape=-1) + with ops.device("/job:ps"): + l_ps = array_ops.identity(l) + element_shape = list_ops.tensor_list_element_shape( + l_ps, shape_type=dtypes.int32) + with ops.device("/job:worker"): + element_shape = array_ops.identity(element_shape) + self.assertEqual(self.evaluate(element_shape), -1) @test_util.run_in_graph_and_eager_modes def testPushPopGradients(self): diff --git a/tensorflow/python/kernel_tests/lookup_ops_test.py b/tensorflow/python/kernel_tests/lookup_ops_test.py index 6791a03e2e1acc..bd93942efbd016 100644 --- a/tensorflow/python/kernel_tests/lookup_ops_test.py +++ b/tensorflow/python/kernel_tests/lookup_ops_test.py @@ -44,7 +44,7 @@ def testHashTable(self): values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup_ops.HashTable( lookup_ops.KeyValueTensorInitializer(keys, values), default_val) - table.init.run() + table.initializer.run() self.assertAllEqual(3, table.size().eval()) @@ -68,7 +68,7 @@ def testHashTableFindHighRank(self): values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup_ops.HashTable( lookup_ops.KeyValueTensorInitializer(keys, values), default_val) - table.init.run() + table.initializer.run() self.assertAllEqual(3, table.size().eval()) @@ -87,7 +87,7 @@ def testHashTableInitWithPythonArrays(self): table = lookup_ops.HashTable( lookup_ops.KeyValueTensorInitializer( keys, values, value_dtype=dtypes.int64), default_val) - table.init.run() + table.initializer.run() self.assertAllEqual(3, table.size().eval()) @@ -104,7 +104,7 @@ def testHashTableInitWithNumPyArrays(self): values = np.array([0, 1, 2], dtype=np.int64) table = lookup_ops.HashTable( lookup_ops.KeyValueTensorInitializer(keys, values), default_val) - table.init.run() + table.initializer.run() self.assertAllEqual(3, table.size().eval()) @@ -149,7 +149,7 @@ def testHashTableWithTensorDefault(self): values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup_ops.HashTable( lookup_ops.KeyValueTensorInitializer(keys, values), default_val) - table.init.run() + table.initializer.run() input_string = constant_op.constant(["brain", "salad", "tank"]) output = table.lookup(input_string) @@ -164,7 +164,7 @@ def testHashTableWithSparseTensorInput(self): values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup_ops.HashTable( lookup_ops.KeyValueTensorInitializer(keys, values), default_val) - table.init.run() + table.initializer.run() sp_indices = [[0, 0], [0, 1], [1, 0]] sp_shape = [2, 2] @@ -187,7 +187,7 @@ def testSignatureMismatch(self): values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup_ops.HashTable( lookup_ops.KeyValueTensorInitializer(keys, values), default_val) - table.init.run() + table.initializer.run() # Ref types do not produce a lookup signature mismatch. input_string_ref = variables.Variable("brain") @@ -230,10 +230,10 @@ def testInitializeTwice(self): values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup_ops.HashTable( lookup_ops.KeyValueTensorInitializer(keys, values), default_val) - table.init.run() + table.initializer.run() with self.assertRaisesOpError("Table already initialized"): - table.init.run() + table.initializer.run() def testInitializationWithInvalidDimensions(self): with self.cached_session(): @@ -265,13 +265,13 @@ def testMultipleSessions(self): # Init the table in the first session. with session1: - table.init.run() + table.initializer.run() self.assertAllEqual(3, table.size().eval()) # Init the table in the second session and verify that we do not get a # "Table already initialized" error. with session2: - table.init.run() + table.initializer.run() self.assertAllEqual(3, table.size().eval()) def testHashTableInt32String(self): @@ -281,7 +281,7 @@ def testHashTableInt32String(self): values = constant_op.constant(["brain", "salad", "surgery"]) table = lookup_ops.HashTable( lookup_ops.KeyValueTensorInitializer(keys, values), default_val) - table.init.run() + table.initializer.run() input_tensor = constant_op.constant([0, 1, -1]) output = table.lookup(input_tensor) @@ -305,7 +305,8 @@ def test_string_index_table_from_file(self): vocabulary_file=vocabulary_file, num_oov_buckets=1) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) - self.assertRaises(errors_impl.OpError, ids.eval) + with self.assertRaises(errors_impl.OpError): + ids.eval() lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) @@ -320,7 +321,8 @@ def test_string_index_table_from_multicolumn_file(self): value_column_index=lookup_ops.TextFileIndex.LINE_NUMBER) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) - self.assertRaises(errors_impl.OpError, ids.eval) + with self.assertRaises(errors_impl.OpError): + ids.eval() lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) @@ -336,7 +338,8 @@ def test_string_index_table_from_multicolumn_file_custom_delimiter(self): delimiter=" ") ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) - self.assertRaises(errors_impl.OpError, ids.eval) + with self.assertRaises(errors_impl.OpError): + ids.eval() lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) @@ -348,7 +351,8 @@ def test_string_index_table_from_file_tensor_filename(self): vocabulary_file=vocabulary_file, num_oov_buckets=1) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) - self.assertRaises(errors_impl.OpError, ids.eval) + with self.assertRaises(errors_impl.OpError): + ids.eval() lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) self.assertEqual(1, @@ -362,7 +366,8 @@ def test_string_index_table_from_file_placeholder_filename(self): vocabulary_file=vocabulary_placeholder, num_oov_buckets=1) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) - self.assertRaises(errors_impl.OpError, ids.eval) + with self.assertRaises(errors_impl.OpError): + ids.eval() feed_dict = {vocabulary_placeholder.name: vocabulary_file} lookup_ops.tables_initializer().run(feed_dict=feed_dict) @@ -381,7 +386,8 @@ def test_int32_index_table_from_file(self): ids = table.lookup( constant_op.constant((1, -1000, 11), dtype=dtypes.int32)) - self.assertRaises(errors_impl.OpError, ids.eval) + with self.assertRaises(errors_impl.OpError): + ids.eval() lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) @@ -396,7 +402,8 @@ def test_int64_index_table_from_file(self): ids = table.lookup( constant_op.constant((1, -1000, 11), dtype=dtypes.int64)) - self.assertRaises(errors_impl.OpError, ids.eval) + with self.assertRaises(errors_impl.OpError): + ids.eval() lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) @@ -408,7 +415,8 @@ def test_index_table_from_file_with_default_value(self): vocabulary_file=vocabulary_file, default_value=default_value) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) - self.assertRaises(errors_impl.OpError, ids.eval) + with self.assertRaises(errors_impl.OpError): + ids.eval() lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, default_value), ids.eval()) @@ -420,7 +428,8 @@ def test_index_table_from_file_with_oov_buckets(self): ids = table.lookup( constant_op.constant(["salad", "surgery", "tarkus", "toccata"])) - self.assertRaises(errors_impl.OpError, ids.eval) + with self.assertRaises(errors_impl.OpError): + ids.eval() lookup_ops.tables_initializer().run() self.assertAllEqual( ( @@ -466,7 +475,8 @@ def test_index_table_from_file_with_vocab_size_too_small(self): vocabulary_file=vocabulary_file, vocab_size=2) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) - self.assertRaises(errors_impl.OpError, ids.eval) + with self.assertRaises(errors_impl.OpError): + ids.eval() lookup_ops.tables_initializer().run() self.assertAllEqual((1, -1, -1), ids.eval()) self.assertEqual(2, table.size().eval()) @@ -477,7 +487,7 @@ def test_index_table_from_file_with_vocab_size_too_large(self): table = lookup_ops.index_table_from_file( vocabulary_file=vocabulary_file, vocab_size=4) self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "Invalid vocab_size", table.init.run) + "Invalid vocab_size", table.initializer.run) def test_index_table_from_file_with_vocab_size(self): vocabulary_file = self._createVocabFile("f2i_vocab8.txt") @@ -493,7 +503,8 @@ def test_index_table_from_file_with_vocab_size(self): vocabulary_file=vocabulary_file, vocab_size=3) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) - self.assertRaises(errors_impl.OpError, ids.eval) + with self.assertRaises(errors_impl.OpError): + ids.eval() lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, -1), ids.eval()) self.assertEqual(3, table.size().eval()) @@ -522,14 +533,14 @@ def test_index_table_from_file_table_ref_with_oov_buckets(self): with self.cached_session(): table = lookup_ops.index_table_from_file( vocabulary_file=vocabulary_file, num_oov_buckets=1) - self.assertIsNotNone(table.table_ref) + self.assertIsNotNone(table.resource_handle) def test_index_table_from_file_table_ref_without_oov_buckets(self): vocabulary_file = self._createVocabFile("f2i_vocab10.txt") with self.cached_session(): table = lookup_ops.index_table_from_file( vocabulary_file=vocabulary_file, num_oov_buckets=0) - self.assertIsNotNone(table.table_ref) + self.assertIsNotNone(table.resource_handle) class KeyValueTensorInitializerTest(test.TestCase): @@ -539,14 +550,32 @@ def test_string(self): init = lookup_ops.KeyValueTensorInitializer( ("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64) table = lookup_ops.HashTable(init, default_value=-1) - table.init.run() + table.initializer.run() + + def test_multiple_tables(self): + with ops.Graph().as_default(), self.cached_session(): + with ops.name_scope("table_scope"): + init1 = lookup_ops.KeyValueTensorInitializer( + ("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, + dtypes.int64) + table1 = lookup_ops.HashTable(init1, default_value=-1) + self.assertEquals("hash_table", table1.name) + self.assertEquals("table_scope/hash_table", + table1.resource_handle.op.name) + init2 = lookup_ops.KeyValueTensorInitializer( + ("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, + dtypes.int64) + table2 = lookup_ops.HashTable(init2, default_value=-1) + self.assertEquals("hash_table_1", table2.name) + self.assertEquals("table_scope/hash_table_1", + table2.resource_handle.op.name) def test_int64(self): with ops.Graph().as_default(), self.cached_session(): init = lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64) table = lookup_ops.HashTable(init, default_value=-1) - table.init.run() + table.initializer.run() def test_int32(self): with ops.Graph().as_default(), self.cached_session(): @@ -555,7 +584,7 @@ def test_int32(self): table = lookup_ops.HashTable(init, default_value=-1) with self.assertRaisesRegexp( errors_impl.OpError, "No OpKernel was registered"): - table.init.run() + table.initializer.run() class IndexTableFromTensor(test.TestCase): @@ -584,7 +613,8 @@ def test_int32_index_table_from_tensor_with_tensor_init(self): ids = table.lookup( constant_op.constant((1, -1000, 11), dtype=dtypes.int32)) - self.assertRaises(errors_impl.OpError, ids.eval) + with self.assertRaises(errors_impl.FailedPreconditionError): + ids.eval() lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) @@ -595,7 +625,8 @@ def test_int64_index_table_from_tensor_with_tensor_init(self): ids = table.lookup( constant_op.constant((1, -1000, 11), dtype=dtypes.int64)) - self.assertRaises(errors_impl.OpError, ids.eval) + with self.assertRaises(errors_impl.FailedPreconditionError): + ids.eval() lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) @@ -607,7 +638,8 @@ def test_index_table_from_tensor_with_default_value(self): default_value=default_value) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) - self.assertRaises(errors_impl.OpError, ids.eval) + with self.assertRaises(errors_impl.FailedPreconditionError): + ids.eval() lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, default_value), ids.eval()) @@ -623,7 +655,8 @@ def test_index_table_from_tensor_empty_vocabulary_list(self): table = lookup_ops.index_table_from_tensor( vocabulary_list=np.array([], dtype=np.str_), num_oov_buckets=1) ids = table.lookup(constant_op.constant(["salad", "surgery", "brain"])) - self.assertRaises(errors_impl.OpError, ids.eval) + with self.assertRaises(errors_impl.OpError): + ids.eval() with self.assertRaisesRegexp( errors_impl.OpError, "keys and values cannot be empty"): lookup_ops.tables_initializer().run() @@ -664,7 +697,8 @@ def test_index_to_string_table(self): vocabulary_file=vocabulary_file) features = table.lookup( constant_op.constant([0, 1, 2, 3], dtypes.int64)) - self.assertRaises(errors_impl.OpError, features.eval) + with self.assertRaises(errors_impl.OpError): + features.eval() lookup_ops.tables_initializer().run() self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), features.eval()) @@ -678,7 +712,8 @@ def test_index_to_string_table_from_multicolumn_file(self): key_column_index=lookup_ops.TextFileIndex.LINE_NUMBER, value_column_index=0) features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64)) - self.assertRaises(errors_impl.OpError, features.eval) + with self.assertRaises(errors_impl.OpError): + features.eval() lookup_ops.tables_initializer().run() self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), features.eval()) @@ -693,7 +728,8 @@ def test_index_to_string_table_from_multicolumn_file_custom_delimiter(self): value_column_index=0, delimiter=" ") features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64)) - self.assertRaises(errors_impl.OpError, features.eval) + with self.assertRaises(errors_impl.OpError): + features.eval() lookup_ops.tables_initializer().run() self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), features.eval()) @@ -705,7 +741,8 @@ def test_index_to_string_table_with_default_value(self): table = lookup_ops.index_to_string_table_from_file( vocabulary_file=vocabulary_file, default_value=default_value) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) - self.assertRaises(errors_impl.OpError, features.eval) + with self.assertRaises(errors_impl.OpError): + features.eval() lookup_ops.tables_initializer().run() self.assertAllEqual((b"salad", b"surgery", default_value), features.eval()) @@ -719,7 +756,8 @@ def test_index_to_string_table_with_vocab_size_too_small(self): vocab_size=2, default_value=default_value) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) - self.assertRaises(errors_impl.OpError, features.eval) + with self.assertRaises(errors_impl.OpError): + features.eval() lookup_ops.tables_initializer().run() self.assertAllEqual((b"salad", default_value, default_value), features.eval()) @@ -731,7 +769,8 @@ def test_index_to_string_table_with_vocab_size_too_large(self): vocabulary_file=vocabulary_file, vocab_size=4) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) - self.assertRaises(errors_impl.OpError, features.eval) + with self.assertRaises(errors_impl.OpError): + features.eval() init = lookup_ops.tables_initializer() self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "Invalid vocab_size", init.run) @@ -743,7 +782,8 @@ def test_index_to_string_table_with_vocab_size(self): vocabulary_file=vocabulary_file, vocab_size=3) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) - self.assertRaises(errors_impl.OpError, features.eval) + with self.assertRaises(errors_impl.OpError): + features.eval() lookup_ops.tables_initializer().run() self.assertAllEqual((b"salad", b"surgery", b"UNK"), features.eval()) @@ -758,7 +798,8 @@ def test_index_to_string_table_from_tensor(self): indices = constant_op.constant([0, 1, 2, 3], dtypes.int64) features = table.lookup(indices) - self.assertRaises(errors_impl.OpError, features.eval) + with self.assertRaises(errors_impl.OpError): + features.eval() lookup_ops.tables_initializer().run() self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), @@ -782,7 +823,8 @@ def test_index_to_string_with_default_value(self): vocabulary_list=vocabulary_list, default_value=default_value) indices = constant_op.constant([1, 2, 4], dtypes.int64) features = table.lookup(indices) - self.assertRaises(errors_impl.OpError, features.eval) + with self.assertRaises(errors_impl.OpError): + features.eval() lookup_ops.tables_initializer().run() self.assertAllEqual((b"salad", b"surgery", default_value), @@ -805,7 +847,7 @@ def testInitializeStringTable(self): lookup_ops.TextFileInitializer( vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE, dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER), default_value) - self.evaluate(table.init) + self.evaluate(table.initializer) output = table.lookup(constant_op.constant(["brain", "salad", "tank"])) @@ -823,7 +865,7 @@ def testInitializeInt64Table(self): vocabulary_file, dtypes.int64, lookup_ops.TextFileIndex.WHOLE_LINE, dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER), default_value) - table.init.run() + table.initializer.run() output = table.lookup( constant_op.constant((42, 1, 11), dtype=dtypes.int64)) @@ -842,7 +884,7 @@ def testInitializeIndexTable(self): lookup_ops.TextFileInitializer(vocabulary_file, dtypes.int64, key_index, dtypes.string, value_index), default_value) - table.init.run() + table.initializer.run() input_values = constant_op.constant([0, 1, 2, 3], dtypes.int64) output = table.lookup(input_values) @@ -864,7 +906,7 @@ def testMultiColumn(self): lookup_ops.TextFileInitializer(vocabulary_file, dtypes.string, key_index, dtypes.int64, value_index), default_value) - table.init.run() + table.initializer.run() input_string = constant_op.constant(["brain", "salad", "surgery"]) output = table.lookup(input_string) @@ -886,7 +928,7 @@ def testInvalidDataTypeInMultiColumn(self): key_index, dtypes.int64, value_index), default_value) with self.assertRaisesOpError("is not a valid"): - table.init.run() + table.initializer.run() def testInvalidDataType(self): vocabulary_file = self._createVocabFile("one_column_3.txt") @@ -914,7 +956,7 @@ def testInvalidIndex(self): default_value) with self.assertRaisesOpError("Invalid number of columns"): - table.init.run() + table.initializer.run() def testInitializeSameTableWithMultipleNodes(self): vocabulary_file = self._createVocabFile("one_column_5.txt") @@ -982,7 +1024,7 @@ def testInitializeWithVocabSize(self): vocab_size=vocab_size), default_value) # Initialize from file. - table1.init.run() + table1.initializer.run() self.assertEquals(vocab_size, table1.size().eval()) vocabulary_file2 = self._createVocabFile("one_column7.txt") @@ -996,7 +1038,7 @@ def testInitializeWithVocabSize(self): lookup_ops.TextFileIndex.LINE_NUMBER, vocab_size=vocab_size), default_value) with self.assertRaisesOpError("Invalid vocab_size"): - table2.init.run() + table2.initializer.run() vocab_size = 1 vocabulary_file3 = self._createVocabFile("one_column3.txt") @@ -1010,7 +1052,7 @@ def testInitializeWithVocabSize(self): vocab_size=vocab_size), default_value) # Smaller vocab size reads only vocab_size records. - table3.init.run() + table3.initializer.run() self.assertEquals(vocab_size, table3.size().eval()) def testFeedVocabularyName(self): @@ -1027,11 +1069,11 @@ def testFeedVocabularyName(self): # Initialize with non existing file (old_file.txt) should fail. # TODO(yleon): Update message, which might change per FileSystem. with self.assertRaisesOpError("old_file.txt"): - table.init.run() + table.initializer.run() # Initialize the model feeding the vocabulary file. filenames = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS) - table.init.run(feed_dict={filenames[0]: vocabulary_file}) + table.initializer.run(feed_dict={filenames[0]: vocabulary_file}) input_string = constant_op.constant(["brain", "salad", "tank"]) output = table.lookup(input_string) @@ -1072,7 +1114,7 @@ def testIdToStringTable(self): lookup_ops.TextFileStringTableInitializer( vocab_file, vocab_size=vocab_size), default_value) - table.init.run() + table.initializer.run() input_values = constant_op.constant([0, 1, 2, 3], dtypes.int64) @@ -1088,7 +1130,7 @@ def testStringToIdTable(self): table = lookup_ops.HashTable( lookup_ops.TextFileIdTableInitializer( vocab_file, vocab_size=vocab_size), default_value) - table.init.run() + table.initializer.run() input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"]) @@ -1106,7 +1148,7 @@ def testInt64ToIdTable(self): lookup_ops.TextFileIdTableInitializer( vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64), default_value) - table.init.run() + table.initializer.run() out = table.lookup( constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64)) @@ -1134,7 +1176,7 @@ def testStringIdTableWithHashBuckets(self): vocab_file, vocab_size=vocab_size), default_value), oov_buckets) - table.init.run() + table.initializer.run() input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"]) @@ -1156,7 +1198,7 @@ def testInt32IdTableWithHashBuckets(self): oov_buckets, key_dtype=dtypes.int32) - table.init.run() + table.initializer.run() values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int32) @@ -1176,7 +1218,7 @@ def testInt64IdTableWithHashBuckets(self): vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64), default_value), oov_buckets) - table.init.run() + table.initializer.run() values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64) @@ -1191,7 +1233,7 @@ def testStringIdTableWithOnlyHashBucket(self): # Set a table that only uses hash buckets, for each input value returns # an id calculated by fingerprint("input") mod oov_buckets. table = lookup_ops.IdTableWithHashBuckets(None, oov_buckets) - table.init.run() + table.initializer.run() values = constant_op.constant(("brain", "salad", "surgery")) @@ -1213,7 +1255,7 @@ def testInt32IdTableWithOnlyHashBucket(self): # an id calculated by fingerprint("input") mod oov_buckets. table = lookup_ops.IdTableWithHashBuckets( None, oov_buckets, key_dtype=dtypes.int32) - table.init.run() + table.initializer.run() input_string = constant_op.constant([42, 1, -1000], dtype=dtypes.int32) @@ -1293,7 +1335,7 @@ def testIdTableWithHashBucketsInitializationAcrossSessions(self): default_value, shared_name=shared_name), oov_buckets) - table1.init.run() + table1.initializer.run() input_string_1 = constant_op.constant( ["brain", "salad", "surgery", "UNK"]) @@ -1309,7 +1351,7 @@ def testIdTableWithHashBucketsInitializationAcrossSessions(self): oov_buckets = 1 # Underlying lookup table already initialized in previous session. - # No need to call table2.init.run() + # No need to call table2.initializer.run() table2 = lookup_ops.IdTableWithHashBuckets( lookup_ops.HashTable( lookup_ops.TextFileIdTableInitializer( @@ -1373,7 +1415,7 @@ def testSparseTensor(self): lookup_ops.HashTable( lookup_ops.TextFileIdTableInitializer(vocab_file, vocab_size=3), -1), 1) - table.init.run() + table.initializer.run() sp_ids = table.lookup(sp_features) @@ -1401,7 +1443,7 @@ def testInt32SparseTensor(self): (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64), -1), 1, key_dtype=dtypes.int32) - table.init.run() + table.initializer.run() sp_ids = table.lookup(sp_features) @@ -1429,7 +1471,7 @@ def testInt64SparseTensor(self): (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64), -1), 1, key_dtype=dtypes.int64) - table.init.run() + table.initializer.run() sp_ids = table.lookup(sp_features) @@ -1487,7 +1529,8 @@ def testIdTableWithHashBucketsWithInvalidHashers(self): def testIdTableWithHashBucketsNoInnerTable(self): with self.cached_session(): table = lookup_ops.IdTableWithHashBuckets(None, num_oov_buckets=1) - self.assertIsNone(table.table_ref) + self.assertIsNone(table.resource_handle) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py index 4760236ca0e1a0..1c2822180ac986 100644 --- a/tensorflow/python/kernel_tests/matmul_op_test.py +++ b/tensorflow/python/kernel_tests/matmul_op_test.py @@ -35,6 +35,19 @@ # os.environ["TF_MATMUL_AUTOTUNE_ENABLE"] = "1" to enable it. +class MatVecTest(test_lib.TestCase): + """Simple test for matvec, which is sugar on top of matmul.""" + + def testTwoByTwoCase(self): + a = np.array([[1, 2], [3, 4]]) + b = np.array([5, 6]) + with self.cached_session(): + c = math_ops.matvec(a, b) + self.assertAllEqual((2,), c.shape) + c_ = c.eval() + self.assertAllEqual([5 + 2 * 6, 3 * 5 + 4 * 6], c_) + + def _AddTest(test, op_name, testcase_name, fn): test_name = "_".join(["test", op_name, testcase_name]) if hasattr(test, test_name): diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index a9fd93e9f8760f..c8227dc117f316 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -915,6 +915,17 @@ def testAssignIncompatibleShape(self): with self.assertRaisesRegexp(Exception, r"hapes must be equal"): self.assertAllEqual(self.evaluate(v.assign_add(1)), [1, 2, 3, 4]) + @test_util.run_in_graph_and_eager_modes + def testCopyToGraphUninitialized(self): + v = resource_variable_ops.ResourceVariable([0, 1, 2, 3]) + copy_to_graph = ops.Graph() + with copy_to_graph.as_default(): # Intentionally testing v1 behavior + copied = resource_variable_ops.copy_to_graph_uninitialized(v) + self.assertEqual(v.name, copied.name) + with self.session(copy_to_graph) as session: + with self.assertRaises(errors.InvalidArgumentError): + session.run(copied.initializer) + class _MixedPrecisionVariableTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/kernel_tests/slice_op_test.py b/tensorflow/python/kernel_tests/slice_op_test.py index 0e8c276ba9c899..41f040ab739451 100644 --- a/tensorflow/python/kernel_tests/slice_op_test.py +++ b/tensorflow/python/kernel_tests/slice_op_test.py @@ -50,9 +50,9 @@ def testInt32(self): slice_val = slice_t.eval() self.assertAllEqual(slice_val, inp[2, k:k]) - def testInt64Slicing(self): - with self.cached_session(use_gpu=True): - a = constant_op.constant([0, 1, 2], dtype=dtypes.int64) + def testSlicingWithInt64Index(self): + with self.cached_session(force_gpu=test.is_gpu_available()): + a = constant_op.constant([0, 1, 2], dtype=dtypes.int32) # Slice using int64 Tensor. i = constant_op.constant(1, dtype=dtypes.int64) @@ -72,6 +72,46 @@ def testInt64Slicing(self): slice_val = slice_t.eval() self.assertAllEqual([1], slice_val) + a_int32 = constant_op.constant([0, 1, 2], dtype=dtypes.int32) + slice_t = array_ops.slice(a_int32, + np.asarray([1]).astype(np.int64), + np.asarray([2]).astype(np.int64)) + slice_val = slice_t.eval() + self.assertAllEqual([1, 2], slice_val) + + a_float32 = constant_op.constant([0, 1, 2], dtype=dtypes.float32) + slice_t = array_ops.slice(a_float32, + np.asarray([1]).astype(np.int64), + np.asarray([2]).astype(np.int64)) + slice_val = slice_t.eval() + self.assertAllEqual([1, 2], slice_val) + + def testSlicingInt64Tensor(self): + with self.cached_session(force_gpu=test.is_gpu_available()): + a = constant_op.constant([0, 1, 2], dtype=dtypes.int64) + + # Slice using int32 Tensor. + i = constant_op.constant(1, dtype=dtypes.int32) + slice_t = a[i] + slice_val = slice_t.eval() + self.assertAllEqual(1, slice_val) + slice_t = a[i:i + 1] + slice_val = slice_t.eval() + self.assertAllEqual([1], slice_val) + + # Slice using int32 integer. + i = np.asarray(1).astype(np.int32) + slice_t = a[i] + slice_val = slice_t.eval() + self.assertAllEqual(1, slice_val) + slice_t = a[i:i + 1] + slice_val = slice_t.eval() + self.assertAllEqual([1], slice_val) + + slice_t = array_ops.slice(a, [1], [2]) + slice_val = slice_t.eval() + self.assertAllEqual([1, 2], slice_val) + def testSelectAll(self): for _ in range(10): with self.cached_session(use_gpu=True): diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py index af18e4b81444f0..dc1bcb78b8066c 100644 --- a/tensorflow/python/kernel_tests/while_v2_test.py +++ b/tensorflow/python/kernel_tests/while_v2_test.py @@ -226,7 +226,7 @@ def Body(x, tl): ("PartiallyDefinedShape", [None, 2]), ("FullyDefinedShape", [1, 2]), ) - def testTensorListOutputElementShape(self, shape): + def testAccumulatorElementShape(self, shape): def MatchShape(actual_tensor_shape): # Compare the shapes, treating None dimensions as equal. We do not @@ -267,7 +267,7 @@ def GetAccumulatorForInputAtIndex(while_op, idx): # values of grad_y. # grad_while_op.inputs: # [counter_arg, total_iters_arg, grad_x_arg, grad_y_arg, *other_args] - grad_output = GetAccumulatorForInputAtIndex(grad_while_op, 4) + grad_output = GetAccumulatorForInputAtIndex(grad_while_op, 3) _, val = list_ops.tensor_list_pop_back(grad_output, element_dtype=dtypes.float32) MatchShape(val.shape) diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index 71ece0d392a5d1..ec6615ea86657b 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -190,10 +190,10 @@ def batch_normalization(inputs, Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they - need to be added as a dependency to the `train_op`. Also, be sure to add - any batch_normalization ops before getting the update_ops collection. - Otherwise, update_ops will be empty, and training/inference will not work - properly. For example: + need to be executed alongside the `train_op`. Also, be sure to add any + batch_normalization ops before getting the update_ops collection. Otherwise, + update_ops will be empty, and training/inference will not work properly. For + example: ```python x_norm = tf.layers.batch_normalization(x, training=training) @@ -201,8 +201,8 @@ def batch_normalization(inputs, # ... update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) - with tf.control_dependencies(update_ops): - train_op = optimizer.minimize(loss) + train_op = optimizer.minimize(loss) + train_op = tf.group([train_op, update_ops]) ``` Arguments: diff --git a/tensorflow/python/lib/io/python_io.py b/tensorflow/python/lib/io/python_io.py index 404423ce07b3bb..8223d3092fc085 100644 --- a/tensorflow/python/lib/io/python_io.py +++ b/tensorflow/python/lib/io/python_io.py @@ -13,10 +13,7 @@ # limitations under the License. # ============================================================================== -"""Python functions for directly manipulating TFRecord-formatted files. - -See the [Python IO](https://tensorflow.org/api_guides/python/python_io) guide. -""" +"""Python functions for directly manipulating TFRecord-formatted files.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index ee328df208e667..6fdc50733a1c19 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -13,10 +13,7 @@ # limitations under the License. # ============================================================================== # Tests for this file live in python/kernel_tests/array_ops_test.py -"""Support for manipulating tensors. - -See the [Array Ops](https://tensorflow.org/api_guides/python/array_ops) guide. -""" +"""Support for manipulating tensors.""" from __future__ import absolute_import from __future__ import division @@ -123,7 +120,7 @@ def expand_dims(input, axis=None, name=None, dim=None): axis: 0-D (scalar). Specifies the dimension index at which to expand the shape of `input`. Must be in the range `[-rank(input) - 1, rank(input)]`. - name: The name of the output `Tensor`. + name: The name of the output `Tensor` (optional). dim: 0-D (scalar). Equivalent to `axis`, to be deprecated. Returns: @@ -131,9 +128,11 @@ def expand_dims(input, axis=None, name=None, dim=None): dimension of size 1 added. Raises: - ValueError: if both `dim` and `axis` are specified. + ValueError: if either both or neither of `dim` and `axis` are specified. """ axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim) + if axis is None: + raise ValueError("Must specify an axis argument to tf.expand_dims()") return gen_array_ops.expand_dims(input, axis, name) diff --git a/tensorflow/python/ops/boosted_trees_ops.py b/tensorflow/python/ops/boosted_trees_ops.py index 720f9f4d41e4cc..87d3918a5f9ba7 100644 --- a/tensorflow/python/ops/boosted_trees_ops.py +++ b/tensorflow/python/ops/boosted_trees_ops.py @@ -40,6 +40,7 @@ # pylint: enable=unused-import from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import tracking class PruningMode(object): @@ -102,35 +103,52 @@ def restore(self, restored_tensors, unused_restored_shapes): tree_ensemble_serialized=restored_tensors[1]) -class TreeEnsemble(object): +class TreeEnsemble(tracking.TrackableResource): """Creates TreeEnsemble resource.""" def __init__(self, name, stamp_token=0, is_local=False, serialized_proto=''): + self._stamp_token = stamp_token + self._serialized_proto = serialized_proto + self._is_local = is_local with ops.name_scope(name, 'TreeEnsemble') as name: - self._resource_handle = ( - gen_boosted_trees_ops.boosted_trees_ensemble_resource_handle_op( - container='', shared_name=name, name=name)) - create_op = gen_boosted_trees_ops.boosted_trees_create_ensemble( - self.resource_handle, - stamp_token, - tree_ensemble_serialized=serialized_proto) - is_initialized_op = ( - gen_boosted_trees_ops.is_boosted_trees_ensemble_initialized( - self._resource_handle)) + self._name = name + self._resource_handle = self.create_resource() + self._init_op = self.initialize() + is_initialized_op = self.is_initialized() # Adds the variable to the savable list. if not is_local: - saveable = _TreeEnsembleSavable(self.resource_handle, create_op, - self.resource_handle.name) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + self._saveable = _TreeEnsembleSavable( + self.resource_handle, self.initializer, self.resource_handle.name) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) resources.register_resource( self.resource_handle, - create_op, + self.initializer, is_initialized_op, is_shared=not is_local) + def create_resource(self): + return gen_boosted_trees_ops.boosted_trees_ensemble_resource_handle_op( + container='', shared_name=self._name, name=self._name) + + def initialize(self): + return gen_boosted_trees_ops.boosted_trees_create_ensemble( + self.resource_handle, + self._stamp_token, + tree_ensemble_serialized=self._serialized_proto) + @property - def resource_handle(self): - return self._resource_handle + def initializer(self): + if self._init_op is None: + self._init_op = self.initialize() + return self._init_op + + def is_initialized(self): + return gen_boosted_trees_ops.is_boosted_trees_ensemble_initialized( + self.resource_handle) + + def _gather_saveables_for_checkpoint(self): + if not self._is_local: + return {'tree_ensemble': self._saveable} def get_stamp_token(self): """Returns the current stamp token of the resource.""" diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index 40b111ea0c2bca..5589bbc848597c 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -13,11 +13,7 @@ # limitations under the License. # ============================================================================== # pylint: disable=g-short-docstring-punctuation -"""Asserts and Boolean Checks. - -See the [Asserts and -checks](https://tensorflow.org/api_guides/python/check_ops) guide. -""" +"""Asserts and Boolean Checks.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 5604af665eface..d74ab732d74d4f 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -14,8 +14,7 @@ # ============================================================================== """Control Flow Operations. -See the [Control -Flow](https://tensorflow.org/api_guides/python/control_flow_ops) guide. +See the [autograph](https://www.tensorflow.org/guide/autographs) guide. """ # pylint: disable=g-bad-name from __future__ import absolute_import @@ -1485,6 +1484,7 @@ def ZerosLikeOutsideLoop(op, index): return array_ops.zeros_like(val, optimize=False) +@six.add_metaclass(abc.ABCMeta) class ControlFlowContext(object): """The base class for control flow context. @@ -3239,7 +3239,12 @@ def while_loop(cond, """ if ENABLE_WHILE_V2 and not context.executing_eagerly(): return while_v2.while_loop( - cond, body, loop_vars, shape_invariants=shape_invariants, name=name) + cond, + body, + loop_vars, + shape_invariants=shape_invariants, + maximum_iterations=maximum_iterations, + name=name) with ops.name_scope(name, "while", loop_vars): if not loop_vars: @@ -3796,6 +3801,12 @@ def __init__(self): super(XLAControlFlowContext, self).__init__() self._name = "XLAControlFlowContext" + def to_control_flow_context_def(self, context_def, export_scope=None): + # pylint: disable=useless-super-delegation + # NOTE(slebedev): the method is required by `ControlFlowContext`. + super(XLAControlFlowContext, self).to_control_flow_context_def( + context_def, export_scope) + def IsXLAContext(self): return True diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py index 22263c702ff301..f4b28f0113bebf 100644 --- a/tensorflow/python/ops/control_flow_ops_test.py +++ b/tensorflow/python/ops/control_flow_ops_test.py @@ -469,17 +469,28 @@ def testWhileContextWithMaximumIterations(self): self._testWhileContextHelper(maximum_iterations=10) def testControlContextImportScope(self): + class NoABCControlFlowContext(control_flow_ops.ControlFlowContext): + """A noop wrapper around `ControlFlowContext`. + + `ControlFlowContext` is an ABC and therefore cannot be instantiated. + """ + # pylint: disable=useless-super-delegation + + def to_control_flow_context_def(self, context_def, export_scope=None): + super(NoABCControlFlowContext, self).to_control_flow_context_def( + context_def, export_scope) + with self.cached_session(): constant_op.constant(0, name="a") constant_op.constant(2, name="test_scope/a") b1 = constant_op.constant(1, name="b") b2 = constant_op.constant(3, name="test_scope/b") - c = control_flow_ops.ControlFlowContext() + c = NoABCControlFlowContext() c._values = ["a", "b"] c._external_values = {"a": b1} - c_with_scope = control_flow_ops.ControlFlowContext( + c_with_scope = NoABCControlFlowContext( values_def=c._to_values_def(), import_scope="test_scope") # _values and _external_values should be have scope prepended. diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py index f2701bc41bdd5f..fecd7ddbf9ffd8 100644 --- a/tensorflow/python/ops/functional_ops.py +++ b/tensorflow/python/ops/functional_ops.py @@ -13,11 +13,7 @@ # limitations under the License. # ============================================================================= -"""Functional operations. - -See the [Higher Order -Functions](https://tensorflow.org/api_guides/python/functional_ops) guide. -""" +"""Functional operations.""" from __future__ import absolute_import from __future__ import division @@ -806,6 +802,29 @@ def Gradient(inputs, f, name=None): return symbolic_gradient(input=inputs, Tout=tlist, f=f, name=name) +def _LoopBodyCaptureWrapper(func): + """Returns a wrapper for `func` that handles loop-carried captured inputs.""" + + @function.Defun( + *func.declared_input_types, func_name="%s_Wrapper" % func.name) + def Wrapper(*args): + """A wrapper that handles loop-carried captured inputs.""" + result = func(*args) + extra_args = tuple(function.get_extra_args()) + # Nullary functions return an Operation. Normal functions can't do this + # because their return values are converted to Tensors. + if isinstance(result, ops.Operation): + return extra_args + # Unary functions return a single Tensor value. + elif not isinstance(result, tuple): + return (result,) + extra_args + # N-ary functions return a tuple of Tensors. + else: + return result + extra_args + + return Wrapper + + # pylint: disable=invalid-name,protected-access def While(input_, cond, body, name=None, hostmem=None): r"""output = input; While (Cond(output)) { output = Body(output) }. @@ -827,11 +846,41 @@ def While(input_, cond, body, name=None, hostmem=None): hostmem: A list of integer. If i is in the list, input[i] is a host memory tensor. + Raises: + ValueError: if `cond` has implicitly captured inputs or if `cond` and `body` + have different signatures. + Returns: A list of `Tensor` objects. Has the same type as `input`. A list of output tensors whose types are T. """ - ret = gen_functional_ops._while(input_, cond, body, name=name) + if cond.captured_inputs: + raise ValueError("While op 'cond' argument must be a function " + "without implicitly captured inputs.") + + if cond.declared_input_types != body.declared_input_types: + raise ValueError( + "While op 'cond' and 'body' signatures do not match. %r vs %r" % + (cond.declared_input_types, body.declared_input_types)) + + if body.captured_inputs: + cond_dtypes = list( + body.declared_input_types) + [t.dtype for t in body.captured_inputs] + + @function.Defun(*cond_dtypes, func_name="%s_Wrapper" % cond.name) + def CondWrapper(*args): + """A wrapper that handles loop-carried captured inputs.""" + return cond(*args[:len(body.declared_input_types)]) + + ret = gen_functional_ops._while( + input_ + body.captured_inputs, + CondWrapper, + _LoopBodyCaptureWrapper(body), + name=name) + # Slice off the loop-carried captured inputs. + ret = ret[:-len(body.captured_inputs)] + else: + ret = gen_functional_ops._while(input_, cond, body, name=name) if hostmem: input_attr = attr_value_pb2.AttrValue() input_attr.list.i.extend(hostmem) @@ -880,11 +929,10 @@ def _ForUsingWhile(start, # must have identical inputs, we have to augment the cond signature to take # the same types as the carried loop variables. body_sig = [dtypes.int32] * 4 + list(forbody.declared_input_types)[1:] - cond_sig = body_sig + [t.dtype for t in forbody.captured_inputs] cond_name = "%s_Cond" % forbody.name - @function.Defun(*cond_sig, func_name=cond_name) + @function.Defun(*body_sig, func_name=cond_name) def WhileCond(i, n, *args): del args return i < n @@ -902,8 +950,7 @@ def WhileBody(i, n, start, delta, *args): # Unary functions return a single Tensor value. elif isinstance(for_result, ops.Tensor): for_result = (for_result,) - extra_args = tuple(function.get_extra_args()) - return (i + 1, n, start, delta) + tuple(for_result) + extra_args + return (i + 1, n, start, delta) + tuple(for_result) if hostmem is not None: hostmem = [0, 1, 2, 3] + [(4 + _) for _ in hostmem] @@ -911,13 +958,13 @@ def WhileBody(i, n, start, delta, *args): hostmem = [0, 1, 2, 3] results = While( - input_=[0, n, start, delta] + inputs + WhileBody.captured_inputs, + input_=[0, n, start, delta] + inputs, cond=WhileCond, body=WhileBody, name=name, hostmem=hostmem) # Slice off the loop-carried captured inputs. - return list(results[4:len(results) - len(WhileBody.captured_inputs)]) + return list(results[4:len(results)]) def For(start, @@ -951,29 +998,15 @@ def For(start, if rewrite_with_while: return _ForUsingWhile(start, limit, delta, inputs, body, name, hostmem) if body.captured_inputs: - wrapper_name = "%s_BodyWrapper" % body.name - - @function.Defun(*body.declared_input_types, func_name=wrapper_name) - def BodyWrapper(*args): - """A wrapper for body that handles loop-carried captured inputs.""" - body_result = body(*args) - extra_args = tuple(function.get_extra_args()) - # Nullary functions return an Operation. Normal functions can't do this - # because their return values are converted to Tensors. - if isinstance(body_result, ops.Operation): - return extra_args - # Unary functions return a single Tensor value. - elif not isinstance(body_result, tuple): - return (body_result,) + extra_args - # N-ary functions return a tuple of Tensors. - else: - return body_result + extra_args - - inputs += BodyWrapper.captured_inputs ret = gen_functional_ops._for( - start, limit, delta, inputs, BodyWrapper, name=name) + start, + limit, + delta, + inputs + body.captured_inputs, + _LoopBodyCaptureWrapper(body), + name=name) # Slice off the loop-carried captured inputs. - ret = ret[:-len(BodyWrapper.captured_inputs)] + ret = ret[:-len(body.captured_inputs)] else: ret = gen_functional_ops._for(start, limit, delta, inputs, body, name=name) if hostmem: diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 0cb6f80b2cc294..4f0fb54dcab855 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -264,6 +264,12 @@ def _DefaultGradYs(grad_ys, "Gradient type %s generated for variant " "tensor %s with type %s must be variant" % (dtypes.as_dtype( grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name)) + elif y.dtype == dtypes.resource: + # We assume y is the handle of a ResourceVariable. The gradient of a + # ResourceVariable should be a numeric value, not another resource. + if grad_y.dtype == dtypes.resource: + raise TypeError("Input gradient %s for resource tensor %s should not " + "be a resource" % (grad_y, y)) else: raise TypeError( "Tensor %s with type %s must be numeric " diff --git a/tensorflow/python/ops/linalg/linear_operator.py b/tensorflow/python/ops/linalg/linear_operator.py index 6af7e0fa227837..704ac11d0134ea 100644 --- a/tensorflow/python/ops/linalg/linear_operator.py +++ b/tensorflow/python/ops/linalg/linear_operator.py @@ -22,6 +22,7 @@ import contextlib import numpy as np +import six from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -40,6 +41,7 @@ # TODO(langmore) Use matrix_solve_ls for singular or non-square matrices. @tf_export("linalg.LinearOperator") +@six.add_metaclass(abc.ABCMeta) class LinearOperator(object): """Base class defining a [batch of] linear operator[s]. @@ -140,7 +142,6 @@ class LinearOperator(object): * If `is_X == None` (the default), callers should have no expectation either way. """ - __metaclass__ = abc.ABCMeta def __init__(self, dtype, diff --git a/tensorflow/python/ops/linalg/linear_operator_util.py b/tensorflow/python/ops/linalg/linear_operator_util.py index 9dd40765c20222..54d04e4a70bc65 100644 --- a/tensorflow/python/ops/linalg/linear_operator_util.py +++ b/tensorflow/python/ops/linalg/linear_operator_util.py @@ -18,6 +18,8 @@ from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -25,6 +27,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops.linalg import linalg_impl as linalg def assert_no_entries_with_modulus_zero( @@ -233,9 +236,9 @@ def matmul_with_broadcast(a, """Multiplies matrix `a` by matrix `b`, producing `a @ b`. Works identically to `tf.matmul`, but broadcasts batch dims - of `a` and `b` (by replicating) if they are determined statically to be - different, or if static shapes are not fully defined. Thus, this may result - in an inefficient replication of data. + of `a` and `b` if they are determined statically to be different, or if static + shapes are not fully defined. Attempts are made to avoid unnecessary + replication of data, but this is not always possible. The inputs must be matrices (or tensors of rank > 2, representing batches of matrices). @@ -308,23 +311,51 @@ def matmul_with_broadcast(a, are both set to True. """ with ops.name_scope(name, "MatMulWithBroadcast", [a, b]): - a, b = broadcast_matrix_batch_dims([a, b]) - return math_ops.matmul( + a = ops.convert_to_tensor(a, name="a") + b = ops.convert_to_tensor(b, name="b", dtype=a.dtype) + + # If either a or b has extra dims, we can reshape to get rid of them. + a, b, reshape_inv, still_need_to_transpose = _reshape_for_efficiency( a, b, transpose_a=transpose_a, transpose_b=transpose_b, adjoint_a=adjoint_a, - adjoint_b=adjoint_b, + adjoint_b=adjoint_b) + + # This will broadcast by brute force if we still need to. + a, b = broadcast_matrix_batch_dims([a, b]) + + a_times_b = math_ops.matmul( + a, + b, + transpose_a=transpose_a and still_need_to_transpose, + transpose_b=transpose_b and still_need_to_transpose, + adjoint_a=adjoint_a and still_need_to_transpose, + adjoint_b=adjoint_b and still_need_to_transpose, a_is_sparse=a_is_sparse, b_is_sparse=b_is_sparse) + return reshape_inv(a_times_b) + def matrix_solve_with_broadcast(matrix, rhs, adjoint=False, name=None): """Solve systems of linear equations.""" with ops.name_scope(name, "MatrixSolveWithBroadcast", [matrix, rhs]): + matrix = ops.convert_to_tensor(matrix, name="matrix") + rhs = ops.convert_to_tensor(rhs, name="rhs", dtype=matrix.dtype) + + # If either matrix/rhs has extra dims, we can reshape to get rid of them. + matrix, rhs, reshape_inv, still_need_to_transpose = _reshape_for_efficiency( + matrix, rhs, adjoint_a=adjoint) + + # This will broadcast by brute force if we still need to. matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs]) - return linalg_ops.matrix_solve(matrix, rhs, adjoint=adjoint) + + solution = linalg_ops.matrix_solve( + matrix, rhs, adjoint=adjoint and still_need_to_transpose) + + return reshape_inv(solution) def matrix_triangular_solve_with_broadcast(matrix, @@ -354,9 +385,119 @@ def matrix_triangular_solve_with_broadcast(matrix, `Tensor` with same `dtype` as `matrix` and shape `[..., M, K]`. """ with ops.name_scope(name, "MatrixTriangularSolve", [matrix, rhs]): + matrix = ops.convert_to_tensor(matrix, name="matrix") + rhs = ops.convert_to_tensor(rhs, name="rhs", dtype=matrix.dtype) + + # If either matrix/rhs has extra dims, we can reshape to get rid of them. + matrix, rhs, reshape_inv, still_need_to_transpose = _reshape_for_efficiency( + matrix, rhs, adjoint_a=adjoint) + + # lower indicates whether the matrix is lower triangular. If we have + # manually taken adjoint inside _reshape_for_efficiency, it is now upper tri + if not still_need_to_transpose and adjoint: + lower = not lower + + # This will broadcast by brute force if we still need to. matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs]) - return linalg_ops.matrix_triangular_solve( + + solution = linalg_ops.matrix_triangular_solve( matrix, rhs, lower=lower, - adjoint=adjoint) + adjoint=adjoint and still_need_to_transpose) + + return reshape_inv(solution) + + +def _reshape_for_efficiency(a, + b, + transpose_a=False, + transpose_b=False, + adjoint_a=False, + adjoint_b=False): + """Maybe reshape a, b, and return an inverse map. For matmul/solve.""" + def identity(x): + return x + + # At this point, we have not taken transpose/adjoint of a/b. + still_need_to_transpose = True + + if a.shape.ndims is None or b.shape.ndims is None: + return a, b, identity, still_need_to_transpose + + # This could be handled in the future, but seems less common. + if a.shape.ndims >= b.shape.ndims: + return a, b, identity, still_need_to_transpose + + # From now on, we might modify b, but will not modify a. + + # Suppose: + # a.shape = C + [m, n], b.shape = + # b.shape = S + C + [n, r] + b_extra_ndims = b.shape.ndims - a.shape.ndims + + # b_extra_sh = S, b_main_sh = C + [n, r] + b_extra_sh = array_ops.shape(b)[:b_extra_ndims] + b_main_sh = array_ops.shape(b)[b_extra_ndims:] + + # No reason to flip unless the extra dims of b are big enough. Why? + # Assume adjoint/transpose = False. Then... + # By not flipping, we have to replicate a to shape + # b_extra_sh + a.shape, + # which could use extra memory. But in all cases, the final output has shape + # b_extra_sh + a.shape[:-1] + [b.shape[-1]] + # So we only end up creating a larger object if the end dim of b is smaller + # than the end dim of a. This often happens, e.g. if b was a vector that was + # expanded to a matrix (by appending a singleton). + + # Since adjoint/transpose may not be False, we must make adjustments here. + # The dim of b that holds the multiple equations. + a_domain_sz_ = a.shape[-2 if adjoint_a or transpose_a else -1] + b_eq_sz_ = b.shape[-2 if adjoint_b or transpose_b else -1] + b_extra_sz_ = ( + np.prod(b.shape[:b_extra_ndims].as_list()) + if b.shape[:b_extra_ndims].is_fully_defined() else None) + if (a_domain_sz_ is not None and b_eq_sz_ is not None and + b_extra_sz_ is not None): + if b_extra_sz_ < 2 or a_domain_sz_ <= b_eq_sz_: + return a, b, identity, still_need_to_transpose + + # At this point, we're flipping for sure! + # Any transposes/adjoints will happen here explicitly, rather than in calling + # code. Why? To avoid having to write separate complex code for each case. + if adjoint_a: + a = linalg.adjoint(a) + elif transpose_a: + a = linalg.transpose(a) + if adjoint_b: + b = linalg.adjoint(b) + elif transpose_b: + b = linalg.transpose(b) + still_need_to_transpose = False + + # Recompute shapes, since the transpose/adjoint may have changed them. + b_extra_sh = array_ops.shape(b)[:b_extra_ndims] + b_main_sh = array_ops.shape(b)[b_extra_ndims:] + + # Permutation to put the extra dims at the end. + perm = ( + array_ops.concat( + (math_ops.range(b_extra_ndims, b.shape.ndims), + math_ops.range(0, b_extra_ndims)), 0)) + b_extra_on_end = array_ops.transpose(b, perm=perm) + + # Now squash this end into one long dim. + b_squashed_end = array_ops.reshape( + b_extra_on_end, array_ops.concat((b_main_sh[:-1], [-1]), 0)) + + def reshape_inv(y): + # Expand the extra dims hanging off the end, "b_extra_sh". + # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y + # Could have different batch dims than a and b, because of broadcasting. + y_extra_shape = array_ops.concat( + (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0) + y_extra_on_end = array_ops.reshape(y, y_extra_shape) + return array_ops.transpose( + y_extra_on_end, perm=array_ops.invert_permutation(perm)) + + return a, b_squashed_end, reshape_inv, still_need_to_transpose diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py index cffaa983d486c9..e65b53e3ac9474 100644 --- a/tensorflow/python/ops/lookup_ops.py +++ b/tensorflow/python/ops/lookup_ops.py @@ -39,6 +39,7 @@ # pylint: disable=wildcard-import from tensorflow.python.ops.gen_lookup_ops import * # pylint: enable=wildcard-import +from tensorflow.python.training.checkpointable import tracking as checkpointable from tensorflow.python.util import compat from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.tf_export import tf_export @@ -96,20 +97,22 @@ def _check_table_dtypes(table, key_dtype, value_dtype): (table.value_dtype, value_dtype)) -class LookupInterface(object): +class LookupInterface(checkpointable.TrackableResource): """Represent a lookup table that persists across different steps.""" - def __init__(self, key_dtype, value_dtype, name): + def __init__(self, key_dtype, value_dtype): """Construct a lookup table interface. Args: key_dtype: The table key type. value_dtype: The table value type. - name: A name for the operation (optional). """ self._key_dtype = dtypes.as_dtype(key_dtype) self._value_dtype = dtypes.as_dtype(value_dtype) - self._name = name + super(LookupInterface, self).__init__() + + def create_resource(self): + raise NotImplementedError @property def key_dtype(self): @@ -124,12 +127,7 @@ def value_dtype(self): @property def name(self): """The name of the table.""" - return self._name - - @property - def init(self): - """The table initialization op.""" - raise NotImplementedError + return NotImplementedError def size(self, name=None): """Compute the number of elements in this table.""" @@ -146,7 +144,7 @@ class InitializableLookupTableBase(LookupInterface): An initializable lookup tables persist across different steps. """ - def __init__(self, table_ref, default_value, initializer): + def __init__(self, default_value, initializer): """Construct a table object from a table reference. If requires a table initializer object (subclass of `TableInitializerBase`). @@ -154,38 +152,35 @@ def __init__(self, table_ref, default_value, initializer): the table. The caller is responsible to execute the initialization op. Args: - table_ref: The table reference, i.e. the output of the lookup table ops. default_value: The value to use if a key is missing in the table. initializer: The table initializer to use. """ - if context.executing_eagerly(): - name = context.context().scope_name - else: - name = table_ref.op.name.split("/")[-1] - super(InitializableLookupTableBase, - self).__init__(initializer.key_dtype, initializer.value_dtype, - name) - self._table_ref = table_ref + super(InitializableLookupTableBase, self).__init__(initializer.key_dtype, + initializer.value_dtype) self._default_value = ops.convert_to_tensor( default_value, dtype=self._value_dtype) self._default_value.get_shape().merge_with(tensor_shape.scalar()) - self._init = initializer.initialize(self) + self._initializer = initializer + self._resource_handle = self.create_resource() + self._init_op = self.initialize() + + def initialize(self): + return self._initializer.initialize(self) @property - def table_ref(self): - """Get the underlying table reference.""" - return self._table_ref + def initializer(self): + return self._init_op + + @property + @deprecated("2018-12-15", "Use `initializer` instead.") + def init(self): + return self.initializer @property def default_value(self): """The default value of the table.""" return self._default_value - @property - def init(self): - """The table initialization op.""" - return self._init - def size(self, name=None): """Compute the number of elements in this table. @@ -195,9 +190,10 @@ def size(self, name=None): Returns: A scalar tensor containing the number of elements in this table. """ - with ops.name_scope(name, "%s_Size" % self._name, - [self._table_ref]) as scope: - return gen_lookup_ops.lookup_table_size_v2(self._table_ref, name=scope) + with ops.name_scope(name, "%s_Size" % self.name, + [self.resource_handle]) as scope: + return gen_lookup_ops.lookup_table_size_v2( + self.resource_handle, name=scope) def lookup(self, keys, name=None): """Looks up `keys` in a table, outputs the corresponding values. @@ -223,11 +219,11 @@ def lookup(self, keys, name=None): raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % (self._key_dtype, keys.dtype)) - with ops.name_scope(name, "%s_Lookup" % self._name, - (self._table_ref, key_tensor, - self._default_value)) as scope: + with ops.name_scope( + name, "%s_Lookup" % self.name, + (self.resource_handle, key_tensor, self._default_value)) as scope: values = gen_lookup_ops.lookup_table_find_v2( - self._table_ref, key_tensor, self._default_value, name=scope) + self.resource_handle, key_tensor, self._default_value, name=scope) values.set_shape(key_tensor.get_shape()) if isinstance(keys, sparse_tensor.SparseTensor): @@ -269,16 +265,28 @@ def __init__(self, initializer, default_value, shared_name=None, name=None): Returns: A `HashTable` object. """ - with ops.name_scope(name, "hash_table", (initializer, - default_value)) as scope: + self._initializer = initializer + self._default_value = default_value + self._shared_name = shared_name + self._name = name + self._table_name = "" + super(HashTable, self).__init__(default_value, initializer) + self._value_shape = self._default_value.get_shape() + + def create_resource(self): + with ops.name_scope(self._name, "hash_table", + (self._initializer, self._default_value)) as scope: table_ref = gen_lookup_ops.hash_table_v2( - shared_name=shared_name, - key_dtype=initializer.key_dtype, - value_dtype=initializer.value_dtype, + shared_name=self._shared_name, + key_dtype=self._initializer.key_dtype, + value_dtype=self._initializer.value_dtype, name=scope) + self._table_name = scope.split("/")[-2] + return table_ref - super(HashTable, self).__init__(table_ref, default_value, initializer) - self._value_shape = self._default_value.get_shape() + @property + def name(self): + return self._table_name def export(self, name=None): """Returns tensors of all keys and values in the table. @@ -290,11 +298,11 @@ def export(self, name=None): A pair of tensors with the first tensor containing all keys and the second tensors containing all values in the table. """ - with ops.name_scope(name, "%s_Export" % self._name, - [self._table_ref]) as name: - with ops.colocate_with(self._table_ref): + with ops.name_scope(name, "%s_Export" % self.name, + [self.resource_handle]) as name: + with ops.colocate_with(self.resource_handle): exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( - self._table_ref, self._key_dtype, self._value_dtype, name=name) + self.resource_handle, self._key_dtype, self._value_dtype, name=name) exported_values.set_shape(exported_keys.get_shape().concatenate( self._value_shape)) @@ -366,7 +374,7 @@ def initialize(self, table): """ _check_table_dtypes(table, self._keys.dtype, self._values.dtype) with ops.name_scope( - self._name, values=(table.table_ref, self._keys, + self._name, values=(table.resource_handle, self._keys, self._values)) as scope: if context.executing_eagerly(): # Ensure a unique name when eager execution is enabled to avoid spurious @@ -374,11 +382,11 @@ def initialize(self, table): scope += str(ops.uid()) if fwd_compat.forward_compatible(2018, 9, 19): init_op = gen_lookup_ops.lookup_table_import_v2( - table.table_ref, self._keys, self._values, name=scope) + table.resource_handle, self._keys, self._values, name=scope) else: # To maintain forward compatibiltiy, use the old implementation. init_op = gen_lookup_ops.initialize_table_v2( - table.table_ref, self._keys, self._values, name=scope) + table.resource_handle, self._keys, self._values, name=scope) ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) return init_op @@ -538,11 +546,11 @@ def initialize(self, table): """ _check_table_dtypes(table, self.key_dtype, self.value_dtype) with ops.name_scope(self._name, "text_file_init", - (table.table_ref,)) as scope: + (table.resource_handle,)) as scope: filename = ops.convert_to_tensor( self._filename, dtypes.string, name="asset_filepath") init_op = gen_lookup_ops.initialize_table_from_text_file_v2( - table.table_ref, + table.resource_handle, filename, self._key_index, self._value_index, @@ -806,36 +814,42 @@ def __init__(self, raise TypeError( "hasher_spec must be of type HasherSpec, got %s" % hasher_spec) self._hasher_spec = hasher_spec - super(IdTableWithHashBuckets, self).__init__(key_dtype, dtypes.int64, - name.split("/")[-1]) + self._table_name = name.split("/")[-1] + super(IdTableWithHashBuckets, self).__init__(key_dtype, dtypes.int64) - @property - def init(self): - """The table initialization op.""" - if self._table: - return self._table.init + def create_resource(self): + if self._table is not None: + return self._table.create_resource() + return None + + def initialize(self): + if self._table is not None: + return self._table.initialize() with ops.name_scope(None, "init"): return control_flow_ops.no_op() @property - def table_ref(self): - """Returns the table_ref of the underlying table, if one exists. - - Only use the table_ref directly if you know what you are doing. The - table_ref does not have the "hash bucket" functionality, as that is provided - by this class. + def initializer(self): + if self._table is not None: + return self._table._init_op # pylint: disable=protected-access + with ops.name_scope(None, "init"): + return control_flow_ops.no_op() - One possible use of the table_ref is subtokenization, i.e. ops which - dynamically decompose tokens into subtokens based on the contents of the - table_ref. + @property + @deprecated("2018-12-15", "Use `initializer` instead.") + def init(self): + return self.initializer - Returns: - the underlying table_ref, or None if there is no underlying table - """ + @property + def resource_handle(self): if self._table is not None: - return self._table.table_ref + return self._table.resource_handle return None + @property + def name(self): + return self._table_name + def size(self, name=None): """Compute the number of elements in this table.""" with ops.name_scope(name, "%s_Size" % self.name) as scope: @@ -1139,7 +1153,6 @@ def index_table_from_tensor(vocabulary_list, hasher_spec=hasher_spec, name=feat_to_id_scope, key_dtype=dtype) - return table diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index c9374006ba3db7..d247e7b2463bd0 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -2057,6 +2057,107 @@ def matmul(a, a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name) +@tf_export("linalg.matvec") +def matvec(a, + b, + transpose_a=False, + adjoint_a=False, + a_is_sparse=False, + b_is_sparse=False, + name=None): + """Multiplies matrix `a` by vector `b`, producing `a` * `b`. + + The matrix `a` must, following any transpositions, be a tensor of rank >= 2, + and we must have `shape(b) = shape(a)[:-2] + [shape(a)[-1]]`. + + Both `a` and `b` must be of the same type. The supported types are: + `float16`, `float32`, `float64`, `int32`, `complex64`, `complex128`. + + Matrix `a` can be transposed or adjointed (conjugated and transposed) on + the fly by setting one of the corresponding flag to `True`. These are `False` + by default. + + If one or both of the inputs contain a lot of zeros, a more efficient + multiplication algorithm can be used by setting the corresponding + `a_is_sparse` or `b_is_sparse` flag to `True`. These are `False` by default. + This optimization is only available for plain matrices/vectors (rank-2/1 + tensors) with datatypes `bfloat16` or `float32`. + + For example: + + ```python + # 2-D tensor `a` + # [[1, 2, 3], + # [4, 5, 6]] + a = tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3]) + + # 1-D tensor `b` + # [7, 9, 11] + b = tf.constant([7, 9, 11], shape=[3]) + + # `a` * `b` + # [ 58, 64] + c = tf.matvec(a, b) + + + # 3-D tensor `a` + # [[[ 1, 2, 3], + # [ 4, 5, 6]], + # [[ 7, 8, 9], + # [10, 11, 12]]] + a = tf.constant(np.arange(1, 13, dtype=np.int32), + shape=[2, 2, 3]) + + # 2-D tensor `b` + # [[13, 14, 15], + # [16, 17, 18]] + b = tf.constant(np.arange(13, 19, dtype=np.int32), + shape=[2, 3]) + + # `a` * `b` + # [[ 86, 212], + # [410, 563]] + c = tf.matvec(a, b) + ``` + + Args: + a: `Tensor` of type `float16`, `float32`, `float64`, `int32`, `complex64`, + `complex128` and rank > 1. + b: `Tensor` with same type and rank = `rank(a) - 1`. + transpose_a: If `True`, `a` is transposed before multiplication. + adjoint_a: If `True`, `a` is conjugated and transposed before + multiplication. + a_is_sparse: If `True`, `a` is treated as a sparse matrix. + b_is_sparse: If `True`, `b` is treated as a sparse matrix. + name: Name for the operation (optional). + + Returns: + A `Tensor` of the same type as `a` and `b` where each inner-most vector is + the product of the corresponding matrices in `a` and vectors in `b`, e.g. if + all transpose or adjoint attributes are `False`: + + `output`[..., i] = sum_k (`a`[..., i, k] * `b`[..., k]), for all indices i. + + Note: This is matrix-vector product, not element-wise product. + + + Raises: + ValueError: If transpose_a and adjoint_a are both set to True. + """ + with ops.name_scope(name, "MatVec", [a, b]) as name: + # matvec is achieved by reshaping b into a matrix (appending a singleton), + # then squeezing out the trailing dim of the result. There are other ways + # to do this, e.g. using tf.expand_dims and tf.squeeze. What we have here + # has been found to be most memory efficient on TPU. + return matmul( + a, + b[..., array_ops.newaxis], + transpose_a=transpose_a, + adjoint_a=adjoint_a, + a_is_sparse=a_is_sparse, + b_is_sparse=b_is_sparse)[..., 0] + + _OverrideBinaryOperatorHelper(matmul, "matmul") sparse_matmul = gen_math_ops.sparse_mat_mul diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index d0919bdbe46d20..e86a3b85360ae2 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -225,7 +225,7 @@ def _safe_div(numerator, denominator, name): 0 if `denominator` <= 0, else `numerator` / `denominator` """ if compat.forward_compatible(2018, 11, 1): - return math_ops.div_no_nan(numerator, denominator) + return math_ops.div_no_nan(numerator, denominator, name=name) t = math_ops.truediv(numerator, denominator) zero = array_ops.zeros_like(t, dtype=denominator.dtype) condition = math_ops.greater(denominator, zero) diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index e31d162285beb8..74343f832bcecb 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -1609,6 +1609,8 @@ def leaky_relu(features, alpha=0.2, name=None): if features.dtype.is_integer: features = math_ops.to_float(features) if compat.forward_compatible(2018, 11, 1): + if isinstance(alpha, np.ndarray): + alpha = np.asscalar(alpha) return gen_nn_ops.leaky_relu(features, alpha=alpha, name=name) alpha = ops.convert_to_tensor(alpha, dtype=features.dtype, name="alpha") return math_ops.maximum(alpha * features, features, name=name) diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 10c57333ba576b..488b6fcbcdb2fb 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -1504,3 +1504,21 @@ def is_resource_variable(var): """"Returns True if `var` is to be considered a ResourceVariable.""" return isinstance(var, ResourceVariable) or hasattr( var, "_should_act_as_resource_variable") + + +def copy_to_graph_uninitialized(var): + """Copies an existing variable to a new graph, with no initializer.""" + # Like ResourceVariable.__deepcopy__, but does not set an initializer on the + # new variable. + # pylint: disable=protected-access + new_variable = ResourceVariable( + initial_value=array_ops.placeholder( + shape=var.shape, dtype=var.dtype, + name="unused_initial_variable_value"), + trainable=var.trainable, + constraint=var._constraint, + dtype=var.dtype, + name=var._shared_name) + new_variable._maybe_initialize_checkpointable() + # pylint: enable=protected-access + return new_variable diff --git a/tensorflow/python/ops/session_ops.py b/tensorflow/python/ops/session_ops.py index 720be098c25a87..c6cf2fe9adf58b 100644 --- a/tensorflow/python/ops/session_ops.py +++ b/tensorflow/python/ops/session_ops.py @@ -13,11 +13,7 @@ # limitations under the License. # ============================================================================== -"""Tensor Handle Operations. - -See the [Session Ops](https://tensorflow.org/api_guides/python/session_ops) -guide. -""" +"""Tensor Handle Operations.""" # pylint: disable=g-bad-name from __future__ import absolute_import diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index 077e4558b7cc0f..b98c7f5f65b641 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -16,7 +16,7 @@ # pylint: disable=g-short-docstring-punctuation """Sparse Tensor Representation. -See the [Sparse Ops](https://tensorflow.org/api_guides/python/sparse_ops) guide. +See also `tf.SparseTensor`. """ from __future__ import absolute_import diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index 920047f38b07e6..76684f89f8ac93 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -15,7 +15,7 @@ """Variables. -See the [Variables](https://tensorflow.org/api_guides/python/state_ops) guide. +See the [Variables](https://www.tensorflow.org/guide/variables) guide. """ from __future__ import absolute_import diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py index ed14aa7d90060a..25e86cadeb6496 100644 --- a/tensorflow/python/ops/string_ops.py +++ b/tensorflow/python/ops/string_ops.py @@ -13,10 +13,7 @@ # limitations under the License. # ============================================================================== -"""Operations for working with string Tensors. - -See the [Strings](https://tensorflow.org/api_guides/python/string_ops) guide. -""" +"""Operations for working with string Tensors.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 4b3445a0bc44ad..e43736069e38a6 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -58,6 +58,21 @@ def getter(**kwargs): return getter +def _has_cycle(op, path): + """Detect cycles in the dependencies of `initial_value`.""" + if op.name in path: + return True + path.add(op.name) + for op_input in op.inputs: + if _has_cycle(op_input.op, path): + return True + for op_control_input in op.control_inputs: + if _has_cycle(op_control_input, path): + return True + path.remove(op.name) + return False + + @tf_export("VariableSynchronization") class VariableSynchronization(enum.Enum): """Indicates when a distributed variable will be synced. @@ -2172,20 +2187,7 @@ def _try_guard_against_uninitialized_dependencies(self, initial_value): raise TypeError("initial_value needs to be a Tensor: %s" % initial_value) # Don't modify initial_value if it contains any cyclic dependencies. - def has_cycle(op, path): - """Detect cycles in the dependencies of `initial_value`.""" - if op.name in path: - return True - path.add(op.name) - for op_input in op.inputs: - if has_cycle(op_input.op, path): - return True - for op_control_input in op.control_inputs: - if has_cycle(op_control_input, path): - return True - path.remove(op.name) - return False - if has_cycle(initial_value.op, path=set()): + if _has_cycle(initial_value.op, path=set()): return initial_value return self._safe_initial_value_from_tensor(initial_value, op_cache={}) diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index bf883fe0672b72..254fae11f4b08e 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -39,6 +39,7 @@ from tensorflow.python.ops import gen_functional_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import list_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.util import nest @@ -50,9 +51,24 @@ # to them and then pass those in as data inputs. This should probably be # handled in the CapturingGraph itself. +# Op types that output a resource tensor representing a TensorArray handle. +TENSOR_ARRAY_HANDLE_OPS = ( + "TensorArrayV3", + "TensorArrayGradV3", + "TensorArrayGradWithShape", +) -def while_loop(cond, body, loop_vars, shape_invariants=None, name=None): + +def while_loop(cond, + body, + loop_vars, + shape_invariants=None, + maximum_iterations=None, + name=None): """Like tf.while_loop, except emits a single While op.""" + if _is_in_xla_context() and maximum_iterations is None: + raise ValueError("maximum_iterations is required in XLA context.") + # Keep the original loop_vars around to know which args were TensorArrays. orig_loop_vars = loop_vars # Cache its length since we use it at multiple places below. @@ -69,6 +85,13 @@ def while_loop(cond, body, loop_vars, shape_invariants=None, name=None): else: shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars) + if maximum_iterations is not None: + maximum_iterations = ops.convert_to_tensor( + maximum_iterations, name="maximum_iterations") + if maximum_iterations.shape.ndims != 0: + raise ValueError("maximum_iterations must be a scalar, saw shape: %s" % + maximum_iterations.shape) + if not name: name = "while" @@ -77,8 +100,13 @@ def while_loop(cond, body, loop_vars, shape_invariants=None, name=None): cond_name = util.unique_fn_name(scope, "cond") body_name = util.unique_fn_name(scope, "body") + loop_counter = constant_op.constant( + 0, + dtype=maximum_iterations.dtype + if maximum_iterations is not None else None, + name="loop_counter") # Add loop counter needed for computing gradients. - loop_vars = [constant_op.constant(0., name="loop_counter")] + loop_vars + loop_vars = [loop_counter] + loop_vars shape_invariants = [tensor_shape.scalar()] + shape_invariants @@ -87,13 +115,18 @@ def while_loop(cond, body, loop_vars, shape_invariants=None, name=None): add_control_dependencies = util.in_defun() # Build a `cond` wrapper that can handle the extra counter loop_var. - def wrapped_cond(unused_loop_counter, *args): + def wrapped_cond(loop_counter, *args): # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. - return cond(*_pack_sequence_as(orig_loop_vars, args)) + if maximum_iterations is None: + return cond(*_pack_sequence_as(orig_loop_vars, args)) + else: + return math_ops.logical_and( + loop_counter < maximum_iterations, + cond(*_pack_sequence_as(orig_loop_vars, args))) cond_graph = func_graph_module.func_graph_from_py_func( cond_name, @@ -243,13 +276,25 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" body_graph = _get_body_graph(op) - # Set the incoming gradient of TensorArray handle to None. - # TODO(b/118164915): We need a way of distinguising b/w TensorArray resource - # handles and ResourceVariables and set the default gradient of only the - # TensorArray handle to None. + # Set the incoming gradient of TensorArray handles to None. The gradient + # implementation currently assumes all resource tensors correspond to float32 + # ResourceVariables, which can lead to runtime shape errors when used with a + # TensorArray. This is a workaround until TensorArrays are reimplemented with + # TensorLists instead of resources. + # Also set the incoming gradient of non-trainable inputs to None. It is + # possible that we receive non-None gradients for non-trainable types in + # nested while loops because we accumulate outputs of the inner while as + # variant tensors which are trainable and hence receive zeros_like tensors in + # the gradient pass. The non-trainable tensors then receive the popped zeros + # tensor from this zeros variant. The gradient for the loop vars corresponding + # to these tensors is None or zeros (this happens only if the loop var is + # accumulated as well) in _grad_fn so we reset these. + # TODO(b/118712257): Remove the IsTrainable filter once we can handle None + # output grads in _grad_fn. grads = [ - None if output.dtype == dtypes.resource else g - for g, output in zip(grads, op.outputs) + None if _is_tensor_array_handle(output) or + not gradients_impl.IsTrainable(output) else grad + for grad, output in zip(grads, op.outputs) ] # Ensure that all non-resource trainable outputs have incoming gradients. @@ -363,8 +408,9 @@ def _create_grad_func(ys, xs, grads, func_graph, name, while_op): """ assert len(ys) == len(grads) - counter = constant_op.constant(0.) total_iters = while_op.outputs[0] + counter = constant_op.constant( + 0, dtype=total_iters.dtype, name="grad_counter") args = [counter, total_iters] + list(grads) # Note: The returned function does not have `args` in the list of @@ -402,7 +448,7 @@ def _grad_fn(ys, xs, args, func_graph): args: The input arguments. args[0] - Loop counter args[1] - Total number of iterations. - args[2:] - Incoming gradients for `func_graph.outputs`. + args[2:] - Incoming gradients for `ys`. func_graph: function.FuncGraph. The corresponding forward-pass function. Returns: @@ -418,6 +464,8 @@ def _grad_fn(ys, xs, args, func_graph): grad_outs = gradients_impl._GradientsHelper( ys, xs, grad_ys=grad_ys, src_graph=func_graph) + # TODO(b/118712257): Handle the case when grad_outs has None's e.g. when there + # is a tf.StopGradient in the loop body. assert all([g is not None for g in grad_outs]) counter = args[0] total_iters = args[1] @@ -728,6 +776,14 @@ def _maybe_set_lowering_attr(op): # pylint: enable=protected-access +# TODO(srbs): This method should be in control_flow_util but that introduces +# a circular dependency ops -> control_flow_util -> ops. +def _is_in_xla_context(): + """Returns whether the current context is inside an XLA context.""" + cur_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access + return control_flow_util.GetContainingXLAContext(cur_ctxt) is not None + + def _get_tensor_convertible_shape(shape): assert isinstance(shape, tensor_shape.TensorShape) if shape.is_fully_defined(): @@ -746,6 +802,24 @@ def _graph_name(graph): return "Base" +def _is_tensor_array_handle(tensor): + """Returns whether tensor is a TensorArray handle.""" + if tensor.dtype != dtypes.resource: + return False + + if tensor.op.type == "While": + # We assume that any resource outputs of a While op correspond to a captured + # resource input (as opposed to a loop variable specified by the user). + # NOTE(skyewm): we could actually check this, but I can't think of when you + # would have a resource loop variable. + tensor = tensor.op.inputs[tensor.value_index] + + # TODO(b/118452219): add test coverage for this. + tensor = func_graph_module.maybe_captured(tensor) + + return tensor.op.type in TENSOR_ARRAY_HANDLE_OPS + + def _pack_sequence_as(structure_with_tas, loop_vars): """Like `nest.pack_sequence_as` but also replaces flows with TensorArrays.""" @@ -783,7 +857,7 @@ def f(maybe_ta): def _build_signature(loop_vars, shape_invariants): return nest.pack_sequence_as(loop_vars, [ - tensor_spec.TensorSpec(s, t.dtype) + tensor_spec.TensorSpec(s, t.dtype, name=t.op.name) for s, t in zip(nest.flatten(shape_invariants), nest.flatten(loop_vars)) ]) diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 576ad8ed65cfa6..e7a3b8afd5daf2 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -21,9 +21,9 @@ py_library( deps = [ ":builder", ":constants", - ":export", ":loader", ":main_op", + ":save", ":signature_constants", ":signature_def_utils", ":simple_save", @@ -265,9 +265,9 @@ py_test( ) py_library( - name = "export", + name = "save", srcs = [ - "export.py", + "save.py", ], srcs_version = "PY2AND3", deps = [ @@ -285,11 +285,11 @@ py_library( ) py_test( - name = "export_test", - srcs = ["export_test.py"], + name = "save_test", + srcs = ["save_test.py"], srcs_version = "PY2AND3", deps = [ - ":export", + ":save", ":signature_constants", ":tag_constants", "//tensorflow/python/eager:def_function", diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py index 8e7bea36de77b4..4f68f7c5aeac4e 100644 --- a/tensorflow/python/saved_model/builder_impl.py +++ b/tensorflow/python/saved_model/builder_impl.py @@ -36,14 +36,10 @@ from tensorflow.python.training import saver as tf_saver from tensorflow.python.util import compat from tensorflow.python.util.deprecation import deprecated_args -from tensorflow.python.util.deprecation import deprecated_endpoints from tensorflow.python.util.tf_export import tf_export -@tf_export( - "saved_model.Builder", - v1=["saved_model.Builder", "saved_model.builder.SavedModelBuilder"]) -@deprecated_endpoints("saved_model.builder.SavedModelBuilder") +@tf_export(v1=["saved_model.Builder", "saved_model.builder.SavedModelBuilder"]) class SavedModelBuilder(object): """Builds the `SavedModel` protocol buffer and saves variables and assets. @@ -82,6 +78,11 @@ class SavedModelBuilder(object): builder.save() ``` + + Note: This function will only be available through the v1 compatibility + library as tf.compat.v1.saved_model.builder.SavedModelBuilder or + tf.compat.v1.saved_model.Builder. Tensorflow 2.0 will introduce a new + object-based method of creating SavedModels. """ def __init__(self, export_dir): diff --git a/tensorflow/python/saved_model/export.py b/tensorflow/python/saved_model/export.py deleted file mode 100644 index 4d4ff236f66652..00000000000000 --- a/tensorflow/python/saved_model/export.py +++ /dev/null @@ -1,283 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Exports a SavedModel from a Checkpointable Python object.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections -import os - -from tensorflow.core.protobuf import saved_model_pb2 -from tensorflow.python.eager import def_function -from tensorflow.python.eager import function -from tensorflow.python.framework import ops -from tensorflow.python.lib.io import file_io -from tensorflow.python.ops import array_ops -from tensorflow.python.saved_model import constants -from tensorflow.python.saved_model import signature_constants -from tensorflow.python.saved_model import signature_def_utils -from tensorflow.python.saved_model import utils -from tensorflow.python.training.checkpointable import base as checkpointable -from tensorflow.python.util import compat -from tensorflow.python.util import nest - - -def _canonicalize_signatures(signatures): - """Converts `signatures` into a dictionary of concrete functions.""" - if signatures is None: - signatures = {} - elif not isinstance(signatures, collections.Mapping): - signatures = { - signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures} - concrete_signatures = {} - for serving_key, signature_function in signatures.items(): - if isinstance(signature_function, (function.PolymorphicFunction, - def_function.PolymorphicFunction)): - input_signature = signature_function._input_signature # pylint: disable=protected-access - if input_signature is None: - raise ValueError( - ("Unable to use the function {} as a signature directly. Functions " - "used to generate serving signatures must either have an " - "`input_signature=` specified when constructed, or must be " - "converted to concrete functions using " - "`f.get_concrete_function(...)`.").format(signature_function)) - signature_function = signature_function.get_concrete_function() - elif not isinstance(signature_function, function.Function): - raise ValueError( - ("Expected a TensorFlow function to generate a signature for, but " - "got {}. Python functions may be decorated with " - "`@tf.function(input_signature=...)` and passed as signatures " - "directly, or created without a signature using `@tf.function` " - "and then converted to a concrete TensorFlow function using " - "`f.get_concrete_function(...)`.").format(signature_function)) - concrete_signatures[serving_key] = signature_function - return concrete_signatures - - -def _is_flat(sequence): - sequence_flat = nest.flatten(sequence) - try: - nest.assert_same_structure(sequence_flat, sequence) - return True - except ValueError: - return False - except TypeError: - return False - - -def _normalize_outputs(outputs, function_name, signature_key): - """Construct an output dictionary from unnormalized function outputs.""" - if isinstance(outputs, collections.Mapping): - for key, value in outputs.items(): - if not isinstance(value, ops.Tensor): - raise ValueError( - ("Got a dictionary containing non-Tensor value {} for key {} " - "in the output of the function {} used to generate a SavedModel " - "signature. Dictionaries outputs for functions used as signatures " - "should have one Tensor output per string key.") - .format(value, key, compat.as_str_any(function_name))) - return outputs - else: - original_outputs = outputs - if not isinstance(outputs, collections.Sequence): - outputs = [outputs] - if not _is_flat(outputs): - raise ValueError( - ("Got non-flat outputs '{}' from '{}' for SavedModel " - "signature '{}'. Signatures have one Tensor per output, so " - "to have predictable names Python functions used to generate " - "these signatures should avoid outputting Tensors in nested " - "structures.") - .format(original_outputs, function_name, signature_key)) - return {("output_{}".format(output_index)): output - for output_index, output - in enumerate(outputs)} - - -def _tensor_dict_to_tensorinfo(tensor_dict): - return {key: utils.build_tensor_info(value) - for key, value in tensor_dict.items()} - - -def _generate_signatures(signature_functions): - """Validates and calls `signature_functions` in the default graph. - - Args: - signature_functions: A dictionary mapping string keys to concrete TensorFlow - functions (e.g. from `_canonicalize_signatures`) which will be used to - generate SignatureDefs. - - Returns: - Each function in the `signature_functions` dictionary is called with - placeholder Tensors, generating a function call operation and output - Tensors. The placeholder Tensors, the function call operation, and the - output Tensors from the function call are part of the default Graph. - - This function then returns a dictionary with the same structure as - `signature_functions`, with the concrete functions replaced by SignatureDefs - implicitly containing information about how to call each function from a - TensorFlow 1.x Session / the C++ Loader API. These SignatureDefs reference - the generated placeholders and Tensor outputs by name. - - The caller is expected to include the default Graph set while calling this - function as a MetaGraph in a SavedModel, including the returned - SignatureDefs as part of that MetaGraph. - """ - signatures = {} - for signature_key, func in sorted(signature_functions.items()): - func.add_to_graph(register_gradient_functions=True) - # `exterior_placeholders` holds placeholders which are outside the function - # body, directly contained in a MetaGraph of the SavedModel. The function - # body itself contains nearly identical placeholders used when running the - # function, but these exterior placeholders allow Session-based APIs to call - # the function using feeds and fetches which name Tensors in the MetaGraph. - exterior_placeholders = {} - kwargs = {} - for placeholder in func.inputs: - user_input_name = compat.as_str_any( - placeholder.op.get_attr("_user_specified_name")) - # If the internal placeholders for a function have names which were - # uniquified by TensorFlow, then a single user-specified argument name - # must refer to multiple Tensors. The resulting signatures would be - # confusing to call. Instead, we throw an exception telling the user to - # specify explicit names. - if user_input_name != placeholder.op.name: - # This should be unreachable, since concrete functions may not be - # generated with non-unique argument names. - raise ValueError( - ("Got non-flat/non-unique argument names for SavedModel " - "signature '{}': more than one argument to '{}' was named '{}'. " - "Signatures have one Tensor per named input, so to have " - "predictable names Python functions used to generate these " - "signatures should avoid *args and Tensors in nested " - "structures unless unique names are specified for each. Use " - "tf.TensorSpec(..., name=...) to provide a name for a Tensor " - "input.") - .format(signature_key, compat.as_str_any(func.name), - user_input_name)) - arg_placeholder = array_ops.placeholder( - shape=placeholder.shape, - dtype=placeholder.dtype, - name="{}_{}".format(signature_key, user_input_name)) - exterior_placeholders[user_input_name] = arg_placeholder - kwargs[user_input_name] = arg_placeholder - outputs = _normalize_outputs( - func(**kwargs), func.name, signature_key) - signatures[signature_key] = signature_def_utils.build_signature_def( - _tensor_dict_to_tensorinfo(exterior_placeholders), - _tensor_dict_to_tensorinfo(outputs)) - return signatures - - -def _make_graph_def(signature_functions): - """Generates and exports call ops for `signature_functions`.""" - # TODO(allenl): Handle variables - signatures = {} - exported_graph = ops.Graph() - with exported_graph.as_default(): - signatures = _generate_signatures(signature_functions) - graph_def = exported_graph.as_graph_def(add_shapes=True) - return graph_def, signatures - - -def export(obj, export_dir, signatures=None): - # pylint: disable=line-too-long - """Exports the Checkpointable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md). - - The `signatures` argument indicates TensorFlow functions which will be - available to programs which consume `SavedModel`s, for example serving - APIs. Python functions may be decorated with - `@tf.function(input_signature=...)` and passed as signatures directly, or - created without a signature using `@tf.function` and then converted to a - concrete TensorFlow function using `f.get_concrete_function(...)`. - - In either case, `Tensor` inputs to `signatures` functions which are not - associated with a unique Python argument name must have names explicitly - specified in their `tf.TensorSpec` objects. Cases where this is necessary - include positional arguments passed through variadic `*args` and multiple - `Tensor` inputs which are part of the same nested structure. - - The outputs of functions used as `signatures` must either be flat lists, in - which case outputs will be numbered, or a dictionary mapping string keys to - Tensors, in which case the string keys will be used to name outputs. - - Exporting with a signature specified: - - ```python - class Model(tf.keras.Model): - - @tf.function(input_signature=tf.TensorSpec(shape=[None], dtype=tf.string)) - def serve(serialized): - ... - - m = Model() - tf.saved_model.export(m, '/tmp/saved_model/', signatures=m.serve) - ``` - - Exporting from a function without a fixed signature: - - ```python - class Model(tf.keras.Model): - - @tf.function - def compute(x): - ... - - m = Model() - tf.saved_model.export( - m, '/tmp/saved_model/', - signatures=m.compute.get_concrete_function( - tf.TensorSpec(shape=[None, 3], dtype=tf.float32, name="inp"))) - ``` - - Args: - obj: A checkpointable object to export. - export_dir: A directory in which to write the SavedModel. - signatures: Optional, either a `tf.function` with an input signature - specified or the result of `f.get_concrete_function` on a - `tf.function`-decorated function `f`, in which case `f` will be used to - generate a signature for the SavedModel under the default serving - signature key. `signatures` may also be a dictionary, in which case it - maps from signature keys to either `tf.function` instances with input - signatures or concrete functions. The keys of such a dictionary may be - arbitrary strings, but will typically be from the - `tf.saved_model.signature_constants` module. - - Raises: - ValueError: If `obj` is not checkpointable. - """ - # pylint: enable=line-too-long - if not isinstance(obj, checkpointable.CheckpointableBase): - raise ValueError( - "Expected a Checkpointable object for export, got {}.".format(obj)) - signatures = _canonicalize_signatures(signatures) - graph_def, signatures = _make_graph_def(signatures) - saved_model = saved_model_pb2.SavedModel() - saved_model.saved_model_schema_version = ( - constants.SAVED_MODEL_SCHEMA_VERSION) - meta_graph_def = saved_model.meta_graphs.add() - # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x - # compatible (no sessions) and share it with this export API rather than - # making a SavedModel proto and writing it directly. - meta_graph_def.graph_def.MergeFrom(graph_def) - for signature_key, signature in signatures.items(): - meta_graph_def.signature_def[signature_key].MergeFrom(signature) - file_io.recursive_create_dir(export_dir) - path = os.path.join( - compat.as_bytes(export_dir), - compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) - file_io.write_string_to_file(path, saved_model.SerializeToString()) diff --git a/tensorflow/python/saved_model/export_test.py b/tensorflow/python/saved_model/export_test.py deleted file mode 100644 index 4131b45ce4073f..00000000000000 --- a/tensorflow/python/saved_model/export_test.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for checkpointable object SavedModel export.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os - -from tensorflow.python.eager import def_function -from tensorflow.python.eager import test -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec -from tensorflow.python.saved_model import export -from tensorflow.python.saved_model import loader -from tensorflow.python.saved_model import signature_constants -from tensorflow.python.training.checkpointable import tracking - - -class ExportTest(test.TestCase): - - def _import_and_infer( - self, export_dir, inputs, - signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY): - """Import a SavedModel into a TF 1.x-style graph and run `signature_key`.""" - graph = ops.Graph() - with graph.as_default(), self.session(graph) as session: - model = loader.load(session, [], export_dir) - signature = model.signature_def[signature_key] - self.assertEqual(set(inputs.keys()), set(signature.inputs.keys())) - feed_dict = {} - for arg_name in inputs.keys(): - feed_dict[graph.get_tensor_by_name(signature.inputs[arg_name].name)] = ( - inputs[arg_name]) - output_dict = {} - for output_name, output_tensor_info in signature.outputs.items(): - output_dict[output_name] = graph.get_tensor_by_name( - output_tensor_info.name) - return session.run(output_dict, feed_dict=feed_dict) - - def test_method_export_signature(self): - root = tracking.Checkpointable() - root.f = def_function.function( - lambda x: 2. * x, - input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) - root.f(constant_op.constant(1.)) - export_dir = os.path.join(self.get_temp_dir(), "saved_model") - export.export(root, export_dir, root.f) - self.assertEqual( - {"output_0": 2.}, - self._import_and_infer(export_dir, {"x": 1.})) - - def test_method_export_concrete(self): - root = tracking.Checkpointable() - root.f = def_function.function( - lambda z: {"out": 2. * z}) - root.f(constant_op.constant(1.)) - export_dir = os.path.join(self.get_temp_dir(), "saved_model") - export.export( - root, - export_dir, - {"non_default_key": root.f.get_concrete_function( - tensor_spec.TensorSpec(None, dtypes.float32))}) - self.assertEqual( - {"out": 2.}, - self._import_and_infer( - export_dir, {"z": 1.}, signature_key="non_default_key")) - - def test_non_concrete_error(self): - root = tracking.Checkpointable() - root.f = def_function.function(lambda x: 2. * x) - root.f(constant_op.constant(1.)) - export_dir = os.path.join(self.get_temp_dir(), "saved_model") - with self.assertRaisesRegexp( - ValueError, "must be converted to concrete functions"): - export.export(root, export_dir, root.f) - - def test_nested_inputs(self): - root = tracking.Checkpointable() - root.f = def_function.function( - lambda x: 2. * x[0], - input_signature=([tensor_spec.TensorSpec(None, dtypes.float32), - tensor_spec.TensorSpec(None, dtypes.float32)],)) - root.f([constant_op.constant(1.), constant_op.constant(1.)]) - # Concrete functions must always have uniquely named Tensor inputs. Export - # relies on this. - with self.assertRaisesRegexp( - ValueError, "two arguments named 'x'"): - root.f.get_concrete_function() - - def test_nested_outputs(self): - root = tracking.Checkpointable() - root.f = def_function.function(lambda x: (2. * x, (3. * x, 4. * x))) - root.f(constant_op.constant(1.)) - to_export = root.f.get_concrete_function(constant_op.constant(1.)) - export_dir = os.path.join(self.get_temp_dir(), "saved_model") - with self.assertRaisesRegexp( - ValueError, "non-flat outputs"): - export.export(root, export_dir, to_export) - - def test_nested_dict_outputs(self): - root = tracking.Checkpointable() - root.f = def_function.function( - lambda x: {"a": 2. * x, "b": (3. * x, 4. * x)}) - root.f(constant_op.constant(1.)) - to_export = root.f.get_concrete_function(constant_op.constant(1.)) - export_dir = os.path.join(self.get_temp_dir(), "saved_model") - with self.assertRaisesRegexp( - ValueError, "dictionary containing non-Tensor value"): - export.export(root, export_dir, to_export) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py index 55ef273fee3f94..8c8eaf038a1b90 100644 --- a/tensorflow/python/saved_model/loader_impl.py +++ b/tensorflow/python/saved_model/loader_impl.py @@ -173,9 +173,13 @@ def maybe_saved_model_directory(export_dir): return file_io.file_exists(txt_path) or file_io.file_exists(pb_path) -@tf_export("saved_model.load", - v1=["saved_model.load", "saved_model.loader.load"]) -@deprecation.deprecated_endpoints("saved_model.loader.load") +@tf_export(v1=["saved_model.load", "saved_model.loader.load"]) +@deprecation.deprecated( + None, + "This function will only be available through the v1 compatibility " + "library as tf.compat.v1.saved_model.loader.load or " + "tf.compat.v1.saved_model.load. There will be a new function for importing " + "SavedModels in Tensorflow 2.0.") def load(sess, tags, export_dir, import_scope=None, **saver_kwargs): """Loads the model from a SavedModel as specified by tags. diff --git a/tensorflow/python/saved_model/main_op_impl.py b/tensorflow/python/saved_model/main_op_impl.py index d567b95795b949..bc0d38930eb8fa 100644 --- a/tensorflow/python/saved_model/main_op_impl.py +++ b/tensorflow/python/saved_model/main_op_impl.py @@ -26,8 +26,11 @@ from tensorflow.python.util.tf_export import tf_export -@tf_export('saved_model.main_op', v1=['saved_model.main_op.main_op']) -@deprecation.deprecated_endpoints('saved_model.main_op.main_op') +@tf_export(v1=['saved_model.main_op.main_op']) +@deprecation.deprecated( + None, + 'This function will only be available through the v1 compatibility ' + 'library as tf.compat.v1.saved_model.main_op.main_op.') def main_op(): """Returns a main op to init variables and tables. @@ -44,13 +47,13 @@ def main_op(): # TODO(sukritiramesh): Integrate with Saver for complete restore functionality. -@tf_export( - 'saved_model.main_op_with_restore', - v1=[ - 'saved_model.main_op_with_restore', - 'saved_model.main_op.main_op_with_restore' - ]) -@deprecation.deprecated_endpoints('saved_model.main_op.main_op_with_restore') +@tf_export(v1=['saved_model.main_op_with_restore', + 'saved_model.main_op.main_op_with_restore']) +@deprecation.deprecated( + None, + 'This function will only be available through the v1 compatibility ' + 'library as tf.compat.v1.saved_model.main_op_with_restore or ' + 'tf.compat.v1.saved_model.main_op.main_op_with_restore.') def main_op_with_restore(restore_op_name): """Returns a main op to init variables, tables and restore the graph. diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py new file mode 100644 index 00000000000000..63575f631eb0c0 --- /dev/null +++ b/tensorflow/python/saved_model/save.py @@ -0,0 +1,540 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Exports a SavedModel from a Checkpointable Python object.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import os + +from tensorflow.core.protobuf import saved_model_pb2 +from tensorflow.python.eager import context +from tensorflow.python.eager import def_function +from tensorflow.python.eager import function +from tensorflow.python.framework import meta_graph +from tensorflow.python.framework import ops +from tensorflow.python.lib.io import file_io +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.saved_model import constants +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import signature_def_utils +from tensorflow.python.saved_model import utils_impl +from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.checkpointable import util +from tensorflow.python.util import compat +from tensorflow.python.util import nest +from tensorflow.python.util.tf_export import tf_export + + +def _find_function_to_export(root): + """Iterate over `root`'s attributes, finding traced functions.""" + functions = [] + function_attribute_names = [] + for attribute_name in dir(root): + attribute_value = getattr(root, attribute_name, None) + if isinstance(attribute_value, def_function.PolymorphicFunction): + functions.append(attribute_value) + function_attribute_names.append(attribute_name) + # TODO(allenl): Automatically infer signatures for Keras functional models? + if not functions: + raise ValueError( + ("Exporting an object with no tf.saved_model.save(..., signatures=...) " + "argument specified, and with no @tf.function-decorated methods " + "attached to it. In the future this will be a supported use-case for " + "Python re-import, but at the moment saving a SavedModel without " + "signatures does not make sense, as the only consumers will expect " + "signatures. Either decorate a method or specify a signature function " + "explicitly.")) + elif len(functions) > 1: + raise ValueError( + ("Exporting an object with no tf.saved_model.save(..., signatures=...) " + "argument specified, and with more than one @tf.function-decorated " + "method attached to it: {}. The signature keys for these functions " + "are ambiguous. Specify signature functions explicitly.").format( + function_attribute_names)) + return functions[0] + + +def _canonicalize_signatures(signatures): + """Converts `signatures` into a dictionary of concrete functions.""" + if not isinstance(signatures, collections.Mapping): + signatures = { + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures} + concrete_signatures = {} + for serving_key, signature_function in signatures.items(): + if isinstance(signature_function, (function.PolymorphicFunction, + def_function.PolymorphicFunction)): + input_signature = signature_function._input_signature # pylint: disable=protected-access + if input_signature is None: + raise ValueError( + ("Unable to use the function {} as a signature directly. Functions " + "used to generate serving signatures must either have an " + "`input_signature=` specified when constructed, or must be " + "converted to concrete functions using " + "`f.get_concrete_function(...)`.").format(signature_function)) + signature_function = signature_function.get_concrete_function() + elif not isinstance(signature_function, function.Function): + raise ValueError( + ("Expected a TensorFlow function to generate a signature for, but " + "got {}. Python functions may be decorated with " + "`@tf.function(input_signature=...)` and passed as signatures " + "directly, or created without a signature using `@tf.function` " + "and then converted to a concrete TensorFlow function using " + "`f.get_concrete_function(...)`.").format(signature_function)) + concrete_signatures[serving_key] = signature_function + return concrete_signatures + + +def _is_flat(sequence): + sequence_flat = nest.flatten(sequence) + try: + nest.assert_same_structure(sequence_flat, sequence) + return True + except ValueError: + return False + except TypeError: + return False + + +def _normalize_outputs(outputs, function_name, signature_key): + """Construct an output dictionary from unnormalized function outputs.""" + if isinstance(outputs, collections.Mapping): + for key, value in outputs.items(): + if not isinstance(value, ops.Tensor): + raise ValueError( + ("Got a dictionary containing non-Tensor value {} for key {} " + "in the output of the function {} used to generate a SavedModel " + "signature. Dictionaries outputs for functions used as signatures " + "should have one Tensor output per string key.") + .format(value, key, compat.as_str_any(function_name))) + return outputs + else: + original_outputs = outputs + if not isinstance(outputs, collections.Sequence): + outputs = [outputs] + if not _is_flat(outputs): + raise ValueError( + ("Got non-flat outputs '{}' from '{}' for SavedModel " + "signature '{}'. Signatures have one Tensor per output, so " + "to have predictable names Python functions used to generate " + "these signatures should avoid outputting Tensors in nested " + "structures.") + .format(original_outputs, function_name, signature_key)) + return {("output_{}".format(output_index)): output + for output_index, output + in enumerate(outputs)} + + +def _tensor_dict_to_tensorinfo(tensor_dict): + return {key: utils_impl.build_tensor_info(value) + for key, value in tensor_dict.items()} + + +def _map_captured_resources_to_created_resources( + original_captures, resource_map): + """Maps eager resources captured by a function to Graph resources for export. + + Args: + original_captures: A dictionary mapping from resource tensors captured by + the function to interior placeholders for those resources (inside the + function body). + resource_map: A dictionary mapping from resource tensors owned by the eager + context to resource tensors in the exported graph. + + Returns: + A dictionary mapping from interior placeholders in the function body to + exterior stand-in resource tensors which belong to the exported graph. + + Raises: + AssertionError: If the function references a resource which is not part of + `resource_map`. + """ + export_captures = {} + for exterior, interior in original_captures.items(): + mapped_resource = resource_map.get(exterior, None) + if mapped_resource is None: + raise AssertionError( + ("Tried to export a function which references untracked stateful " + "object {}. Stateful TensorFlow objects (e.g. tf.Variable) must " + "be tracked by the main object. Objects may be tracked by " + "assigning them to an attribute of another tracked object, or to " + "an attribute of the main object directly.") + .format(interior)) + export_captures[interior] = mapped_resource + return export_captures + + +def _map_function_inputs_to_created_inputs( + function_inputs, export_captures, signature_key, function_name): + """Creates exterior placeholders in the exported graph for function inputs. + + Functions have two types of inputs: tensors captured from the outside (eager) + context, and arguments to the function which we expect to receive from the + user at each call. `_map_captured_resources_to_created_resources` replaces + captured tensors with stand-ins (typically these are resource dtype tensors + associated with variables). `_map_function_inputs_to_created_inputs` runs over + every input, either captured or argument. For captures, it uses the mapped + resource from `export_captures`. For arguments, it creates a new placeholder + which will belong to the exported graph rather than the function body. + + Args: + function_inputs: A list of all placeholders in the function body. + export_captures: A dictionary mapping from interior placeholders in the + function body to exterior stand-in resource tensors which belong to the + exported graph (see `_map_captured_resources_to_created_resources`). + signature_key: The name of the signature being exported, for error messages. + function_name: The name of the function, for error messages. + + Returns: + A tuple of (mapped_inputs, exterior_placeholders) + mapped_inputs: A list with entries corresponding to `function_inputs` + containing all of the inputs of the function gathered from the exported + graph (both captured resources and arguments). + exterior_argument_placeholders: A dictionary mapping from argument names + to placeholders in the exported graph, containing the explicit arguments + to the function which a user is expected to provide. + + Raises: + ValueError: If argument names are not unique. + """ + # `exterior_argument_placeholders` holds placeholders which are outside the + # function body, directly contained in a MetaGraph of the SavedModel. The + # function body itself contains nearly identical placeholders used when + # running the function, but these exterior placeholders allow Session-based + # APIs to call the function using feeds and fetches which name Tensors in the + # MetaGraph. + exterior_argument_placeholders = {} + mapped_inputs = [] + for placeholder in function_inputs: + mapped_resource_tensor = export_captures.get(placeholder, None) + if mapped_resource_tensor is not None: + # This is a captured resource. + mapped_inputs.append(mapped_resource_tensor) + continue + # `export_captures` contains an exhaustive set of captures, so if we don't + # find the input there then we now know we have an argument. + user_input_name = compat.as_str_any( + placeholder.op.get_attr("_user_specified_name")) + # If the internal placeholders for a function have names which were + # uniquified by TensorFlow, then a single user-specified argument name + # must refer to multiple Tensors. The resulting signatures would be + # confusing to call. Instead, we throw an exception telling the user to + # specify explicit names. + if user_input_name != placeholder.op.name: + # This should be unreachable, since concrete functions may not be + # generated with non-unique argument names. + raise ValueError( + ("Got non-flat/non-unique argument names for SavedModel " + "signature '{}': more than one argument to '{}' was named '{}'. " + "Signatures have one Tensor per named input, so to have " + "predictable names Python functions used to generate these " + "signatures should avoid *args and Tensors in nested " + "structures unless unique names are specified for each. Use " + "tf.TensorSpec(..., name=...) to provide a name for a Tensor " + "input.") + .format(signature_key, compat.as_str_any(function_name), + user_input_name)) + arg_placeholder = array_ops.placeholder( + shape=placeholder.shape, + dtype=placeholder.dtype, + name="{}_{}".format(signature_key, user_input_name)) + exterior_argument_placeholders[user_input_name] = arg_placeholder + mapped_inputs.append(arg_placeholder) + return mapped_inputs, exterior_argument_placeholders + + +def _generate_signatures(signature_functions, resource_map): + """Validates and calls `signature_functions` in the default graph. + + Args: + signature_functions: A dictionary mapping string keys to concrete TensorFlow + functions (e.g. from `_canonicalize_signatures`) which will be used to + generate SignatureDefs. + resource_map: A dictionary mapping from resource tensors in the eager + context to resource tensors in the Graph being exported. This dictionary + is used to re-bind resources captured by functions to tensors which will + exist in the SavedModel. + + Returns: + Each function in the `signature_functions` dictionary is called with + placeholder Tensors, generating a function call operation and output + Tensors. The placeholder Tensors, the function call operation, and the + output Tensors from the function call are part of the default Graph. + + This function then returns a dictionary with the same structure as + `signature_functions`, with the concrete functions replaced by SignatureDefs + implicitly containing information about how to call each function from a + TensorFlow 1.x Session / the C++ Loader API. These SignatureDefs reference + the generated placeholders and Tensor outputs by name. + + The caller is expected to include the default Graph set while calling this + function as a MetaGraph in a SavedModel, including the returned + SignatureDefs as part of that MetaGraph. + """ + signatures = {} + for signature_key, func in sorted(signature_functions.items()): + # Register the inference function for this signature in the exported + # graph. There is no direct use for the gradient of this function, so we + # don't generate/register a gradient function here (but may end up with one + # if another function relies on it). Users can still take symbolic gradients + # of the function on import, the gradient just won't be in the saved + # graph. When exporting a signature which already computes gradients, this + # stops us from taking needless second-order gradients. + func.add_to_graph(register_gradient_functions=False) + export_captures = _map_captured_resources_to_created_resources( + func.graph.captures, resource_map) + mapped_inputs, exterior_argument_placeholders = ( + _map_function_inputs_to_created_inputs( + func.inputs, export_captures, signature_key, func.name)) + # Calls the function quite directly, since we have new captured resource + # tensors we need to feed in which weren't part of the original function + # definition. + # pylint: disable=protected-access + outputs = _normalize_outputs( + func._build_call_outputs( + func._inference_function.call(context.context(), mapped_inputs)), + func.name, signature_key) + # pylint: enable=protected-access + signatures[signature_key] = signature_def_utils.build_signature_def( + _tensor_dict_to_tensorinfo(exterior_argument_placeholders), + _tensor_dict_to_tensorinfo(outputs)) + return signatures + + +def _map_resources(accessible_objects): + """Makes new resource handle ops corresponding to existing resource tensors. + + Creates resource handle ops in the current default graph, whereas + `accessible_objects` will be from an eager context. Resource mapping adds + resource handle ops to the main GraphDef of a SavedModel, which allows the C++ + loader API to interact with variables. + + Args: + accessible_objects: A list of objects, some of which may contain resources, + to create replacements for. + + Returns: + A tuple of (object_map, resource_map): + object_map: A dictionary mapping from object in `accessible_objects` to + replacement objects created to hold the new resource tensors. + resource_map: A dictionary mapping from resource tensors extracted from + `accessible_objects` to newly created resource tensors. + """ + # TODO(allenl, rohanj): Map generic resources rather than just variables. + # TODO(allenl): Handle MirroredVariables and other types of variables which + # may need special casing. + object_map = {} + resource_map = {} + for obj in accessible_objects: + if resource_variable_ops.is_resource_variable(obj): + new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj) + object_map[obj] = new_variable + resource_map[obj.handle] = new_variable.handle + return object_map, resource_map + + +def _make_graph_def(root, signature_functions, object_saver): + """Generates and exports call ops for `signature_functions`.""" + signatures = {} + # List objects from the eager context to make sure Optimizers give us the + # right Graph-dependent variables. + accessible_objects = util.list_objects(root) + exported_graph = ops.Graph() + with exported_graph.as_default(): + object_map, resource_map = _map_resources(accessible_objects) + # Saving an object-based checkpoint again gathers variables. We need to do the + # gathering from the eager context so Optimizers save the right set of + # variables, but want any operations associated with the save/restore to be in + # the exported graph (thus the `to_graph` argument). + saver = object_saver.freeze(object_map=object_map, to_graph=exported_graph) + with exported_graph.as_default(): + signatures = _generate_signatures(signature_functions, resource_map) + saver_def = saver.to_proto() + graph_def = exported_graph.as_graph_def(add_shapes=True) + # Clean reference cycles so repeated export()s don't make work for the garbage + # collector. + ops.dismantle_graph(exported_graph) + return graph_def, signatures, saver_def + + +@tf_export("saved_model.save", v1=["saved_model.experimental.save"]) +def save(obj, export_dir, signatures=None): + # pylint: disable=line-too-long + """Exports the Checkpointable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md). + + Example usage: + + ```python + class Adder(tf.train.Checkpoint): + + @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) + def add(self, x): + return x + x + 1. + + to_export = Adder() + tf.saved_model.save(to_export, '/tmp/adder') + ``` + + The resulting SavedModel is then servable with an input named "x", its value + having any shape and dtype float32. + + The optional `signatures` argument controls which methods in `obj` will be + available to programs which consume `SavedModel`s, for example serving + APIs. Python functions may be decorated with + `@tf.function(input_signature=...)` and passed as signatures directly, or + lazily with a call to `get_concrete_function` on the method decorated with + `@tf.function`. + + If the `signatures` argument is omitted, `obj` will be searched for + `@tf.function`-decorated methods. If exactly one `@tf.function` is found, that + method will be used as the default signature for the SavedModel. This behavior + is expected to change in the future, when a corresponding + `tf.saved_model.load` symbol is added. At that point signatures will be + completely optional, and any `@tf.function` attached to `obj` or its + dependencies will be exported for use with `load`. + + When invoking a signature in an exported SavedModel, `Tensor` arguments are + identified by name. These names will come from the Python function's argument + names by default. They may be overridden by specifying a `name=...` argument + in the corresponding `tf.TensorSpec` object. Explicit naming is required if + multiple `Tensor`s are passed through a single argument to the Python + function. + + The outputs of functions used as `signatures` must either be flat lists, in + which case outputs will be numbered, or a dictionary mapping string keys to + `Tensor`, in which case the keys will be used to name outputs. + + Since `tf.keras.Model` objects are also Checkpointable, this function can be + used to export Keras models. For example, exporting with a signature + specified: + + ```python + class Model(tf.keras.Model): + + @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)]) + def serve(self, serialized): + ... + + m = Model() + tf.saved_model.save(m, '/tmp/saved_model/') + ``` + + Exporting from a function without a fixed signature: + + ```python + class Model(tf.keras.Model): + + @tf.function + def call(self, x): + ... + + m = Model() + tf.saved_model.save( + m, '/tmp/saved_model/', + signatures=m.call.get_concrete_function( + tf.TensorSpec(shape=[None, 3], dtype=tf.float32, name="inp"))) + ``` + + Variables must be tracked by assigning them to an attribute of a tracked + object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers + from `tf.keras.layers`, optimizers from `tf.train`) track their variables + automatically. This is the same tracking scheme that `tf.train.Checkpoint` + uses, and an exported `Checkpoint` object may be restored as a training + checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's + "variables/" subdirectory. Currently variables are the only stateful objects + supported by `tf.saved_model.save`, but others (e.g. tables) will be supported + in the future. + + `tf.function` does not hard-code device annotations from outside the function + body, instead using the calling context's device. This means for example that + exporting a model which runs on a GPU and serving it on a CPU will generally + work, with some exceptions. `tf.device` annotations inside the body of the + function will be hard-coded in the exported model; this type of annotation is + discouraged. Device-specific operations, e.g. with "cuDNN" in the name or with + device-specific layouts, may cause issues. Currently a `DistributionStrategy` + is another exception: active distribution strategies will cause device + placements to be hard-coded in a function. Exporting a single-device + computation and importing under a `DistributionStrategy` is not currently + supported, but may be in the future. + + SavedModels exported with `tf.saved_model.save` [strip default-valued + attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes) + automatically, which removes one source of incompatibilities when the consumer + of a SavedModel is running an older TensorFlow version than the + producer. There are however other sources of incompatibilities which are not + handled automatically, such as when the exported model contains operations + which the consumer does not have definitions for. + + The current implementation of `tf.saved_model.save` targets serving use-cases, + but omits information which will be necessary for the planned future + implementation of `tf.saved_model.load`. Exported models using the current + `save` implementation, and other existing SavedModels, will not be compatible + with `tf.saved_model.load` when it is implemented. Further, `save` will in the + future attempt to export `@tf.function`-decorated methods which it does not + currently inspect, so some objects which are exportable today will raise + exceptions on export in the future (e.g. due to complex/non-serializable + default arguments). Such backwards-incompatible API changes are expected only + prior to the TensorFlow 2.0 release. + + Args: + obj: A checkpointable object to export. + export_dir: A directory in which to write the SavedModel. + signatures: Optional, either a `tf.function` with an input signature + specified or the result of `f.get_concrete_function` on a + `@tf.function`-decorated function `f`, in which case `f` will be used to + generate a signature for the SavedModel under the default serving + signature key. `signatures` may also be a dictionary, in which case it + maps from signature keys to either `tf.function` instances with input + signatures or concrete functions. The keys of such a dictionary may be + arbitrary strings, but will typically be from the + `tf.saved_model.signature_constants` module. + + Raises: + ValueError: If `obj` is not checkpointable. + """ + # pylint: enable=line-too-long + if not isinstance(obj, base.CheckpointableBase): + raise ValueError( + "Expected a Checkpointable object for export, got {}.".format(obj)) + if signatures is None: + # Note that we run this before saving the checkpoint, since looping over + # attributes may have the side effect of creating variables in some cases. + signatures = _find_function_to_export(obj) + object_saver = util.CheckpointableSaver(obj) + utils_impl.get_or_create_variables_dir(export_dir) + object_saver.save(utils_impl.get_variables_path(export_dir)) + + signatures = _canonicalize_signatures(signatures) + graph_def, signatures, saver_def = _make_graph_def( + obj, signatures, object_saver) + saved_model = saved_model_pb2.SavedModel() + saved_model.saved_model_schema_version = ( + constants.SAVED_MODEL_SCHEMA_VERSION) + meta_graph_def = saved_model.meta_graphs.add() + meta_graph_def.saver_def.CopyFrom(saver_def) + # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x + # compatible (no sessions) and share it with this export API rather than + # making a SavedModel proto and writing it directly. + meta_graph_def.graph_def.MergeFrom(graph_def) + for signature_key, signature in signatures.items(): + meta_graph_def.signature_def[signature_key].MergeFrom(signature) + meta_graph.strip_graph_default_valued_attrs(meta_graph_def) + path = os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) + file_io.write_string_to_file(path, saved_model.SerializeToString()) diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py new file mode 100644 index 00000000000000..42ff508b38ae2b --- /dev/null +++ b/tensorflow/python/saved_model/save_test.py @@ -0,0 +1,276 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for checkpointable object SavedModel save.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +from tensorflow.python.eager import backprop +from tensorflow.python.eager import def_function +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import test_util +from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import core +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.saved_model import loader +from tensorflow.python.saved_model import save +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.training import adam +from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.checkpointable import util + + +class _ModelWithOptimizer(training.Model): + + def __init__(self): + super(_ModelWithOptimizer, self).__init__() + self.dense = core.Dense(1) + self.optimizer = adam.AdamOptimizer(0.01) + + @def_function.function( + input_signature=(tensor_spec.TensorSpec([None, 2], dtypes.float32), + tensor_spec.TensorSpec([None], dtypes.float32))) + def call(self, x, y): + with backprop.GradientTape() as tape: + loss = math_ops.reduce_mean((self.dense(x) - y) ** 2.) + trainable_variables = self.trainable_variables + gradients = tape.gradient(loss, trainable_variables) + self.optimizer.apply_gradients(zip(gradients, trainable_variables)) + return {"loss": loss} + + +class SaveTest(test.TestCase): + + def _import_and_infer( + self, save_dir, inputs, + signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY): + """Import a SavedModel into a TF 1.x-style graph and run `signature_key`.""" + graph = ops.Graph() + with graph.as_default(), self.session(graph) as session: + model = loader.load(session, [], save_dir) + signature = model.signature_def[signature_key] + self.assertEqual(set(inputs.keys()), set(signature.inputs.keys())) + feed_dict = {} + for arg_name in inputs.keys(): + feed_dict[graph.get_tensor_by_name(signature.inputs[arg_name].name)] = ( + inputs[arg_name]) + output_dict = {} + for output_name, output_tensor_info in signature.outputs.items(): + output_dict[output_name] = graph.get_tensor_by_name( + output_tensor_info.name) + return session.run(output_dict, feed_dict=feed_dict) + + def test_method_save_signature(self): + root = tracking.Checkpointable() + root.f = def_function.function( + lambda x: 2. * x, + input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) + root.f(constant_op.constant(1.)) + save_dir = os.path.join(self.get_temp_dir(), "saved_model") + save.save(root, save_dir, root.f) + self.assertEqual( + {"output_0": 2.}, + self._import_and_infer(save_dir, {"x": 1.})) + + def test_method_save_concrete(self): + root = tracking.Checkpointable() + root.f = def_function.function( + lambda z: {"out": 2. * z}) + root.f(constant_op.constant(1.)) + save_dir = os.path.join(self.get_temp_dir(), "saved_model") + save.save( + root, + save_dir, + {"non_default_key": root.f.get_concrete_function( + tensor_spec.TensorSpec(None, dtypes.float32))}) + self.assertEqual( + {"out": 2.}, + self._import_and_infer( + save_dir, {"z": 1.}, signature_key="non_default_key")) + + def test_non_concrete_error(self): + root = tracking.Checkpointable() + root.f = def_function.function(lambda x: 2. * x) + root.f(constant_op.constant(1.)) + save_dir = os.path.join(self.get_temp_dir(), "saved_model") + with self.assertRaisesRegexp( + ValueError, "must be converted to concrete functions"): + save.save(root, save_dir, root.f) + + def test_nested_inputs(self): + root = tracking.Checkpointable() + root.f = def_function.function( + lambda x: 2. * x[0], + input_signature=([tensor_spec.TensorSpec(None, dtypes.float32), + tensor_spec.TensorSpec(None, dtypes.float32)],)) + root.f([constant_op.constant(1.), constant_op.constant(1.)]) + # Concrete functions must always have uniquely named Tensor inputs. Save + # relies on this. + with self.assertRaisesRegexp( + ValueError, "two arguments named 'x'"): + root.f.get_concrete_function() + + def test_nested_outputs(self): + root = tracking.Checkpointable() + root.f = def_function.function(lambda x: (2. * x, (3. * x, 4. * x))) + root.f(constant_op.constant(1.)) + to_save = root.f.get_concrete_function(constant_op.constant(1.)) + save_dir = os.path.join(self.get_temp_dir(), "saved_model") + with self.assertRaisesRegexp( + ValueError, "non-flat outputs"): + save.save(root, save_dir, to_save) + + def test_nested_dict_outputs(self): + root = tracking.Checkpointable() + root.f = def_function.function( + lambda x: {"a": 2. * x, "b": (3. * x, 4. * x)}) + root.f(constant_op.constant(1.)) + to_save = root.f.get_concrete_function(constant_op.constant(1.)) + save_dir = os.path.join(self.get_temp_dir(), "saved_model") + with self.assertRaisesRegexp( + ValueError, "dictionary containing non-Tensor value"): + save.save(root, save_dir, to_save) + + def test_variable(self): + root = tracking.Checkpointable() + root.v1 = variables.Variable(3.) + root.v2 = variables.Variable(2.) + root.f = def_function.function( + lambda x: root.v1 * root.v2 * x) + root.f(constant_op.constant(1.)) + to_save = root.f.get_concrete_function(constant_op.constant(1.)) + save_dir = os.path.join(self.get_temp_dir(), "saved_model") + save.save(root, save_dir, to_save) + self.assertAllEqual({"output_0": 12.}, + self._import_and_infer(save_dir, {"x": 2.})) + + def test_optimizer(self): + x = constant_op.constant([[3., 4.]]) + y = constant_op.constant([2.]) + model = _ModelWithOptimizer() + first_loss = model(x, y) + save_dir = os.path.join(self.get_temp_dir(), "saved_model") + save.save(model, save_dir, model.call) + second_loss = model(x, y) + self.assertNotEqual(first_loss, second_loss) + self.assertAllClose( + second_loss, + self._import_and_infer(save_dir, {"x": [[3., 4.]], "y": [2.]})) + + def test_trivial_save_exception(self): + save_dir = os.path.join(self.get_temp_dir(), "saved_model") + with self.assertRaisesRegexp(ValueError, "signature"): + save.save(tracking.Checkpointable(), save_dir) + + def test_single_method_default_signature(self): + model = _ModelWithOptimizer() + x = constant_op.constant([[3., 4.]]) + y = constant_op.constant([2.]) + model(x, y) + save_dir = os.path.join(self.get_temp_dir(), "saved_model") + save.save(model, save_dir) + self.assertIn("loss", + self._import_and_infer(save_dir, + {"x": [[3., 4.]], "y": [2.]})) + + def test_single_function_default_signature(self): + model = tracking.Checkpointable() + model.f = def_function.function(lambda: 3., input_signature=()) + model.f() + save_dir = os.path.join(self.get_temp_dir(), "saved_model") + save.save(model, save_dir) + self.assertAllClose({"output_0": 3.}, + self._import_and_infer(save_dir, {})) + + def test_ambiguous_signatures(self): + model = _ModelWithOptimizer() + x = constant_op.constant([[3., 4.]]) + y = constant_op.constant([2.]) + model(x, y) + model.second_function = def_function.function(lambda: 1.) + save_dir = os.path.join(self.get_temp_dir(), "saved_model") + with self.assertRaisesRegexp(ValueError, "call.*second_function"): + save.save(model, save_dir) + + def test_docstring(self): + + class Adder(util.Checkpoint): + + @def_function.function(input_signature=[tensor_spec.TensorSpec( + shape=None, dtype=dtypes.float32)]) + def add(self, x): + return x + x + 1. + + to_save = Adder() + to_save.add(constant_op.constant(1.)) + save_dir = os.path.join(self.get_temp_dir(), "saved_model") + save.save(to_save, save_dir) + self.assertAllClose({"output_0": 7.}, + self._import_and_infer(save_dir, {"x": 3.})) + + def test_default_attr_stripping(self): + + class Complex(util.Checkpoint): + + @def_function.function(input_signature=[]) + def __call__(self): + return math_ops.complex( + constant_op.constant(1.), + constant_op.constant(2.), + name="complex") + + to_save = Complex() + to_save() + save_dir = os.path.join(self.get_temp_dir(), "saved_model") + save.save(to_save, save_dir) + graph = ops.Graph() + with graph.as_default(), self.session(graph) as session: + loader.load(session, [], save_dir) + func, = graph._functions.values() + complex_node, = [ + node for node in func.definition.node_def if node.op == "Complex"] + self.assertNotIn("T", complex_node.attr) + self.assertNotIn("Tout", complex_node.attr) + + +class MemoryTests(test.TestCase): + + def setUp(self): + self._model = _ModelWithOptimizer() + + @test_util.assert_no_garbage_created + def test_no_reference_cycles(self): + x = constant_op.constant([[3., 4.]]) + y = constant_op.constant([2.]) + self._model(x, y) + if sys.version_info[0] < 3: + # TODO(allenl): debug reference cycles in Python 2.x + self.skipTest("This test only works in Python 3+. Reference cycles are " + "created in older Python versions.") + save_dir = os.path.join(self.get_temp_dir(), "saved_model") + save.save(self._model, save_dir, self._model.call) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/saved_model/saved_model.py b/tensorflow/python/saved_model/saved_model.py index 6702c996071364..fcde6b47e4ff10 100644 --- a/tensorflow/python/saved_model/saved_model.py +++ b/tensorflow/python/saved_model/saved_model.py @@ -29,8 +29,8 @@ from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import tag_constants from tensorflow.python.saved_model import utils +from tensorflow.python.saved_model.save import save # pylint: enable=unused-import # pylint: disable=wildcard-import from tensorflow.python.saved_model.simple_save import * # pylint: enable=wildcard-import - diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index 80b75b7ee65031..5d6167ab38f5a0 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -1420,8 +1420,11 @@ def testInconsistentConsumerDefaultAttrs(self): sess = session.Session(graph=ops.Graph()) with self.assertRaisesRegexp( errors.InvalidArgumentError, - ".*No OpKernel was registered to support Op \'TestAttr\' with these " - "attrs..*"): + "No OpKernel was registered to support Op 'TestAttr' used by node " + "test_attr \\(defined at .*\\) with these attrs: \\[.*\\]\n" + "Registered devices:.*\n" + "Registered kernels:.*" + ): loader.load(sess, ["foo"], export_dir) diff --git a/tensorflow/python/saved_model/simple_save.py b/tensorflow/python/saved_model/simple_save.py index 76d6f666f6e7e8..169504ec891315 100644 --- a/tensorflow/python/saved_model/simple_save.py +++ b/tensorflow/python/saved_model/simple_save.py @@ -23,10 +23,15 @@ from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import tag_constants +from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export -@tf_export('saved_model.simple_save') +@tf_export(v1=['saved_model.simple_save']) +@deprecation.deprecated( + None, + 'This function will only be available through the v1 compatibility ' + 'library as tf.compat.v1.saved_model.simple_save.') def simple_save(session, export_dir, inputs, outputs, legacy_init_op=None): """Convenience function to build a SavedModel suitable for serving. diff --git a/tensorflow/python/saved_model/utils_impl.py b/tensorflow/python/saved_model/utils_impl.py index b3c27dbd818093..2ee4d9f4e042f2 100644 --- a/tensorflow/python/saved_model/utils_impl.py +++ b/tensorflow/python/saved_model/utils_impl.py @@ -34,10 +34,13 @@ # TensorInfo helpers. -@tf_export( - "saved_model.build_tensor_info", - v1=["saved_model.build_tensor_info", "saved_model.utils.build_tensor_info"]) -@deprecation.deprecated_endpoints("saved_model.utils.build_tensor_info") +@tf_export(v1=["saved_model.build_tensor_info", + "saved_model.utils.build_tensor_info"]) +@deprecation.deprecated( + None, + "This function will only be available through the v1 compatibility " + "library as tf.compat.v1.saved_model.utils.build_tensor_info or " + "tf.compat.v1.saved_model.build_tensor_info.") def build_tensor_info(tensor): """Utility function to build TensorInfo proto. @@ -61,14 +64,13 @@ def build_tensor_info(tensor): return tensor_info -@tf_export( - "saved_model.get_tensor_from_tensor_info", - v1=[ - "saved_model.get_tensor_from_tensor_info", - "saved_model.utils.get_tensor_from_tensor_info" - ]) -@deprecation.deprecated_endpoints( - "saved_model.utils.get_tensor_from_tensor_info") +@tf_export(v1=["saved_model.get_tensor_from_tensor_info", + "saved_model.utils.get_tensor_from_tensor_info"]) +@deprecation.deprecated( + None, + "This function will only be available through the v1 compatibility " + "library as tf.compat.v1.saved_model.utils.get_tensor_from_tensor_info or " + "tf.compat.v1.saved_model.get_tensor_from_tensor_info.") def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None): """Returns the Tensor or SparseTensor described by a TensorInfo proto. diff --git a/tensorflow/python/summary/README.md b/tensorflow/python/summary/README.md index 8a5fea0d9a130a..ab6e89e5c95e4d 100644 --- a/tensorflow/python/summary/README.md +++ b/tensorflow/python/summary/README.md @@ -8,8 +8,3 @@ events files. If you wish to load TensorFlow events, you should use an EventAccumulator (to load from a single events file) or an EventMultiplexer (to load from multiple events files). - -The API around these tools has not solidified, and we may make backwards- -incompatible changes without warning. - -If you have questions or requests, please contact danmane@google.com diff --git a/tensorflow/python/summary/plugin_asset.py b/tensorflow/python/summary/plugin_asset.py index 998fb30fa491bd..82d3a618304fb9 100644 --- a/tensorflow/python/summary/plugin_asset.py +++ b/tensorflow/python/summary/plugin_asset.py @@ -32,6 +32,8 @@ import abc +import six + from tensorflow.python.framework import ops _PLUGIN_ASSET_PREFIX = "__tensorboard_plugin_asset__" @@ -107,6 +109,7 @@ def get_all_plugin_assets(graph=None): return out +@six.add_metaclass(abc.ABCMeta) class PluginAsset(object): """This abstract base class allows TensorBoard to serialize assets to disk. @@ -124,7 +127,6 @@ class PluginAsset(object): writer calls assets and the PluginAsset instance provides its contents to be written to disk. """ - __metaclass__ = abc.ABCMeta plugin_name = None diff --git a/tensorflow/python/summary/writer/writer.py b/tensorflow/python/summary/writer/writer.py index 16b8626476eb1d..78217b503ffac9 100644 --- a/tensorflow/python/summary/writer/writer.py +++ b/tensorflow/python/summary/writer/writer.py @@ -20,6 +20,7 @@ import os.path import time +import warnings from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import summary_pb2 @@ -364,6 +365,8 @@ def __init__(self, else: event_writer = EventFileWriter(logdir, max_queue, flush_secs, filename_suffix) + + self._closed = False super(FileWriter, self).__init__(event_writer, graph, graph_def) def __enter__(self): @@ -378,12 +381,23 @@ def get_logdir(self): """Returns the directory where event file will be written.""" return self.event_writer.get_logdir() + def _warn_if_event_writer_is_closed(self): + if self._closed: + warnings.warn("Attempting to use a closed FileWriter. " + "The operation will be a noop unless the FileWriter " + "is explicitly reopened.") + + def _add_event(self, event, step): + self._warn_if_event_writer_is_closed() + super(FileWriter, self)._add_event(event, step) + def add_event(self, event): """Adds an event to the event file. Args: event: An `Event` protocol buffer. """ + self._warn_if_event_writer_is_closed() self.event_writer.add_event(event) def flush(self): @@ -392,6 +406,9 @@ def flush(self): Call this method to make sure that all pending events have been written to disk. """ + # Flushing a closed EventFileWriterV2 raises an exception. It is, + # however, a noop for EventFileWriter. + self._warn_if_event_writer_is_closed() self.event_writer.flush() def close(self): @@ -400,6 +417,7 @@ def close(self): Call this method when you do not need the summary writer anymore. """ self.event_writer.close() + self._closed = True def reopen(self): """Reopens the EventFileWriter. @@ -410,3 +428,4 @@ def reopen(self): Does nothing if the EventFileWriter was not closed. """ self.event_writer.reopen() + self._closed = False diff --git a/tensorflow/python/summary/writer/writer_test.py b/tensorflow/python/summary/writer/writer_test.py index 670230e917eb33..09d4b63fbb6178 100644 --- a/tensorflow/python/summary/writer/writer_test.py +++ b/tensorflow/python/summary/writer/writer_test.py @@ -22,6 +22,7 @@ import os.path import shutil import time +import warnings from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import summary_pb2 @@ -273,6 +274,21 @@ def testNonBlockingClose(self): sw.close() self._assertRecent(time_before_close) + def testUseAfterClose(self): + test_dir = self._CleanTestDir("use_after_close") + sw = self._FileWriter(test_dir) + sw.close() + with warnings.catch_warnings(record=True) as triggered: + warnings.simplefilter("always") + self.assertFalse(triggered) + sw.add_summary(summary_pb2.Summary()) + sw.add_session_log(event_pb2.SessionLog()) + sw.add_graph(ops.Graph()) + + self.assertEqual(len(triggered), 3) + for w in triggered: + self.assertEqual(w.category, UserWarning) + def testWithStatement(self): test_dir = self._CleanTestDir("with_statement") with self._FileWriter(test_dir) as sw: diff --git a/tensorflow/python/tools/api/generator/api_gen.bzl b/tensorflow/python/tools/api/generator/api_gen.bzl index c4633dfe32efe8..2e5d875a58ae4a 100644 --- a/tensorflow/python/tools/api/generator/api_gen.bzl +++ b/tensorflow/python/tools/api/generator/api_gen.bzl @@ -2,16 +2,6 @@ load("//tensorflow/python/tools/api/generator:api_init_files.bzl", "TENSORFLOW_API_INIT_FILES") -# keep sorted -ESTIMATOR_API_INIT_FILES = [ - # BEGIN GENERATED ESTIMATOR FILES - "__init__.py", - "estimator/__init__.py", - "estimator/export/__init__.py", - "estimator/inputs/__init__.py", - # END GENERATED ESTIMATOR FILES -] - def get_compat_files( file_paths, compat_api_version): @@ -27,7 +17,7 @@ def gen_api_init_files( api_version = 2, compat_api_versions = [], compat_init_templates = [], - packages = ["tensorflow.python"], + packages = ["tensorflow.python", "tensorflow.lite.python.lite"], package_deps = ["//tensorflow/python:no_contrib"], output_package = "tensorflow", output_dir = ""): diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl index af6d53aaaaa3b6..ac7bc28b2be130 100644 --- a/tensorflow/python/tools/api/generator/api_init_files.bzl +++ b/tensorflow/python/tools/api/generator/api_init_files.bzl @@ -13,6 +13,7 @@ TENSORFLOW_API_INIT_FILES = [ "distributions/__init__.py", "dtypes/__init__.py", "errors/__init__.py", + "experimental/__init__.py", "feature_column/__init__.py", "gfile/__init__.py", "graph_util/__init__.py", @@ -60,9 +61,10 @@ TENSORFLOW_API_INIT_FILES = [ "keras/wrappers/__init__.py", "keras/wrappers/scikit_learn/__init__.py", "linalg/__init__.py", + "lite/__init__.py", + "lite/constants/__init__.py", "logging/__init__.py", "losses/__init__.py", - "manip/__init__.py", "math/__init__.py", "metrics/__init__.py", "nn/__init__.py", diff --git a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl index 8d3b84751775a2..d0bac4033ca25c 100644 --- a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl +++ b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl @@ -13,6 +13,7 @@ TENSORFLOW_API_INIT_FILES_V1 = [ "distributions/__init__.py", "dtypes/__init__.py", "errors/__init__.py", + "experimental/__init__.py", "feature_column/__init__.py", "gfile/__init__.py", "graph_util/__init__.py", @@ -62,6 +63,8 @@ TENSORFLOW_API_INIT_FILES_V1 = [ "layers/__init__.py", "layers/experimental/__init__.py", "linalg/__init__.py", + "lite/__init__.py", + "lite/constants/__init__.py", "logging/__init__.py", "losses/__init__.py", "manip/__init__.py", @@ -78,6 +81,7 @@ TENSORFLOW_API_INIT_FILES_V1 = [ "saved_model/__init__.py", "saved_model/builder/__init__.py", "saved_model/constants/__init__.py", + "saved_model/experimental/__init__.py", "saved_model/loader/__init__.py", "saved_model/main_op/__init__.py", "saved_model/signature_constants/__init__.py", diff --git a/tensorflow/python/tools/api/generator/create_python_api.py b/tensorflow/python/tools/api/generator/create_python_api.py index 3e4186f7afb6a0..f6258034213693 100644 --- a/tensorflow/python/tools/api/generator/create_python_api.py +++ b/tensorflow/python/tools/api/generator/create_python_api.py @@ -299,7 +299,8 @@ def in_packages(m): module.__name__ is None or not in_packages(module.__name__)): continue # Do not generate __init__.py files for contrib modules for now. - if '.contrib.' in module.__name__ or module.__name__.endswith('.contrib'): + if (('.contrib.' in module.__name__ or module.__name__.endswith('.contrib')) + and '.lite' not in module.__name__): continue for module_contents_name in dir(module): diff --git a/tensorflow/python/tools/api/generator/output_init_files_test.py b/tensorflow/python/tools/api/generator/output_init_files_test.py index 1b6556c59c299f..ab154af9101e32 100644 --- a/tensorflow/python/tools/api/generator/output_init_files_test.py +++ b/tensorflow/python/tools/api/generator/output_init_files_test.py @@ -23,6 +23,7 @@ # available in sys.modules # pylint: disable=unused-import from tensorflow import python as _tf_for_api_traversal +from tensorflow.lite.python import lite as _tflite_for_api_traversal # pylint: enable=unused-import from tensorflow.python.platform import test from tensorflow.python.util import tf_decorator diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py index 57954ec56a5943..857da431db2cb9 100644 --- a/tensorflow/python/training/checkpoint_utils.py +++ b/tensorflow/python/training/checkpoint_utils.py @@ -318,13 +318,13 @@ def _set_checkpoint_initializer(variable, saveable_objects.append(s) assert len(saveable_objects) == 1 # Should be only one variable. - init_op = saveable_objects[0].restore([restore_op], restored_shapes=None) + init_op = saveable_objects[0].restore([restore_op], restored_shapes=None) - # pylint:disable=protected-access - variable._initializer_op = init_op - restore_op.set_shape(variable.shape) - variable._initial_value = restore_op - # pylint:enable=protected-access + # pylint:disable=protected-access + variable._initializer_op = init_op + restore_op.set_shape(variable.shape) + variable._initial_value = restore_op + # pylint:enable=protected-access def _set_variable_or_list_initializer(variable_or_list, ckpt_file, diff --git a/tensorflow/python/training/checkpoint_utils_test.py b/tensorflow/python/training/checkpoint_utils_test.py index 61dcbdb2b8f922..a3e58de4a31bca 100644 --- a/tensorflow/python/training/checkpoint_utils_test.py +++ b/tensorflow/python/training/checkpoint_utils_test.py @@ -207,9 +207,6 @@ def testRestoreRunsOnSameDevice(self): checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"useful_scope/": "useful_scope/"}) - # initializer runs on the same task but always on CPU. - self.assertEqual(my4._initializer_op.op.inputs[1].device, - "/job:ps/device:CPU:0") def testInitFromRootCheckpoint(self): checkpoint_dir = self.get_temp_dir() diff --git a/tensorflow/python/training/checkpointable/tracking.py b/tensorflow/python/training/checkpointable/tracking.py index 558ae0855e4c46..c85b208d479855 100644 --- a/tensorflow/python/training/checkpointable/tracking.py +++ b/tensorflow/python/training/checkpointable/tracking.py @@ -19,6 +19,11 @@ from tensorflow.python.training.checkpointable import base from tensorflow.python.training.checkpointable import data_structures +from tensorflow.python.util import tf_contextlib + + +# global _RESOURCE_TRACKER_STACK +_RESOURCE_TRACKER_STACK = [] class NotCheckpointable(object): @@ -72,10 +77,57 @@ def _no_dependency(self, value): return data_structures.NoDependency(value) +class ResourceTracker(object): + """An object that tracks a list of resources.""" + + def __init__(self): + self._resources = [] + + @property + def resources(self): + return self._resources + + def add_resource(self, resource): + self._resources.append(resource) + + +@tf_contextlib.contextmanager +def resource_tracker_scope(resource_tracker): + """A context to manage resource trackers. + + Use this in order to collect up all resources created within a block of code. + Example usage: + + ```python + resource_tracker = ResourceTracker() + with resource_tracker_scope(resource_tracker): + resource = TrackableResource() + + assert resource_tracker.resources == [resource] + + Args: + resource_tracker: The passed in ResourceTracker object + + Yields: + A scope in which the resource_tracker is active. + """ + global _RESOURCE_TRACKER_STACK + old = list(_RESOURCE_TRACKER_STACK) + _RESOURCE_TRACKER_STACK.append(resource_tracker) + try: + yield + finally: + _RESOURCE_TRACKER_STACK = old + + class TrackableResource(base.CheckpointableBase): """Base class for all resources that need to be tracked.""" def __init__(self): + global _RESOURCE_TRACKER_STACK + for resource_tracker in _RESOURCE_TRACKER_STACK: + resource_tracker.add_resource(self) + self._resource_handle = None def create_resource(self): diff --git a/tensorflow/python/training/checkpointable/tracking_test.py b/tensorflow/python/training/checkpointable/tracking_test.py index a44c570fb9fe41..17c5461bc25e5e 100644 --- a/tensorflow/python/training/checkpointable/tracking_test.py +++ b/tensorflow/python/training/checkpointable/tracking_test.py @@ -193,5 +193,62 @@ def testAssertions(self): self.assertAllClose({"k": [numpy.ones([2, 2]), numpy.zeros([3, 3])]}, self.evaluate(a.tensors)) + +class _DummyResource(tracking.TrackableResource): + + def __init__(self, handle_name): + self._handle_name = handle_name + super(_DummyResource, self).__init__() + + def create_resource(self): + return self._handle_name + + +class ResourceTrackerTest(test.TestCase): + + def testBasic(self): + resource_tracker = tracking.ResourceTracker() + with tracking.resource_tracker_scope(resource_tracker): + dummy_resource1 = _DummyResource("test1") + dummy_resource2 = _DummyResource("test2") + + self.assertEqual(2, len(resource_tracker.resources)) + self.assertEqual("test1", resource_tracker.resources[0].resource_handle) + self.assertEqual("test2", resource_tracker.resources[1].resource_handle) + + def testTwoScopes(self): + resource_tracker1 = tracking.ResourceTracker() + with tracking.resource_tracker_scope(resource_tracker1): + dummy_resource1 = _DummyResource("test1") + + resource_tracker2 = tracking.ResourceTracker() + with tracking.resource_tracker_scope(resource_tracker2): + dummy_resource2 = _DummyResource("test2") + + self.assertEqual(1, len(resource_tracker1.resources)) + self.assertEqual("test1", resource_tracker1.resources[0].resource_handle) + self.assertEqual(1, len(resource_tracker1.resources)) + self.assertEqual("test2", resource_tracker2.resources[0].resource_handle) + + def testNestedScopesScopes(self): + resource_tracker = tracking.ResourceTracker() + with tracking.resource_tracker_scope(resource_tracker): + resource_tracker1 = tracking.ResourceTracker() + with tracking.resource_tracker_scope(resource_tracker1): + dummy_resource1 = _DummyResource("test1") + + resource_tracker2 = tracking.ResourceTracker() + with tracking.resource_tracker_scope(resource_tracker2): + dummy_resource2 = _DummyResource("test2") + + self.assertEqual(1, len(resource_tracker1.resources)) + self.assertEqual("test1", resource_tracker1.resources[0].resource_handle) + self.assertEqual(1, len(resource_tracker1.resources)) + self.assertEqual("test2", resource_tracker2.resources[0].resource_handle) + self.assertEqual(2, len(resource_tracker.resources)) + self.assertEqual("test1", resource_tracker.resources[0].resource_handle) + self.assertEqual("test2", resource_tracker.resources[1].resource_handle) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py index edab6cc6ebb8da..f45f7445f13705 100644 --- a/tensorflow/python/training/checkpointable/util.py +++ b/tensorflow/python/training/checkpointable/util.py @@ -551,7 +551,7 @@ def _serialize_slot_variables(checkpointable_objects, node_ids, object_names): def _serialize_checkpointables( checkpointable_objects, node_ids, object_names, slot_variables, - saveables_cache): + saveables_cache, object_map): """Name non-slot `Checkpointable`s and add them to `object_graph_proto`.""" object_graph_proto = ( checkpointable_object_graph_pb2.CheckpointableObjectGraph()) @@ -569,12 +569,17 @@ def _serialize_checkpointables( object_proto = object_graph_proto.nodes.add() object_proto.slot_variables.extend(slot_variables.get(checkpointable, ())) object_name = object_names[checkpointable] + if object_map: + object_to_save = object_map.get(checkpointable, checkpointable) + else: + object_to_save = checkpointable if saveables_cache is not None: - cached_attributes = saveables_cache.setdefault(checkpointable, {}) + cached_attributes = saveables_cache.setdefault(object_to_save, {}) else: cached_attributes = None + for name, saveable_factory in ( - checkpointable._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access + object_to_save._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access attribute = object_proto.attributes.add() attribute.name = name attribute.checkpoint_key = "%s/%s/%s" % ( @@ -650,6 +655,28 @@ def _serialize_checkpointables( return named_saveables, object_graph_proto, feed_additions +def _serialize_gathered_objects( + checkpointable_objects, path_to_root, saveables_cache, object_map): + """Create SaveableObjects and protos for gathered objects.""" + object_names = _ObjectIdentityDictionary() + for obj, path in path_to_root.items(): + object_names[obj] = _object_prefix_from_path(path) + node_ids = _ObjectIdentityDictionary() + for node_id, node in enumerate(checkpointable_objects): + node_ids[node] = node_id + slot_variables = _serialize_slot_variables( + checkpointable_objects=checkpointable_objects, + node_ids=node_ids, + object_names=object_names) + return _serialize_checkpointables( + checkpointable_objects=checkpointable_objects, + node_ids=node_ids, + object_names=object_names, + slot_variables=slot_variables, + saveables_cache=saveables_cache, + object_map=object_map) + + def _serialize_object_graph(root_checkpointable, saveables_cache): """Determine checkpoint keys for variables and build a serialized graph. @@ -680,22 +707,8 @@ def _serialize_object_graph(root_checkpointable, saveables_cache): """ checkpointable_objects, path_to_root = ( _breadth_first_checkpointable_traversal(root_checkpointable)) - object_names = _ObjectIdentityDictionary() - for obj, path in path_to_root.items(): - object_names[obj] = _object_prefix_from_path(path) - node_ids = _ObjectIdentityDictionary() - for node_id, node in enumerate(checkpointable_objects): - node_ids[node] = node_id - slot_variables = _serialize_slot_variables( - checkpointable_objects=checkpointable_objects, - node_ids=node_ids, - object_names=object_names) - return _serialize_checkpointables( - checkpointable_objects=checkpointable_objects, - node_ids=node_ids, - object_names=object_names, - slot_variables=slot_variables, - saveables_cache=saveables_cache) + return _serialize_gathered_objects( + checkpointable_objects, path_to_root, saveables_cache, object_map=None) def named_saveables(root_checkpointable): @@ -808,7 +821,7 @@ def _checkpointable_custom_creator(next_creator, name, initial_value, """ def _call_next_creator_renaming_initializer(initializer, **inner_kwargs): inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which - # we don't want to propagate. + # we don't want to propagate. return next_creator( initial_value=initializer, name=name, @@ -969,6 +982,12 @@ def assert_existing_objects_matched(self): raise AssertionError( "Object not assigned a value from checkpoint: %s" % (node,)) for checkpointable_object in list_objects(self._root_checkpointable): + # Remove data structures that do not contain any variables from + # restoration checks. + if (isinstance(checkpointable_object, + data_structures.CheckpointableDataStructure) and + not checkpointable_object._checkpoint_dependencies): + continue self._checkpoint.all_python_objects.add(checkpointable_object) unused_python_objects = ( _ObjectIdentitySet(self._checkpoint.all_python_objects) @@ -1306,12 +1325,30 @@ def _gather_saveables( name=base.OBJECT_GRAPH_PROTO_KEY)) return named_saveable_objects, graph_proto, feed_additions - def freeze(self): + def freeze(self, object_map=None, to_graph=None): """Creates a `tf.train.Saver` with the current object graph frozen.""" - named_saveable_objects, _, _ = self._gather_saveables( - object_graph_tensor=None, saveable_object_cache=None) - return saver_lib.Saver( - var_list=named_saveable_objects, max_to_keep=None) + checkpointable_objects, path_to_root = ( + _breadth_first_checkpointable_traversal(self._root_checkpointable)) + if to_graph: + target_context = to_graph.as_default + else: + target_context = ops.NullContextmanager + with target_context(): + named_saveable_objects, graph_proto, _ = _serialize_gathered_objects( + checkpointable_objects, + path_to_root, + saveables_cache=None, + object_map=object_map) + with ops.device("/cpu:0"): + object_graph_tensor = constant_op.constant( + graph_proto.SerializeToString(), dtype=dtypes.string) + named_saveable_objects.append( + base.NoRestoreSaveable( + tensor=object_graph_tensor, + name=base.OBJECT_GRAPH_PROTO_KEY)) + # TODO(allenl, haoliang): Swap in a function-based saver here. + return saver_lib.Saver( + var_list=named_saveable_objects, max_to_keep=None) def _prepare_save(self, object_graph_tensor=None, diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py index 24fd42f6d2e466..19955140123afc 100644 --- a/tensorflow/python/training/checkpointable/util_test.py +++ b/tensorflow/python/training/checkpointable/util_test.py @@ -1313,6 +1313,24 @@ def test_initialize_if_not_restoring(self): train_fn() self.assertEqual(42., self.evaluate(optimizer.variables()[0])) + @test_util.run_in_graph_and_eager_modes + def test_restore_after_adding_empty_checkpointable_data_structure(self): + model = NonLayerCheckpointable() + checkpoint = checkpointable_utils.Checkpoint(model=model) + checkpoint.restore(None).initialize_or_restore() + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + save_path = checkpoint.save(checkpoint_prefix) + + del model, checkpoint + + model = NonLayerCheckpointable() + model.dict = {"a": 1} + model.list = {"b": 1} + checkpoint = checkpointable_utils.Checkpoint(model=model) + load_status = checkpoint.restore(save_path) + load_status.assert_existing_objects_matched().run_restore_ops() + class _ManualScope(tracking.Checkpointable): diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py index e12ebafba1f74a..35ed52fa1293b1 100644 --- a/tensorflow/python/training/distribute.py +++ b/tensorflow/python/training/distribute.py @@ -1254,6 +1254,10 @@ def value_container(self, value): def num_replicas(self): return 1 + @property + def num_replicas_in_sync(self): + return 1 + @property def worker_devices(self): raise RuntimeError( diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py index a9b05dcc736e7c..085b77d1d6aee7 100644 --- a/tensorflow/python/training/input_test.py +++ b/tensorflow/python/training/input_test.py @@ -1035,8 +1035,11 @@ def _testTwoThreadsHelper(self, use_dict): self.assertAllEqual([99] * len(which_b), [results[0][i] for i in which_b]) - # Some minimum level of mixing of the results of both threads. - self.assertGreater(saw_both, 1) + # We'd like to see some minimum level of mixing of the results of both + # threads, but we can't rely on fair thread scheduling, so we just log. + # self.assertGreater(saw_both, 1) + tf_logging.info("testTwoThreads%s saw both count: %s", + "Dict" if use_dict else "", saw_both) # Verify the order of results from "a" were preserved. self.assertAllEqual(all_a, np.arange(num_a)) @@ -1048,10 +1051,10 @@ def _testTwoThreadsHelper(self, use_dict): for thread in threads: thread.join() - def DISABLED_testTwoThreads(self): + def testTwoThreads(self): self._testTwoThreadsHelper(use_dict=False) - def DISABLED_testTwoThreadsDict(self): + def testTwoThreadsDict(self): self._testTwoThreadsHelper(use_dict=True) def testMismatchedDictKeys(self): @@ -1068,7 +1071,7 @@ def testMismatchedDictKeys(self): }], batch_size=8) - def DISABLED_testTwoThreadsDynamicPad(self): + def testTwoThreadsDynamicPad(self): with self.cached_session() as sess: # Two threads, the first generates (0..69, ["a"] * 1..70). num_a = 70 @@ -1128,8 +1131,10 @@ def DISABLED_testTwoThreadsDynamicPad(self): self.assertAllEqual([99] * len(which_b), [results[0][i] for i in which_b]) - # Some minimum level of mixing of the results of both threads. - self.assertGreater(saw_both, 1) + # We'd like to see some minimum level of mixing of the results of both + # threads, but we can't rely on fair thread scheduling, so we just log. + # self.assertGreater(saw_both, 1) + tf_logging.info("testTwoThreadsDynamicPad saw both count: %s", saw_both) # Verify the order of results from "a" were preserved. self.assertAllEqual( # tiled "a" with counter + 1 @@ -1143,7 +1148,7 @@ def DISABLED_testTwoThreadsDynamicPad(self): for thread in threads: thread.join() - def DISABLED_testTwoThreadsSmallerBatch(self): + def testTwoThreadsSmallerBatch(self): with self.cached_session() as sess: extra_elements = 2 # Two threads, the first generates (0..69, "a"). @@ -1229,8 +1234,10 @@ def DISABLED_testTwoThreadsSmallerBatch(self): all_a.extend([results[0][i] for i in which_a]) seen_b += len(which_b) - # Some minimum level of mixing of the results of both threads. - self.assertGreater(saw_both, 1) + # We'd like to see some minimum level of mixing of the results of both + # threads, but we can't rely on fair thread scheduling, so we just log. + # self.assertGreater(saw_both, 1) + tf_logging.info("testTwoThreadsSmallerBatch saw both count: %s", saw_both) # Verify the order of results from "a" were preserved. self.assertAllEqual(all_a, np.arange(num_a)) @@ -1242,7 +1249,7 @@ def DISABLED_testTwoThreadsSmallerBatch(self): for thread in threads: thread.join() - def DISABLED_testTwoThreadsDynamicPadSmallerBatch(self): + def testTwoThreadsDynamicPadSmallerBatch(self): with self.cached_session() as sess: extra_elements = 2 # Two threads, the first generates (0..69, ["a"] * 1..70). @@ -1322,8 +1329,11 @@ def DISABLED_testTwoThreadsDynamicPadSmallerBatch(self): all_a.extend([results[0][i] for i in which_a]) seen_b += len(which_b) - # Some minimum level of mixing of the results of both threads. - self.assertGreater(saw_both, 1) + # We'd like to see some minimum level of mixing of the results of both + # threads, but we can't rely on fair thread scheduling, so we just log. + # self.assertGreater(saw_both, 1) + tf_logging.info("testTwoThreadsDynamicPadSmallerBatch saw both count: %s", + saw_both) # Verify the order of results from "a" were preserved. self.assertAllEqual( # tiled "a" with counter + 1 diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index a479f38165ef25..0687eb5d4bc6cb 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -509,6 +509,7 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name @tf_export('train.SessionCreator') +@six.add_metaclass(abc.ABCMeta) class SessionCreator(object): """A factory for tf.Session.""" @@ -1071,8 +1072,10 @@ def close(self): if self._sess: try: self._sess.close() - except _PREEMPTION_ERRORS: - pass + except _PREEMPTION_ERRORS as e: + logging.warning('An error occurred when attempting to close the ' + 'session. This may be due to a preemption in a ' + 'connected worker or parameter server. Error: %s', e) finally: self._sess = None diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 79906317b4b6ac..8e400f2aebaeb6 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -22,6 +22,8 @@ import abc +import six + from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.framework import dtypes @@ -84,6 +86,7 @@ def _var_key(var): return var._unique_id # pylint: disable=protected-access +@six.add_metaclass(abc.ABCMeta) class _OptimizableVariable(object): """Interface for abstracting over variables in the optimizers.""" diff --git a/tensorflow/python/training/sync_replicas_optimizer.py b/tensorflow/python/training/sync_replicas_optimizer.py index 6a3756fba9fd97..fbde8fe3c2a5ee 100644 --- a/tensorflow/python/training/sync_replicas_optimizer.py +++ b/tensorflow/python/training/sync_replicas_optimizer.py @@ -31,6 +31,7 @@ from tensorflow.python.training import queue_runner from tensorflow.python.training import session_manager from tensorflow.python.training import session_run_hook +from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export @@ -39,7 +40,7 @@ # rate according to the number of replicas. This change is introduced to be # consistent with how gradients are aggregated (averaged) within a batch in a # replica. -@tf_export("train.SyncReplicasOptimizer") +@tf_export(v1=["train.SyncReplicasOptimizer"]) class SyncReplicasOptimizer(optimizer.Optimizer): """Class to synchronize, aggregate gradients and pass them to the optimizer. @@ -139,6 +140,12 @@ class SyncReplicasOptimizer(optimizer.Optimizer): ``` """ + @deprecation.deprecated( + None, + "The `SyncReplicaOptimizer` is deprecated. For synchrononous training, " + "please use [Distribution Strategies](https://github.com/tensorflow/" + "tensorflow/tree/master/tensorflow/contrib/distribute).", + warn_once=True) def __init__(self, opt, replicas_to_aggregate, diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py index a1870dd9de74c4..a7a07babfe12b3 100644 --- a/tensorflow/python/util/tf_export.py +++ b/tensorflow/python/util/tf_export.py @@ -217,5 +217,4 @@ def export_constant(self, module_name, name): tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME) -estimator_export = functools.partial( - api_export, api_name=ESTIMATOR_API_NAME, allow_multiple_exports=True) +estimator_export = functools.partial(api_export, api_name=ESTIMATOR_API_NAME) diff --git a/tensorflow/stream_executor/cuda/cuda_blas.h b/tensorflow/stream_executor/cuda/cuda_blas.h index 42b3fde5b0816f..0fb05089d7530a 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.h +++ b/tensorflow/stream_executor/cuda/cuda_blas.h @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/stream_executor/blas.h" #include "tensorflow/stream_executor/host_or_device_scalar.h" -#include "tensorflow/stream_executor/lib/stringpiece.h" #include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/thread_annotations.h" diff --git a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc index 90449b5d5d7fb6..6af71b6c9d1941 100644 --- a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc +++ b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc @@ -39,14 +39,13 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "tensorflow/stream_executor/lib/error.h" -#include "tensorflow/stream_executor/lib/inlined_vector.h" #include "tensorflow/stream_executor/lib/numbers.h" #include "tensorflow/stream_executor/lib/process_state.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/str_util.h" -#include "tensorflow/stream_executor/lib/stringpiece.h" #include "tensorflow/stream_executor/lib/stringprintf.h" #include "tensorflow/stream_executor/platform/logging.h" @@ -363,7 +362,7 @@ port::StatusOr Diagnostician::FindKernelDriverVersion() { } static const int kContentsSize = 1024; - port::InlinedVector contents(kContentsSize); + absl::InlinedVector contents(kContentsSize); size_t retcode = fread(contents.begin(), 1, kContentsSize - 2, driver_version_file); if (retcode < kContentsSize - 1) { diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index e5baf779debd9b..19397c7dbf21c3 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -37,7 +37,6 @@ limitations under the License. #include "tensorflow/stream_executor/lib/error.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/mathutil.h" -#include "tensorflow/stream_executor/lib/stringpiece.h" #include "tensorflow/stream_executor/lib/threadpool.h" #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/plugin_registry.h" @@ -46,6 +45,7 @@ limitations under the License. #include "tensorflow/stream_executor/stream_executor_pimpl.h" // clang-format off #include "cuda/include/cudnn.h" +#include "absl/strings/string_view.h" // clang-format on namespace stream_executor { @@ -313,20 +313,18 @@ port::StatusOr GetCudnnProperty(libraryPropertyType type) { return value; } -cudnnRNNAlgo_t ToCudnnRNNAlgo(const dnn::AlgorithmDesc& algorithm) { - if (algorithm.is_default()) { +cudnnRNNAlgo_t ToCudnnRNNAlgo(absl::optional algorithm) { + if (!algorithm.has_value()) { return CUDNN_RNN_ALGO_STANDARD; - } else { - cudnnRNNAlgo_t algo = static_cast(algorithm.algo_id()); - switch (algo) { - case CUDNN_RNN_ALGO_STANDARD: - case CUDNN_RNN_ALGO_PERSIST_STATIC: - case CUDNN_RNN_ALGO_PERSIST_DYNAMIC: - return algo; - default: - LOG(FATAL) << "Unsupported Cudnn RNN algorithm: " - << algorithm.algo_id(); - } + } + cudnnRNNAlgo_t algo = static_cast(algorithm->algo_id()); + switch (algo) { + case CUDNN_RNN_ALGO_STANDARD: + case CUDNN_RNN_ALGO_PERSIST_STATIC: + case CUDNN_RNN_ALGO_PERSIST_DYNAMIC: + return algo; + default: + LOG(FATAL) << "Unsupported Cudnn RNN algorithm: " << algorithm->algo_id(); } } @@ -1072,10 +1070,9 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { // in profile mode, which is run with algorithms returned from // GetRnnAlgorithms() (which are non-default and explicitly set whether to // use tensor ops). - if (RnnTensorOpMathEnabled() && - !algorithm_config.algorithm().is_default()) { + if (RnnTensorOpMathEnabled() && algorithm_config.algorithm().has_value()) { cudnnMathType_t math_type = - algorithm_config.algorithm().tensor_ops_enabled() + algorithm_config.algorithm()->tensor_ops_enabled() ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH; CHECK_CUDNN_OK(cudnnSetRNNMatrixMathType(rnn_desc.get(), math_type)); @@ -1513,7 +1510,7 @@ port::Status CudnnSupport::DoRnnForwardImpl( if (!timer->Stop(AsCUDAStream(stream))) { return port::Status(port::error::INTERNAL, "Failed to stop timer"); } - auto algo_desc = rnn_desc.algorithm_config().algorithm(); + auto algo_desc = *rnn_desc.algorithm_config().algorithm(); output_profile_result->set_algorithm(algo_desc); output_profile_result->set_elapsed_time_in_ms( timer->GetElapsedMilliseconds()); @@ -1616,7 +1613,7 @@ port::Status CudnnSupport::DoRnnBackwardImpl( if (!timer->Stop(AsCUDAStream(stream))) { return port::Status(port::error::INTERNAL, "Failed to stop timer"); } - auto algo_desc = rnn_desc.algorithm_config().algorithm(); + auto algo_desc = *rnn_desc.algorithm_config().algorithm(); output_profile_result->set_algorithm(algo_desc); output_profile_result->set_elapsed_time_in_ms( timer->GetElapsedMilliseconds()); @@ -2016,12 +2013,13 @@ port::StatusOr> AllocateCudnnConvolutionForwardWorkspace( Stream* stream, const CudnnHandle& cudnn, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv, - const CudnnTensorDescriptor& output_nd, dnn::AlgorithmDesc* algorithm_desc, + const CudnnTensorDescriptor& output_nd, + const dnn::AlgorithmDesc& algorithm_desc, ScratchAllocator* scratch_allocator) { // TODO(csigg): This has side effects on the convolution descriptor. It is // functionally correct because the convolution is run with the algorithm of // the last call to this function, but should be fixed anyway. - conv.set_use_tensor_op_math(algorithm_desc->tensor_ops_enabled()); + conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); // Query the size of the workspace and allocate it. size_t size_in_bytes; @@ -2029,14 +2027,9 @@ port::StatusOr> AllocateCudnnConvolutionForwardWorkspace( cudnn.handle(), /*xDesc=*/input_nd.handle(), /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(), - /*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(*algorithm_desc), + /*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(algorithm_desc), /*sizeInBytes=*/&size_in_bytes)); - if (TF_PREDICT_FALSE(!algorithm_desc)) { - return port::Status(port::error::INVALID_ARGUMENT, - "No AlgorithmDesc provided"); - } - algorithm_desc->set_scratch_size(size_in_bytes); int64 size_in_bytes_int64 = size_in_bytes; if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) { @@ -2063,12 +2056,13 @@ AllocateCudnnConvolutionBackwardDataWorkspace( Stream* stream, const CudnnHandle& cudnn, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv, - const CudnnTensorDescriptor& output_nd, dnn::AlgorithmDesc* algorithm_desc, + const CudnnTensorDescriptor& output_nd, + const dnn::AlgorithmDesc& algorithm_desc, ScratchAllocator* scratch_allocator) { // TODO(csigg): This has side effects on the convolution descriptor. It is // functionally correct because the convolution is run with the algorithm of // the last call to this function, but should be fixed anyway. - conv.set_use_tensor_op_math(algorithm_desc->tensor_ops_enabled()); + conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); // Query the size of the workspace and allocate it. size_t size_in_bytes; @@ -2078,14 +2072,9 @@ AllocateCudnnConvolutionBackwardDataWorkspace( /*dyDesc=*/output_nd.handle(), /*convDesc=*/conv.handle(), /*dxDesc=*/input_nd.handle(), - /*algo=*/ToConvBackwardDataAlgo(*algorithm_desc), + /*algo=*/ToConvBackwardDataAlgo(algorithm_desc), /*sizeInBytes=*/&size_in_bytes)); - if (TF_PREDICT_FALSE(!algorithm_desc)) { - return port::Status(port::error::INVALID_ARGUMENT, - "No AlgorithmDesc provided"); - } - algorithm_desc->set_scratch_size(size_in_bytes); int64 size_in_bytes_int64 = size_in_bytes; if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) { @@ -2112,12 +2101,13 @@ AllocateCudnnConvolutionBackwardFilterWorkspace( Stream* stream, const CudnnHandle& cudnn, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv, - const CudnnTensorDescriptor& output_nd, dnn::AlgorithmDesc* algorithm_desc, + const CudnnTensorDescriptor& output_nd, + const dnn::AlgorithmDesc& algorithm_desc, ScratchAllocator* scratch_allocator) { // TODO(csigg): This has side effects on the convolution descriptor. It is // functionally correct because the convolution is run with the algorithm of // the last call to this function, but should be fixed anyway. - conv.set_use_tensor_op_math(algorithm_desc->tensor_ops_enabled()); + conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); // Query the size of the workspace and allocate it. size_t size_in_bytes; @@ -2127,14 +2117,9 @@ AllocateCudnnConvolutionBackwardFilterWorkspace( /*dyDesc=*/output_nd.handle(), /*convDesc=*/conv.handle(), /*gradDesc=*/filter.handle(), - /*algo=*/ToConvBackwardFilterAlgo(*algorithm_desc), + /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc), /*sizeInBytes=*/&size_in_bytes)); - if (TF_PREDICT_FALSE(!algorithm_desc)) { - return port::Status(port::error::INVALID_ARGUMENT, - "No AlgorithmDesc provided"); - } - algorithm_desc->set_scratch_size(size_in_bytes); int64 size_in_bytes_int64 = size_in_bytes; if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) { @@ -2163,8 +2148,8 @@ port::StatusOr GetCudnnConvolutionForwardAlgorithm( const CudnnConvolutionDescriptor& conv, const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator, DeviceMemory* scratch) { - dnn::AlgorithmDesc algo_desc = algorithm_config.algorithm(); - if (algorithm_config.algorithm().is_default()) { + absl::optional algo_desc = algorithm_config.algorithm(); + if (!algo_desc.has_value()) { // Pick fastest algorithm within memory limit according to cuDNN's // heuristics. bool specify_workspace_limit = scratch_allocator != nullptr; @@ -2176,33 +2161,33 @@ port::StatusOr GetCudnnConvolutionForwardAlgorithm( GetCudnnConvolutionForwardAlgo( cudnn, input_nd, filter, conv, output_nd, specify_workspace_limit, memory_limit_bytes)); - algo_desc = dnn::AlgorithmDesc( - algo, algorithm_config.algorithm().tensor_ops_enabled()); + algo_desc = dnn::AlgorithmDesc(algo, /*use_tensor_ops=*/true); } auto scratch_or = AllocateCudnnConvolutionForwardWorkspace( - stream, cudnn, input_nd, filter, conv, output_nd, &algo_desc, + stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc, scratch_allocator); if (scratch_or.ok()) { *scratch = scratch_or.ValueOrDie(); - return algo_desc; + return *algo_desc; } + algo_desc = algorithm_config.algorithm_no_scratch(); + // Failed to allocate workspace for the first algorithm, fall back to the // no_scratch algorithm. - if (algorithm_config.algorithm_no_scratch().is_default()) { + if (!algo_desc.has_value()) { return port::Status( port::error::INVALID_ARGUMENT, "The primary convolution algorithm failed memory allocation, " "while a secondary algorithm is not provided."); } - algo_desc = algorithm_config.algorithm_no_scratch(); SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionForwardWorkspace( stream, cudnn, input_nd, filter, conv, - output_nd, &algo_desc, scratch_allocator)); - return algo_desc; + output_nd, *algo_desc, scratch_allocator)); + return *algo_desc; } port::StatusOr GetCudnnConvolutionBackwardDataAlgorithm( @@ -2212,8 +2197,8 @@ port::StatusOr GetCudnnConvolutionBackwardDataAlgorithm( const CudnnConvolutionDescriptor& conv, const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator, DeviceMemory* scratch) { - dnn::AlgorithmDesc algo_desc = algorithm_config.algorithm(); - if (algorithm_config.algorithm().is_default()) { + absl::optional algo_desc = algorithm_config.algorithm(); + if (!algo_desc.has_value()) { // Pick fastest algorithm within memory limit according to cuDNN's // heuristics. bool specify_workspace_limit = scratch_allocator != nullptr; @@ -2225,33 +2210,33 @@ port::StatusOr GetCudnnConvolutionBackwardDataAlgorithm( GetCudnnConvolutionBackwardDataAlgo( cudnn, input_nd, filter, conv, output_nd, specify_workspace_limit, memory_limit_bytes)); - algo_desc = dnn::AlgorithmDesc( - algo, algorithm_config.algorithm().tensor_ops_enabled()); + algo_desc = dnn::AlgorithmDesc(algo, /*use_tensor_ops=*/true); } auto scratch_or = AllocateCudnnConvolutionBackwardDataWorkspace( - stream, cudnn, input_nd, filter, conv, output_nd, &algo_desc, + stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc, scratch_allocator); if (scratch_or.ok()) { *scratch = scratch_or.ValueOrDie(); - return algo_desc; + return *algo_desc; } + algo_desc = algorithm_config.algorithm_no_scratch(); + // Failed to allocate workspace for the first algorithm, fall back to the // no_scratch algorithm. - if (algorithm_config.algorithm_no_scratch().is_default()) { + if (!algo_desc.has_value()) { return port::Status( port::error::INVALID_ARGUMENT, "The primary convolution algorithm failed memory allocation, " "while a secondary algorithm is not provided."); } - algo_desc = algorithm_config.algorithm_no_scratch(); SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardDataWorkspace( stream, cudnn, input_nd, filter, conv, - output_nd, &algo_desc, scratch_allocator)); - return algo_desc; + output_nd, *algo_desc, scratch_allocator)); + return *algo_desc; } port::StatusOr GetCudnnConvolutionBackwardFilterAlgorithm( @@ -2261,8 +2246,8 @@ port::StatusOr GetCudnnConvolutionBackwardFilterAlgorithm( const CudnnConvolutionDescriptor& conv, const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator, DeviceMemory* scratch) { - dnn::AlgorithmDesc algo_desc = algorithm_config.algorithm(); - if (algorithm_config.algorithm().is_default()) { + absl::optional algo_desc = algorithm_config.algorithm(); + if (!algo_desc.has_value()) { // Pick fastest algorithm within memory limit according to cuDNN's // heuristics. bool specify_workspace_limit = scratch_allocator != nullptr; @@ -2274,33 +2259,33 @@ port::StatusOr GetCudnnConvolutionBackwardFilterAlgorithm( GetCudnnConvolutionBackwardFilterAlgo( cudnn, input_nd, filter, conv, output_nd, specify_workspace_limit, memory_limit_bytes)); - algo_desc = dnn::AlgorithmDesc( - algo, algorithm_config.algorithm().tensor_ops_enabled()); + algo_desc = dnn::AlgorithmDesc(algo, /*use_tensor_ops=*/true); } auto scratch_or = AllocateCudnnConvolutionBackwardFilterWorkspace( - stream, cudnn, input_nd, filter, conv, output_nd, &algo_desc, + stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc, scratch_allocator); if (scratch_or.ok()) { *scratch = scratch_or.ValueOrDie(); - return algo_desc; + return *algo_desc; } + algo_desc = algorithm_config.algorithm_no_scratch(); + // Failed to allocate workspace for the first algorithm, fall back to the // no_scratch algorithm. - if (algorithm_config.algorithm_no_scratch().is_default()) { + if (!algo_desc.has_value()) { return port::Status( port::error::INVALID_ARGUMENT, "The primary convolution algorithm failed memory allocation, " "while a secondary algorithm is not provided."); } - algo_desc = algorithm_config.algorithm_no_scratch(); SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardFilterWorkspace( stream, cudnn, input_nd, filter, conv, - output_nd, &algo_desc, scratch_allocator)); - return algo_desc; + output_nd, *algo_desc, scratch_allocator)); + return *algo_desc; } // A helper class to set env-vars and choose options for cudnn-related @@ -2317,7 +2302,7 @@ class CudnnEnvVar { static bool IsEnabledImpl() { const char* tf_env_var_val = getenv(EnvVar::kName); if (tf_env_var_val != nullptr) { - port::StringPiece tf_env_var_val_str(tf_env_var_val); + absl::string_view tf_env_var_val_str(tf_env_var_val); if (tf_env_var_val_str == "0") { return false; } @@ -2545,6 +2530,7 @@ port::Status CudnnSupport::DoConvolveImpl( output_profile_result->set_algorithm(algo_desc); output_profile_result->set_elapsed_time_in_ms( timer->GetElapsedMilliseconds()); + output_profile_result->set_scratch_size(scratch.size()); } return port::Status::OK(); @@ -2661,6 +2647,7 @@ port::Status CudnnSupport::DoFusedConvolveImpl( output_profile_result->set_algorithm(algo_desc); output_profile_result->set_elapsed_time_in_ms( timer->GetElapsedMilliseconds()); + output_profile_result->set_scratch_size(scratch.size()); } return port::Status::OK(); @@ -3177,10 +3164,8 @@ port::Status CudnnSupport::DoConvolveBackwardDataImpl( // Cudnn 7.1.4 has a bug if the workspace of the following convolution is not // zero-initialized, nvbugs/2254619. if (CUDNN_VERSION >= 7000 && CUDNN_VERSION < 7300 && - algorithm_config.algorithm().algo_id() == - CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 && - cudnn_type == CUDNN_DATA_HALF && - algorithm_config.algorithm().tensor_ops_enabled() && + algo_desc.algo_id() == CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 && + cudnn_type == CUDNN_DATA_HALF && algo_desc.tensor_ops_enabled() && input_descriptor.layout() == dnn::DataLayout::kBatchYXDepth && filter_descriptor.layout() == dnn::FilterLayout::kOutputInputYX && output_descriptor.layout() == dnn::DataLayout::kBatchDepthYX && @@ -3210,6 +3195,7 @@ port::Status CudnnSupport::DoConvolveBackwardDataImpl( output_profile_result->set_algorithm(algo_desc); output_profile_result->set_elapsed_time_in_ms( timer->GetElapsedMilliseconds()); + output_profile_result->set_scratch_size(scratch.size()); } return port::Status::OK(); @@ -3371,8 +3357,7 @@ port::Status CudnnSupport::DoConvolveBackwardFilterImpl( // // See nvbugs/2379553. if (CUDNN_VERSION >= 7100 && CUDNN_VERSION < 7300 && - algorithm_config.algorithm().algo_id() == - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 && + algo_desc.algo_id() == CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 && cudnn_type == CUDNN_DATA_HALF && input_descriptor.layout() == dnn::DataLayout::kBatchYXDepth && filter_descriptor.layout() == dnn::FilterLayout::kOutputYXInput && @@ -3403,6 +3388,7 @@ port::Status CudnnSupport::DoConvolveBackwardFilterImpl( output_profile_result->set_algorithm(algo_desc); output_profile_result->set_elapsed_time_in_ms( timer->GetElapsedMilliseconds()); + output_profile_result->set_scratch_size(scratch.size()); } return port::Status::OK(); diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc index a674814190425a..b34d1f722eaf60 100644 --- a/tensorflow/stream_executor/cuda/cuda_driver.cc +++ b/tensorflow/stream_executor/cuda/cuda_driver.cc @@ -22,12 +22,12 @@ limitations under the License. #include #include "absl/base/casts.h" +#include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" #include "tensorflow/stream_executor/lib/env.h" #include "tensorflow/stream_executor/lib/error.h" #include "tensorflow/stream_executor/lib/human_readable.h" -#include "tensorflow/stream_executor/lib/inlined_vector.h" #include "tensorflow/stream_executor/lib/notification.h" #include "tensorflow/stream_executor/lib/ptr_util.h" #include "tensorflow/stream_executor/lib/stacktrace.h" @@ -336,7 +336,7 @@ static port::Status InternalInit() { /* static */ bool CUDADriver::GetDeviceName(CUdevice device, string *device_name) { static const size_t kCharLimit = 64; - port::InlinedVector chars(kCharLimit); + absl::InlinedVector chars(kCharLimit); CUresult res = cuDeviceGetName(chars.begin(), kCharLimit - 1, device); if (res != CUDA_SUCCESS) { LOG(ERROR) << "failed to get device name for " << device << ": " @@ -575,8 +575,8 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) { static const unsigned int kLogBufferBytesLimit = 1024; unsigned int error_log_buffer_bytes = kLogBufferBytesLimit; unsigned int info_log_buffer_bytes = kLogBufferBytesLimit; - port::InlinedVector error_log_buffer(error_log_buffer_bytes); - port::InlinedVector info_log_buffer(info_log_buffer_bytes); + absl::InlinedVector error_log_buffer(error_log_buffer_bytes); + absl::InlinedVector info_log_buffer(info_log_buffer_bytes); bool log_verbose = true; CUjit_option options[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER, @@ -1466,7 +1466,7 @@ static port::StatusOr GetSimpleAttribute(CUdevice device, /* static */ string CUDADriver::GetPCIBusID(CUdevice device) { string pci_bus_id; static const int kBufferSize = 64; - port::InlinedVector chars(kBufferSize); + absl::InlinedVector chars(kBufferSize); chars[kBufferSize - 1] = '\0'; CUresult res = cuDeviceGetPCIBusId(chars.begin(), kBufferSize - 1, device); if (res != CUDA_SUCCESS) { diff --git a/tensorflow/stream_executor/cuda/cuda_fft.cc b/tensorflow/stream_executor/cuda/cuda_fft.cc index 013ca2d7f6d7f9..cbf388a0f892b0 100644 --- a/tensorflow/stream_executor/cuda/cuda_fft.cc +++ b/tensorflow/stream_executor/cuda/cuda_fft.cc @@ -332,6 +332,7 @@ std::unique_ptr CUDAFft::Create1dPlan(Stream *stream, uint64 num_x, // TODO(yangzihao): In the future, send error msg back to TensorFlow // so it can fail gracefully, if (!status.ok()) { + LOG(ERROR) << "Plan Parameters: num_x: " << num_x; LOG(FATAL) << "failed to initialize cufft 1d plan: " << status.error_message(); } @@ -346,6 +347,7 @@ std::unique_ptr CUDAFft::Create1dPlanWithScratchAllocator( port::Status status = fft_plan_ptr->Initialize(parent_, stream, 1, elem_count, type, scratch_allocator); if (!status.ok()) { + LOG(ERROR) << "Plan Parameters: num_x: " << num_x; LOG(FATAL) << "failed to initialize cufft 1d plan with customized allocator: " << status.error_message(); @@ -361,6 +363,7 @@ std::unique_ptr CUDAFft::Create2dPlan(Stream *stream, uint64 num_x, port::Status status = fft_plan_ptr->Initialize( parent_, stream, 1, elem_count, type, /*scratch_allocator=*/nullptr); if (!status.ok()) { + LOG(ERROR) << "Plan Parameters: num_x: " << num_x << " num_y: " << num_y; LOG(FATAL) << "failed to initialize cufft 2d plan: " << status.error_message(); } @@ -375,6 +378,7 @@ std::unique_ptr CUDAFft::Create2dPlanWithScratchAllocator( port::Status status = fft_plan_ptr->Initialize(parent_, stream, 2, elem_count, type, scratch_allocator); if (!status.ok()) { + LOG(ERROR) << "Plan Parameters: num_x: " << num_x << " num_y: " << num_y; LOG(FATAL) << "failed to initialize cufft 2d plan with customized allocator: " << status.error_message(); @@ -391,6 +395,8 @@ std::unique_ptr CUDAFft::Create3dPlan(Stream *stream, uint64 num_x, port::Status status = fft_plan_ptr->Initialize( parent_, stream, 3, elem_count, type, /*scratch_allocator=*/nullptr); if (!status.ok()) { + LOG(ERROR) << "Plan Parameters: num_x: " << num_x << " num_y: " << num_y + << " num_z: " << num_z; LOG(FATAL) << "failed to initialize cufft 3d plan: " << status.error_message(); } @@ -405,6 +411,8 @@ std::unique_ptr CUDAFft::Create3dPlanWithScratchAllocator( port::Status status = fft_plan_ptr->Initialize(parent_, stream, 3, elem_count, type, scratch_allocator); if (!status.ok()) { + LOG(ERROR) << "Plan Parameters: num_x: " << num_x << " num_y: " << num_y + << " num_z: " << num_z; LOG(FATAL) << "failed to initialize cufft 3d plan with customized allocator: " << status.error_message(); @@ -423,6 +431,15 @@ std::unique_ptr CUDAFft::CreateBatchedPlan( input_distance, output_embed, output_stride, output_distance, type, batch_count, /*scratch_allocator=*/nullptr); if (!status.ok()) { + LOG(ERROR) << "Initialize Params: rank: " << rank + << " elem_count: " << *elem_count + << " input_embed: " << *input_embed + << " input_stride: " << input_stride + << " input_distance: " << input_distance + << " output_embed: " << *output_embed + << " output_stride: " << output_stride + << " output_distance: " << output_distance + << " batch_count: " << batch_count; LOG(FATAL) << "failed to initialize batched cufft plan: " << status.error_message(); } @@ -441,6 +458,15 @@ std::unique_ptr CUDAFft::CreateBatchedPlanWithScratchAllocator( input_distance, output_embed, output_stride, output_distance, type, batch_count, scratch_allocator); if (!status.ok()) { + LOG(ERROR) << "Initialize Params: rank: " << rank + << " elem_count: " << *elem_count + << " input_embed: " << *input_embed + << " input_stride: " << input_stride + << " input_distance: " << input_distance + << " output_embed: " << *output_embed + << " output_stride: " << output_stride + << " output_distance: " << output_distance + << " batch_count: " << batch_count; LOG(FATAL) << "failed to initialize batched cufft plan with customized allocator: " << status.error_message(); diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc index d850a45a78642a..ad9154226c4634 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc @@ -25,6 +25,7 @@ limitations under the License. #include #endif #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" #include "tensorflow/stream_executor/cuda/cuda_driver.h" #include "tensorflow/stream_executor/cuda/cuda_event.h" @@ -146,7 +147,7 @@ port::Status CUDAExecutor::Init(int device_ordinal, } bool CUDAExecutor::FindOnDiskForComputeCapability( - port::StringPiece filename, port::StringPiece canonical_suffix, + absl::string_view filename, absl::string_view canonical_suffix, string *found_filename) const { if (cc_major_ == 0 && cc_minor_ == 0) { return false; diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h index 53b2a29ae7554c..90bf1c0242fb24 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h @@ -25,6 +25,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/stream_executor/cuda/cuda_kernel.h" #include "tensorflow/stream_executor/event.h" #include "tensorflow/stream_executor/lib/status.h" @@ -234,8 +235,8 @@ class CUDAExecutor : public internal::StreamExecutorInterface { // filename by looking for compute-capability-specific suffixed versions; i.e. // looking for "foo.ptx" will check to see if "foo.ptx.cc30.ptx" is present if // we're on a compute capability 3.0 machine. - bool FindOnDiskForComputeCapability(port::StringPiece filename, - port::StringPiece canonical_suffix, + bool FindOnDiskForComputeCapability(absl::string_view filename, + absl::string_view canonical_suffix, string *found_filename) const; // Host callback landing routine invoked by CUDA. diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc index 8a5bcf428037d5..3d8e691ab28c1b 100644 --- a/tensorflow/stream_executor/dnn.cc +++ b/tensorflow/stream_executor/dnn.cc @@ -219,8 +219,15 @@ std::vector ReorderDims(const std::vector& input, // -- AlgorithmConfig string AlgorithmConfig::ToString() const { - return absl::StrCat(algorithm_.algo_id(), ", ", - algorithm_no_scratch_.algo_id()); + AlgorithmDesc::Index algo_id = -1; + if (algorithm().has_value()) { + algo_id = algorithm()->algo_id(); + } + AlgorithmDesc::Index algo_id_no_scratch = -1; + if (algorithm_no_scratch().has_value()) { + algo_id_no_scratch = algorithm_no_scratch()->algo_id(); + } + return absl::StrCat(algo_id, ", ", algo_id_no_scratch); } // -- BatchDescriptor @@ -441,7 +448,6 @@ ConvolutionDescriptor::ConvolutionDescriptor(int ndims) : zero_padding_(ndims, 0), filter_strides_(ndims, 1), dilation_rates_(ndims, 1), - pad_alignment_(PadAlignment::kDefault), group_count_(1), ndims_(ndims) {} @@ -463,7 +469,7 @@ string ConvolutionDescriptor::ToString() const { return port::Printf( "{zero_padding: %s pad_alignment: %s filter_strides: %s dilation_rates: " "%s}", - padding.c_str(), PadAlignmentString(pad_alignment_).c_str(), + padding.c_str(), PadAlignmentString(pad_alignment()).c_str(), strides.c_str(), dilations.c_str()); } diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index 6dbf855a25ad6b..c934301829daa2 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -27,6 +27,7 @@ limitations under the License. #include #include +#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/lib/array_slice.h" @@ -547,10 +548,6 @@ class ConvolutionDescriptor { SetDim(&dilation_rates_, dim, value); return *this; } - ConvolutionDescriptor& set_pad_alignment(PadAlignment pad_alignment) { - pad_alignment_ = pad_alignment; - return *this; - } ConvolutionDescriptor& set_group_count(int group_count) { group_count_ = group_count; return *this; @@ -577,7 +574,9 @@ class ConvolutionDescriptor { int zero_padding(DimIndex dim) const { return GetDim(zero_padding_, dim); } int filter_stride(DimIndex dim) const { return GetDim(filter_strides_, dim); } int dilation_rate(DimIndex dim) const { return GetDim(dilation_rates_, dim); } - PadAlignment pad_alignment() const { return pad_alignment_; } + // TODO(timshen): remove this function. No users of this class is setting a + // non-default pad alignment. + PadAlignment pad_alignment() const { return PadAlignment::kDefault; } int group_count() const { return group_count_; } int ndims() const { return ndims_; } @@ -590,7 +589,6 @@ class ConvolutionDescriptor { std::vector zero_padding_; std::vector filter_strides_; std::vector dilation_rates_; - PadAlignment pad_alignment_; int group_count_; int ndims_; // TODO(leary) cudnn provides these fields, but need to characterize what @@ -716,31 +714,21 @@ class PoolingDescriptor { class AlgorithmDesc { public: typedef int64 Index; - AlgorithmDesc() - : algo_(kDefaultAlgorithm), tensor_ops_enabled_(true), scratch_size_(0) {} AlgorithmDesc(Index a, bool use_tensor_ops) - : algo_(a), tensor_ops_enabled_(use_tensor_ops), scratch_size_(0) {} - AlgorithmDesc(Index a, bool use_tensor_ops, size_t scratch_size) - : algo_(a), - tensor_ops_enabled_(use_tensor_ops), - scratch_size_(scratch_size) {} - bool is_default() const { return algo_ == kDefaultAlgorithm; } + : algo_(a), tensor_ops_enabled_(use_tensor_ops) { + DCHECK_NE(a, -1); + } bool tensor_ops_enabled() const { return tensor_ops_enabled_; } Index algo_id() const { return algo_; } - size_t scratch_size() const { return scratch_size_; } - void set_scratch_size(size_t val) { scratch_size_ = val; } bool operator==(const AlgorithmDesc& other) const { return this->algo_ == other.algo_ && - this->tensor_ops_enabled_ == other.tensor_ops_enabled_ && - this->scratch_size_ == other.scratch_size_; + this->tensor_ops_enabled_ == other.tensor_ops_enabled_; } uint64 hash() const; private: - enum { kDefaultAlgorithm = -1 }; Index algo_; bool tensor_ops_enabled_; - size_t scratch_size_; }; // Describes the result from a perf experiment. @@ -751,17 +739,25 @@ class AlgorithmDesc { class ProfileResult { public: bool is_valid() const { - return (!algorithm_.is_default() && - elapsed_time_in_ms_ != std::numeric_limits::max()); + return algorithm_.has_value() && + elapsed_time_in_ms() != std::numeric_limits::max(); } - AlgorithmDesc algorithm() const { return algorithm_; } + + AlgorithmDesc algorithm() const { return *algorithm_; } void set_algorithm(AlgorithmDesc val) { algorithm_ = val; } + float elapsed_time_in_ms() const { return elapsed_time_in_ms_; } void set_elapsed_time_in_ms(float val) { elapsed_time_in_ms_ = val; } + size_t scratch_size() const { return scratch_size_; } + void set_scratch_size(size_t val) { scratch_size_ = val; } + private: - AlgorithmDesc algorithm_; + absl::optional algorithm_; float elapsed_time_in_ms_ = std::numeric_limits::max(); + // The scratch size algorithm_ requires. Currently it's only populated by + // convolutions. + size_t scratch_size_ = 0; }; // Describes the configuration for the algorithms that will used. @@ -776,9 +772,11 @@ class AlgorithmConfig { explicit AlgorithmConfig(AlgorithmDesc algorithm) : algorithm_(algorithm) {} AlgorithmConfig(AlgorithmDesc algorithm, AlgorithmDesc algorithm_no_scratch) : algorithm_(algorithm), algorithm_no_scratch_(algorithm_no_scratch) {} - AlgorithmDesc algorithm() const { return algorithm_; } + absl::optional algorithm() const { return algorithm_; } void set_algorithm(AlgorithmDesc val) { algorithm_ = val; } - AlgorithmDesc algorithm_no_scratch() const { return algorithm_no_scratch_; } + absl::optional algorithm_no_scratch() const { + return algorithm_no_scratch_; + } void set_algorithm_no_scratch(AlgorithmDesc val) { algorithm_no_scratch_ = val; } @@ -792,8 +790,8 @@ class AlgorithmConfig { string ToString() const; private: - AlgorithmDesc algorithm_; - AlgorithmDesc algorithm_no_scratch_; + absl::optional algorithm_; + absl::optional algorithm_no_scratch_; }; // Describes a local response normalization (LRN). LRN is used e.g. in @@ -920,6 +918,23 @@ class VersionInfo { // Suite of operations typically used for implementing Deep/Convolutional Neural // Nets. Note: A false return value of an operation indicates the // implementation is not available. +// +// TODO(b/118763918): this class (or rather dispatch table) has several +// problems: +// * Some overloads are missing. Ideally we want to have template virtual +// functions while the template arguments is a closed set. However, we don't +// get that from the language. +// * The API is a union of cuDNN and another private backend. Only 10% of the +// functions are actually implemented by both backends, the rest are +// actually backend-specific. The massive interface creates extra mental +// burden. +// * Poor error handling: the API should return Status objects. +// +// Things worth trying: +// * Move functions that are not actually common back to the backends. Then, +// callers may use dynamic_cast to access specific backends. This may not be +// that hard, as many of the callers are Stream::ThenXxx functions. +// * Change all the returned bools to Status. class DnnSupport { public: DnnSupport() {} diff --git a/tensorflow/stream_executor/dso_loader.cc b/tensorflow/stream_executor/dso_loader.cc index a994ef809eaaa2..6dda5d63155d8f 100644 --- a/tensorflow/stream_executor/dso_loader.cc +++ b/tensorflow/stream_executor/dso_loader.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/port.h" #if !defined(PLATFORM_GOOGLE) +#include "absl/strings/string_view.h" #include "cuda/cuda_config.h" #endif @@ -119,12 +120,12 @@ static mutex& GetRpathMutex() { return *mu; } -/* static */ void DsoLoader::RegisterRpath(port::StringPiece path) { +/* static */ void DsoLoader::RegisterRpath(absl::string_view path) { mutex_lock lock{GetRpathMutex()}; GetRpaths()->emplace_back(path); } -/* static */ port::Status DsoLoader::GetDsoHandle(port::StringPiece path, +/* static */ port::Status DsoLoader::GetDsoHandle(absl::string_view path, void** dso_handle, LoadKind load_kind) { if (load_kind != LoadKind::kLocal) { @@ -190,13 +191,13 @@ static std::vector* CreatePrimordialRpaths() { #endif } -/* static */ string DsoLoader::FindDsoPath(port::StringPiece library_name, - port::StringPiece runfiles_relpath) { +/* static */ string DsoLoader::FindDsoPath(absl::string_view library_name, + absl::string_view runfiles_relpath) { // Keep a record of the paths we attempted so we can dump out meaningful // diagnostics if no path is found. std::vector attempted; - using StringPieces = std::vector; + using StringPieces = std::vector; string candidate; // Otherwise, try binary-plus-rpath locations. diff --git a/tensorflow/stream_executor/dso_loader.h b/tensorflow/stream_executor/dso_loader.h index 9ee081cb3d64e8..f063b68d6058f7 100644 --- a/tensorflow/stream_executor/dso_loader.h +++ b/tensorflow/stream_executor/dso_loader.h @@ -22,9 +22,9 @@ limitations under the License. #include "tensorflow/stream_executor/platform/port.h" #include +#include "absl/strings/string_view.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/statusor.h" -#include "tensorflow/stream_executor/lib/stringpiece.h" #include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform/mutex.h" @@ -48,7 +48,7 @@ class DsoLoader { static port::Status GetLibcuptiDsoHandle(void** dso_handle); // Registers a new binary-relative path to use as a dlopen search path. - static void RegisterRpath(port::StringPiece path); + static void RegisterRpath(absl::string_view path); private: // Registered rpaths (singleton vector) and a mutex that guards it. @@ -61,10 +61,9 @@ class DsoLoader { // Loads a DSO from the given "path" (which can technically be any dlopen-able // name). If the load kind is global, the symbols in the loaded DSO are // visible to subsequent DSO loading operations. - static port::Status GetDsoHandle(port::StringPiece path, void** dso_handle, + static port::Status GetDsoHandle(absl::string_view path, void** dso_handle, LoadKind load_kind = LoadKind::kLocal); - // Returns the binary directory (or binary path) associated with the currently // executing program. If strip_executable_name is true, the executable file is // stripped off of the path. @@ -80,8 +79,8 @@ class DsoLoader { // library_name: the filename in tree; e.g. libOpenCL.so.1.0.0 // runfiles_relpath: where to look for the library relative to the runfiles // root; e.g. third_party/gpus/cuda/lib64 - static string FindDsoPath(port::StringPiece library_name, - port::StringPiece runfiles_relpath); + static string FindDsoPath(absl::string_view library_name, + absl::string_view runfiles_relpath); // Return platform dependent paths for DSOs static string GetCudaLibraryDirPath(); diff --git a/tensorflow/stream_executor/kernel.cc b/tensorflow/stream_executor/kernel.cc index e84b7e6cc2fbf2..240e955b6ff3d8 100644 --- a/tensorflow/stream_executor/kernel.cc +++ b/tensorflow/stream_executor/kernel.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/port.h" +#include "absl/strings/string_view.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/stream_executor/lib/demangle.h" #include "tensorflow/stream_executor/platform.h" @@ -93,9 +94,9 @@ KernelCacheConfig KernelBase::GetPreferredCacheConfig() const { // Prefix stub functions emitted by the CUDA splitter. static const char *kStubPrefix = "__device_stub_"; -void KernelBase::set_name(port::StringPiece name) { +void KernelBase::set_name(absl::string_view name) { name_ = string(name); - port::StringPiece stubless_name = name; + absl::string_view stubless_name = name; if (tensorflow::str_util::StartsWith(name, kStubPrefix)) { stubless_name.remove_prefix(strlen(kStubPrefix)); } diff --git a/tensorflow/stream_executor/kernel.h b/tensorflow/stream_executor/kernel.h index 2216884b873cda..9384db6858291d 100644 --- a/tensorflow/stream_executor/kernel.h +++ b/tensorflow/stream_executor/kernel.h @@ -75,11 +75,10 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/kernel_cache_config.h" #include "tensorflow/stream_executor/lib/array_slice.h" -#include "tensorflow/stream_executor/lib/inlined_vector.h" -#include "tensorflow/stream_executor/lib/stringpiece.h" #include "tensorflow/stream_executor/platform/port.h" namespace stream_executor { @@ -178,7 +177,7 @@ class KernelBase { // Gets the preferred cache configuration for a kernel. KernelCacheConfig GetPreferredCacheConfig() const; - void set_name(port::StringPiece name); + void set_name(absl::string_view name); const string &name() const { return name_; } const string &demangled_name() const { return demangled_name_; } diff --git a/tensorflow/stream_executor/kernel_spec.cc b/tensorflow/stream_executor/kernel_spec.cc index 1eaa0806993b1d..2e090af7169ff5 100644 --- a/tensorflow/stream_executor/kernel_spec.cc +++ b/tensorflow/stream_executor/kernel_spec.cc @@ -14,26 +14,27 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/stream_executor/kernel_spec.h" +#include "absl/strings/string_view.h" namespace stream_executor { -KernelLoaderSpec::KernelLoaderSpec(port::StringPiece kernelname) +KernelLoaderSpec::KernelLoaderSpec(absl::string_view kernelname) : kernelname_(string(kernelname)) {} -OnDiskKernelLoaderSpec::OnDiskKernelLoaderSpec(port::StringPiece filename, - port::StringPiece kernelname) +OnDiskKernelLoaderSpec::OnDiskKernelLoaderSpec(absl::string_view filename, + absl::string_view kernelname) : KernelLoaderSpec(kernelname), filename_(string(filename)) {} -CudaPtxOnDisk::CudaPtxOnDisk(port::StringPiece filename, - port::StringPiece kernelname) +CudaPtxOnDisk::CudaPtxOnDisk(absl::string_view filename, + absl::string_view kernelname) : OnDiskKernelLoaderSpec(filename, kernelname) {} -CudaCubinOnDisk::CudaCubinOnDisk(port::StringPiece filename, - port::StringPiece kernelname) +CudaCubinOnDisk::CudaCubinOnDisk(absl::string_view filename, + absl::string_view kernelname) : OnDiskKernelLoaderSpec(filename, kernelname) {} CudaCubinInMemory::CudaCubinInMemory(const char *bytes, - port::StringPiece kernelname) + absl::string_view kernelname) : KernelLoaderSpec(kernelname), bytes_(bytes) {} bool CompareComputeCapability(const std::tuple &lhs, @@ -45,8 +46,8 @@ bool CompareComputeCapability(const std::tuple &lhs, const std::tuple CudaPtxInMemory::kMinimumCapability{1, 0}; -CudaPtxInMemory::CudaPtxInMemory(port::StringPiece ptx, - port::StringPiece kernel_name, +CudaPtxInMemory::CudaPtxInMemory(absl::string_view ptx, + absl::string_view kernel_name, bool ptx_compressed) : KernelLoaderSpec(kernel_name), ptx_by_compute_capability_(CompareComputeCapability) { @@ -60,12 +61,12 @@ CudaPtxInMemory::CudaPtxInMemory(port::StringPiece ptx, CudaPtxInMemory::CudaPtxInMemory( const std::initializer_list &spec_list, - port::StringPiece kernel_name, bool ptx_compressed) + absl::string_view kernel_name, bool ptx_compressed) : KernelLoaderSpec(kernel_name), ptx_by_compute_capability_(CompareComputeCapability) { for (const auto &spec : spec_list) { int major, minor; - port::StringPiece ptx; + absl::string_view ptx; std::tie(major, minor, ptx) = spec; if (ptx_compressed) { // Lazy decompression. Put an empty string in decompressed_ptx_ showing @@ -155,62 +156,62 @@ const char *CudaPtxInMemory::original_text(int compute_capability_major, return ptx_iter->second; } -OpenCLTextOnDisk::OpenCLTextOnDisk(port::StringPiece filename, - port::StringPiece kernelname) +OpenCLTextOnDisk::OpenCLTextOnDisk(absl::string_view filename, + absl::string_view kernelname) : OnDiskKernelLoaderSpec(filename, kernelname) {} -OpenCLTextInMemory::OpenCLTextInMemory(port::StringPiece text, - port::StringPiece kernelname) +OpenCLTextInMemory::OpenCLTextInMemory(absl::string_view text, + absl::string_view kernelname) : KernelLoaderSpec(kernelname), text_(text) {} -OpenCLBinaryOnDisk::OpenCLBinaryOnDisk(port::StringPiece filename, - port::StringPiece kernelname) +OpenCLBinaryOnDisk::OpenCLBinaryOnDisk(absl::string_view filename, + absl::string_view kernelname) : OnDiskKernelLoaderSpec(filename, kernelname) {} MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddOpenCLTextOnDisk( - port::StringPiece filename, port::StringPiece kernelname) { + absl::string_view filename, absl::string_view kernelname) { CHECK(ocl_text_on_disk_ == nullptr); ocl_text_on_disk_.reset(new OpenCLTextOnDisk{filename, kernelname}); return this; } MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddOpenCLBinaryOnDisk( - port::StringPiece filename, port::StringPiece kernelname) { + absl::string_view filename, absl::string_view kernelname) { CHECK(ocl_binary_on_disk_ == nullptr); ocl_binary_on_disk_.reset(new OpenCLBinaryOnDisk{filename, kernelname}); return this; } MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddOpenCLTextInMemory( - port::StringPiece filename, port::StringPiece kernelname) { + absl::string_view filename, absl::string_view kernelname) { CHECK(ocl_text_in_memory_ == nullptr); ocl_text_in_memory_.reset(new OpenCLTextInMemory{filename, kernelname}); return this; } MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaPtxOnDisk( - port::StringPiece filename, port::StringPiece kernelname) { + absl::string_view filename, absl::string_view kernelname) { CHECK(cuda_ptx_on_disk_ == nullptr); cuda_ptx_on_disk_.reset(new CudaPtxOnDisk{filename, kernelname}); return this; } MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaCubinInMemory( - const char *bytes, port::StringPiece kernelname) { + const char *bytes, absl::string_view kernelname) { CHECK(cuda_cubin_in_memory_ == nullptr); cuda_cubin_in_memory_.reset(new CudaCubinInMemory{bytes, kernelname}); return this; } MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaCubinOnDisk( - port::StringPiece filename, port::StringPiece kernelname) { + absl::string_view filename, absl::string_view kernelname) { CHECK(cuda_cubin_on_disk_ == nullptr); cuda_cubin_on_disk_.reset(new CudaCubinOnDisk{filename, kernelname}); return this; } MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaPtxInMemory( - port::StringPiece ptx, port::StringPiece kernelname) { + absl::string_view ptx, absl::string_view kernelname) { CHECK(cuda_ptx_in_memory_ == nullptr); cuda_ptx_in_memory_.reset( new CudaPtxInMemory{ptx, kernelname, false /* ptx_compressed */}); @@ -218,7 +219,7 @@ MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaPtxInMemory( } MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaCompressedPtxInMemory( - port::StringPiece ptx, port::StringPiece kernelname) { + absl::string_view ptx, absl::string_view kernelname) { CHECK(cuda_ptx_in_memory_ == nullptr); cuda_ptx_in_memory_.reset( new CudaPtxInMemory{ptx, kernelname, true /* ptx_compressed */}); @@ -227,7 +228,7 @@ MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaCompressedPtxInMemory( MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaPtxInMemory( std::initializer_list spec_list, - port::StringPiece kernelname) { + absl::string_view kernelname) { CHECK(cuda_ptx_in_memory_ == nullptr); cuda_ptx_in_memory_.reset( new CudaPtxInMemory{spec_list, kernelname, false /* ptx_compressed */}); @@ -236,7 +237,7 @@ MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaPtxInMemory( MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaCompressedPtxInMemory( std::initializer_list spec_list, - port::StringPiece kernelname) { + absl::string_view kernelname) { CHECK(cuda_ptx_in_memory_ == nullptr); cuda_ptx_in_memory_.reset( new CudaPtxInMemory{spec_list, kernelname, true /* ptx_compressed */}); diff --git a/tensorflow/stream_executor/kernel_spec.h b/tensorflow/stream_executor/kernel_spec.h index 7cc23bb4e64b45..04b2eab084c79b 100644 --- a/tensorflow/stream_executor/kernel_spec.h +++ b/tensorflow/stream_executor/kernel_spec.h @@ -51,7 +51,7 @@ limitations under the License. #include #include "tensorflow/stream_executor/platform/port.h" -#include "tensorflow/stream_executor/lib/stringpiece.h" +#include "absl/strings/string_view.h" #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/port.h" @@ -76,7 +76,7 @@ class KernelLoaderSpec { const string &kernelname() const { return kernelname_; } protected: - explicit KernelLoaderSpec(port::StringPiece kernelname); + explicit KernelLoaderSpec(absl::string_view kernelname); private: // The kernel name that should be loaded out of the program description given @@ -101,8 +101,8 @@ class OnDiskKernelLoaderSpec : public KernelLoaderSpec { virtual const char *CanonicalSuffix() const = 0; protected: - OnDiskKernelLoaderSpec(port::StringPiece filename, - port::StringPiece kernelname); + OnDiskKernelLoaderSpec(absl::string_view filename, + absl::string_view kernelname); string filename_; @@ -113,7 +113,7 @@ class OnDiskKernelLoaderSpec : public KernelLoaderSpec { // Kernel loader specification for PTX text that resides on disk. class CudaPtxOnDisk : public OnDiskKernelLoaderSpec { public: - CudaPtxOnDisk(port::StringPiece filename, port::StringPiece kernelname); + CudaPtxOnDisk(absl::string_view filename, absl::string_view kernelname); ~CudaPtxOnDisk() override {} const char *CanonicalSuffix() const override { return ".ptx"; } @@ -125,7 +125,7 @@ class CudaPtxOnDisk : public OnDiskKernelLoaderSpec { // Kernel loader specification for CUBIN binary that resides on disk. class CudaCubinOnDisk : public OnDiskKernelLoaderSpec { public: - CudaCubinOnDisk(port::StringPiece filename, port::StringPiece kernelname); + CudaCubinOnDisk(absl::string_view filename, absl::string_view kernelname); ~CudaCubinOnDisk() override {} const string &filename() const { return filename_; } @@ -143,7 +143,7 @@ class CudaPtxInMemory : public KernelLoaderSpec { public: // Components: compute capability major number, compute capability minor // number, and PTX source. - typedef std::tuple PtxSpec; + typedef std::tuple PtxSpec; // Single-PTX constructor. Adds the provided PTX version with an unknown // compute capability. Since the CC is unknown, the PTX is assumed to be very @@ -151,16 +151,16 @@ class CudaPtxInMemory : public KernelLoaderSpec { // likely to be used as the default! Note that the PTX can be compressed, // which is indicated by the argument ptx_compressed. // - // Warning: the string backing the provided port::StringPiece ptx must outlive this - // instance. - CudaPtxInMemory(port::StringPiece ptx, port::StringPiece kernelname, + // Warning: the string backing the provided absl::string_view ptx must outlive + // this instance. + CudaPtxInMemory(absl::string_view ptx, absl::string_view kernelname, bool ptx_compressed = false); // Multiple-PTX-version constructor. Adds each item in spec_list to this // object. Note that the PTX can be compressed, which is indicated by the // argument ptx_compressed. CudaPtxInMemory(const std::initializer_list &spec_list, - port::StringPiece kernel_name, bool ptx_compressed = false); + absl::string_view kernel_name, bool ptx_compressed = false); ~CudaPtxInMemory() override {} // Add the PTX implementation described by ptx_spec to this object. On @@ -218,7 +218,7 @@ class CudaPtxInMemory : public KernelLoaderSpec { // Kernel loader specification for OpenCL text that resides on disk. class OpenCLTextOnDisk : public OnDiskKernelLoaderSpec { public: - OpenCLTextOnDisk(port::StringPiece filename, port::StringPiece kernelname); + OpenCLTextOnDisk(absl::string_view filename, absl::string_view kernelname); ~OpenCLTextOnDisk() override {} const char *CanonicalSuffix() const override { return ".ocl"; } @@ -230,7 +230,7 @@ class OpenCLTextOnDisk : public OnDiskKernelLoaderSpec { // Kernel loader specification for OpenCL binary that resides on disk. class OpenCLBinaryOnDisk : public OnDiskKernelLoaderSpec { public: - OpenCLBinaryOnDisk(port::StringPiece filename, port::StringPiece kernelname); + OpenCLBinaryOnDisk(absl::string_view filename, absl::string_view kernelname); ~OpenCLBinaryOnDisk() override {} const char *CanonicalSuffix() const override { return ".aocx"; } @@ -242,7 +242,7 @@ class OpenCLBinaryOnDisk : public OnDiskKernelLoaderSpec { // Kernel loader specification for OpenCL text that resides in memory. class OpenCLTextInMemory : public KernelLoaderSpec { public: - OpenCLTextInMemory(port::StringPiece text, port::StringPiece kernelname); + OpenCLTextInMemory(absl::string_view text, absl::string_view kernelname); ~OpenCLTextInMemory() override {} // Returns the OpenCL text contents. @@ -258,7 +258,7 @@ class OpenCLTextInMemory : public KernelLoaderSpec { // Kernel loader specification for a CUBIN blob that resides in memory. class CudaCubinInMemory : public KernelLoaderSpec { public: - CudaCubinInMemory(const char *bytes, port::StringPiece kernelname); + CudaCubinInMemory(const char *bytes, absl::string_view kernelname); ~CudaCubinInMemory() override {} const char *bytes() const { return bytes_; } @@ -328,28 +328,28 @@ class MultiKernelLoaderSpec { // the PTX or OpenCL being loaded. Also be aware that in CUDA C++ the kernel // name may be mangled by the compiler if it is not declared in an // extern "C" scope. - MultiKernelLoaderSpec *AddOpenCLTextOnDisk(port::StringPiece filename, - port::StringPiece kernelname); - MultiKernelLoaderSpec *AddOpenCLBinaryOnDisk(port::StringPiece filename, - port::StringPiece kernelname); - MultiKernelLoaderSpec *AddOpenCLTextInMemory(port::StringPiece ocl_text, - port::StringPiece kernelname); - MultiKernelLoaderSpec *AddCudaPtxOnDisk(port::StringPiece filename, - port::StringPiece kernelname); - MultiKernelLoaderSpec *AddCudaCubinOnDisk(port::StringPiece filename, - port::StringPiece kernelname); + MultiKernelLoaderSpec *AddOpenCLTextOnDisk(absl::string_view filename, + absl::string_view kernelname); + MultiKernelLoaderSpec *AddOpenCLBinaryOnDisk(absl::string_view filename, + absl::string_view kernelname); + MultiKernelLoaderSpec *AddOpenCLTextInMemory(absl::string_view ocl_text, + absl::string_view kernelname); + MultiKernelLoaderSpec *AddCudaPtxOnDisk(absl::string_view filename, + absl::string_view kernelname); + MultiKernelLoaderSpec *AddCudaCubinOnDisk(absl::string_view filename, + absl::string_view kernelname); MultiKernelLoaderSpec *AddCudaCubinInMemory(const char *cubin_bytes, - port::StringPiece kernelname); - MultiKernelLoaderSpec *AddCudaPtxInMemory(port::StringPiece ptx, - port::StringPiece kernelname); + absl::string_view kernelname); + MultiKernelLoaderSpec *AddCudaPtxInMemory(absl::string_view ptx, + absl::string_view kernelname); MultiKernelLoaderSpec *AddCudaCompressedPtxInMemory( - port::StringPiece ptx, port::StringPiece kernelname); + absl::string_view ptx, absl::string_view kernelname); MultiKernelLoaderSpec *AddCudaPtxInMemory( std::initializer_list spec_list, - port::StringPiece kernelname); + absl::string_view kernelname); MultiKernelLoaderSpec *AddCudaCompressedPtxInMemory( std::initializer_list spec_list, - port::StringPiece kernelname); + absl::string_view kernelname); private: std::unique_ptr diff --git a/tensorflow/stream_executor/lib/env.h b/tensorflow/stream_executor/lib/env.h index d78bbfd425925f..a5eb8ef1d433be 100644 --- a/tensorflow/stream_executor/lib/env.h +++ b/tensorflow/stream_executor/lib/env.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_ENV_H_ #define TENSORFLOW_STREAM_EXECUTOR_LIB_ENV_H_ +#include "absl/strings/string_view.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/stream_executor/lib/status.h" -#include "tensorflow/stream_executor/lib/stringpiece.h" #include "tensorflow/stream_executor/platform/port.h" namespace stream_executor { @@ -31,7 +31,7 @@ inline Status FileExists(const string& filename) { return Env::Default()->FileExists(filename); } -inline Status FileExists(const port::StringPiece& filename) { +inline Status FileExists(const absl::string_view& filename) { return Env::Default()->FileExists(string(filename)); } diff --git a/tensorflow/stream_executor/lib/inlined_vector.h b/tensorflow/stream_executor/lib/inlined_vector.h deleted file mode 100644 index 0198947e5badf9..00000000000000 --- a/tensorflow/stream_executor/lib/inlined_vector.h +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_INLINED_VECTOR_H_ -#define TENSORFLOW_STREAM_EXECUTOR_LIB_INLINED_VECTOR_H_ - -#include "absl/container/inlined_vector.h" - -namespace stream_executor { -namespace port { - -using absl::InlinedVector; - -} // namespace port -} // namespace stream_executor - -#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_INLINED_VECTOR_H_ diff --git a/tensorflow/stream_executor/lib/path.cc b/tensorflow/stream_executor/lib/path.cc index 401b87b6592ed7..47eedbc6a163af 100644 --- a/tensorflow/stream_executor/lib/path.cc +++ b/tensorflow/stream_executor/lib/path.cc @@ -15,21 +15,22 @@ limitations under the License. #include "tensorflow/stream_executor/lib/path.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" namespace stream_executor { namespace port { namespace internal { -static bool IsAbsolutePath(port::StringPiece path) { +static bool IsAbsolutePath(absl::string_view path) { return !path.empty() && path[0] == '/'; } // For an array of paths of length count, append them all together, // ensuring that the proper path separators are inserted between them. -string JoinPathImpl(std::initializer_list paths) { +string JoinPathImpl(std::initializer_list paths) { string result; - for (port::StringPiece path : paths) { + for (absl::string_view path : paths) { if (path.empty()) continue; if (result.empty()) { diff --git a/tensorflow/stream_executor/lib/path.h b/tensorflow/stream_executor/lib/path.h index 325f04ff47552e..76a623cc033ac2 100644 --- a/tensorflow/stream_executor/lib/path.h +++ b/tensorflow/stream_executor/lib/path.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_PATH_H_ #define TENSORFLOW_STREAM_EXECUTOR_LIB_PATH_H_ +#include "absl/strings/string_view.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/stream_executor/lib/stringpiece.h" #include "tensorflow/stream_executor/platform/port.h" namespace stream_executor { @@ -28,7 +28,7 @@ using tensorflow::io::Dirname; namespace internal { // TODO(rspringer): Move to cc/implementation file. // Not part of the public API. -string JoinPathImpl(std::initializer_list paths); +string JoinPathImpl(std::initializer_list paths); } // namespace internal // Join multiple paths together. @@ -44,7 +44,7 @@ string JoinPathImpl(std::initializer_list paths); // All paths will be treated as relative paths, regardless of whether or not // they start with a leading '/'. That is, all paths will be concatenated // together, with the appropriate path separator inserted in between. -// Arguments must be convertible to port::StringPiece. +// Arguments must be convertible to absl::string_view. // // Usage: // string path = file::JoinPath("/var/log", dirname, filename); diff --git a/tensorflow/stream_executor/lib/status.h b/tensorflow/stream_executor/lib/status.h index 407b71b405bc8a..87269b4591a864 100644 --- a/tensorflow/stream_executor/lib/status.h +++ b/tensorflow/stream_executor/lib/status.h @@ -18,9 +18,9 @@ limitations under the License. #ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STATUS_H_ #define TENSORFLOW_STREAM_EXECUTOR_LIB_STATUS_H_ +#include "absl/strings/string_view.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/stream_executor/lib/error.h" // IWYU pragma: export -#include "tensorflow/stream_executor/lib/stringpiece.h" #include "tensorflow/stream_executor/platform/logging.h" namespace stream_executor { @@ -33,13 +33,13 @@ using Status = tensorflow::Status; ASSERT_EQ(::stream_executor::port::Status::OK(), (val)) // Define some canonical error helpers. -inline Status UnimplementedError(StringPiece message) { +inline Status UnimplementedError(absl::string_view message) { return Status(error::UNIMPLEMENTED, message); } -inline Status InternalError(StringPiece message) { +inline Status InternalError(absl::string_view message) { return Status(error::INTERNAL, message); } -inline Status FailedPreconditionError(StringPiece message) { +inline Status FailedPreconditionError(absl::string_view message) { return Status(error::FAILED_PRECONDITION, message); } diff --git a/tensorflow/stream_executor/lib/str_util.h b/tensorflow/stream_executor/lib/str_util.h index e77dfcef768a38..e99dfa8399d95a 100644 --- a/tensorflow/stream_executor/lib/str_util.h +++ b/tensorflow/stream_executor/lib/str_util.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STR_UTIL_H_ #define TENSORFLOW_STREAM_EXECUTOR_LIB_STR_UTIL_H_ +#include "absl/strings/string_view.h" #include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/stream_executor/lib/stringpiece.h" namespace stream_executor { namespace port { @@ -27,7 +27,8 @@ using tensorflow::str_util::Split; // Returns a copy of the input string 'str' with the given 'suffix' // removed. If the suffix doesn't match, returns a copy of the original string. -inline string StripSuffixString(port::StringPiece str, port::StringPiece suffix) { +inline string StripSuffixString(absl::string_view str, + absl::string_view suffix) { if (tensorflow::str_util::EndsWith(str, suffix)) { str.remove_suffix(suffix.size()); } diff --git a/tensorflow/stream_executor/module_spec.h b/tensorflow/stream_executor/module_spec.h index 75bdfed2d70364..e8a970283c55bb 100644 --- a/tensorflow/stream_executor/module_spec.h +++ b/tensorflow/stream_executor/module_spec.h @@ -17,7 +17,6 @@ limitations under the License. #define TENSORFLOW_STREAM_EXECUTOR_MODULE_SPEC_H_ #include "tensorflow/stream_executor/lib/array_slice.h" -#include "tensorflow/stream_executor/lib/stringpiece.h" #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/platform/port.h" diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h index 59a477b5c9c37f..32f75fd1bc10b4 100644 --- a/tensorflow/stream_executor/stream_executor_internal.h +++ b/tensorflow/stream_executor/stream_executor_internal.h @@ -36,7 +36,6 @@ limitations under the License. #include "tensorflow/stream_executor/kernel_cache_config.h" #include "tensorflow/stream_executor/kernel_spec.h" #include "tensorflow/stream_executor/launch_dim.h" -#include "tensorflow/stream_executor/lib/inlined_vector.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/module_spec.h" diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index c06a1036399c2f..74773629d299a6 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -49,6 +49,9 @@ def register_extension_info(**kwargs): # if_cuda_is_configured def placeholder +def if_cuda_is_configured_compat(x): + return if_cuda_is_configured(x) + # Given a source file, generate a test name. # i.e. "common_runtime/direct_session_test.cc" becomes # "common_runtime_direct_session_test" @@ -170,10 +173,10 @@ def if_not_windows(a): "//conditions:default": a, }) -def if_windows(a): +def if_windows(a, otherwise = []): return select({ clean_dep("//tensorflow:windows"): a, - "//conditions:default": [], + "//conditions:default": otherwise, }) def if_not_windows_cuda(a): @@ -1042,7 +1045,7 @@ def _cuda_copts(opts = []): "@local_config_cuda//cuda:using_clang": ([ "-fcuda-flush-denormals-to-zero", ]), - }) + if_cuda_is_configured(opts) + }) + if_cuda_is_configured_compat(opts) # Build defs for TensorFlow kernels @@ -1067,7 +1070,7 @@ def tf_gpu_kernel_library( srcs = srcs, hdrs = hdrs, copts = copts, - deps = deps + if_cuda_is_configured([ + deps = deps + if_cuda_is_configured_compat([ clean_dep("//tensorflow/core:cuda"), clean_dep("//tensorflow/core:gpu_lib"), ]) + if_rocm_is_configured([ @@ -1107,7 +1110,7 @@ def tf_cuda_library(deps = None, cuda_deps = None, copts = tf_copts(), **kwargs) kwargs["features"] = kwargs.get("features", []) + ["-use_header_modules"] native.cc_library( - deps = deps + if_cuda(cuda_deps + [ + deps = deps + if_cuda_is_configured_compat(cuda_deps + [ clean_dep("//tensorflow/core:cuda"), "@local_config_cuda//cuda:cuda_headers", ]) + if_rocm_is_configured(cuda_deps + [ @@ -1504,7 +1507,7 @@ def tf_custom_op_library(name, srcs = [], gpu_srcs = [], deps = [], linkopts = [ srcs = gpu_srcs, copts = _cuda_copts() + if_tensorrt(["-DGOOGLE_TENSORRT=1"]), features = if_cuda(["-use_header_modules"]), - deps = deps + if_cuda_is_configured(cuda_deps) + if_rocm_is_configured(rocm_deps), + deps = deps + if_cuda_is_configured_compat(cuda_deps) + if_rocm_is_configured(rocm_deps), **kwargs ) cuda_deps.extend([":" + basename + "_gpu"]) @@ -1516,12 +1519,12 @@ def tf_custom_op_library(name, srcs = [], gpu_srcs = [], deps = [], linkopts = [ clean_dep("//tensorflow/core:framework"), clean_dep("//tensorflow/core:lib"), ], - deps = deps + if_cuda_is_configured(cuda_deps) + if_rocm_is_configured(rocm_deps), + deps = deps + if_cuda_is_configured_compat(cuda_deps) + if_rocm_is_configured(rocm_deps), ) tf_cc_shared_object( name = name, srcs = srcs, - deps = deps + if_cuda_is_configured(cuda_deps) + if_rocm_is_configured(rocm_deps), + deps = deps + if_cuda_is_configured_compat(cuda_deps) + if_rocm_is_configured(rocm_deps), data = if_static([name + "_check_deps"]), copts = tf_copts(is_external = True), features = ["windows_export_all_symbols"], @@ -1996,7 +1999,7 @@ def tf_py_build_info_genrule(): name = "py_build_info_gen", outs = ["platform/build_info.py"], cmd = - "$(location //tensorflow/tools/build_info:gen_build_info) --raw_generate \"$@\" --build_config " + if_cuda("cuda", "cpu"), + "$(location //tensorflow/tools/build_info:gen_build_info) --raw_generate \"$@\" --build_config " + if_cuda("cuda", "cpu") + if_windows(" --key_value msvcp_dll_name=msvcp140.dll", ""), local = 1, tools = [clean_dep("//tensorflow/tools/build_info:gen_build_info")], ) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt index 9f6dcd8fdb0697..f7491649c22738 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt @@ -14,6 +14,12 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_STRING } + field { + name: "recv_buf_max_chunk" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } reserved_range { start: 2 end: 3 diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt index f3a515163df642..53b532beab344d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt @@ -137,6 +137,12 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_STRING } + field { + name: "recv_buf_max_chunk" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } reserved_range { start: 2 end: 3 diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-g-p-u-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-g-p-u-options.pbtxt index 353e63127de174..a2cc07483a4e10 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-g-p-u-options.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-g-p-u-options.pbtxt @@ -78,6 +78,12 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_INT32 } + field { + name: "collective_ring_order" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_STRING + } nested_type { name: "VirtualDevices" field { diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.__metaclass__.pbtxt deleted file mode 100644 index d81a3d986daded..00000000000000 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.data.Dataset.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.__metaclass__.pbtxt deleted file mode 100644 index eb7c8dc2644b9b..00000000000000 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.data.FixedLengthRecordDataset.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.__metaclass__.pbtxt deleted file mode 100644 index 7cd273b2dd8cfa..00000000000000 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.data.TFRecordDataset.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.__metaclass__.pbtxt deleted file mode 100644 index b30f93ef5d7280..00000000000000 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.data.TextLineDataset.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.__metaclass__.pbtxt deleted file mode 100644 index 604a1dc89e6a23..00000000000000 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.data.experimental.CsvDataset.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.__metaclass__.pbtxt deleted file mode 100644 index 0c2300a4da6f6f..00000000000000 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.data.experimental.RandomDataset.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.__metaclass__.pbtxt deleted file mode 100644 index f1a96b03e51ec5..00000000000000 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.data.experimental.SqlDataset.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-estimator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-estimator.pbtxt new file mode 100644 index 00000000000000..4635a1544c35cc --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-estimator.pbtxt @@ -0,0 +1,62 @@ +path: "tensorflow.estimator.DNNEstimator" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "model_fn" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'head\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'batch_norm\'], varargs=None, keywords=None, defaults=[\'None\', \'Adagrad\', \'\', \'None\', \'None\', \'None\', \'None\', \'False\'], " + } + member_method { + name: "eval_dir" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_saved_model" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " + } + member_method { + name: "get_variable_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_variable_value" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "latest_checkpoint" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-estimator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-estimator.pbtxt new file mode 100644 index 00000000000000..3d6b03098aac47 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-estimator.pbtxt @@ -0,0 +1,62 @@ +path: "tensorflow.estimator.LinearEstimator" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "model_fn" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'head\', \'feature_columns\', \'model_dir\', \'optimizer\', \'config\', \'partitioner\', \'sparse_combiner\'], varargs=None, keywords=None, defaults=[\'None\', \'Ftrl\', \'None\', \'None\', \'sum\'], " + } + member_method { + name: "eval_dir" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_saved_model" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " + } + member_method { + name: "get_variable_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_variable_value" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "latest_checkpoint" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.pbtxt index ec3216ae705709..c5b0085b8d3ec5 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.pbtxt @@ -28,6 +28,10 @@ tf_module { name: "DNNClassifier" mtype: "" } + member { + name: "DNNEstimator" + mtype: "" + } member { name: "DNNLinearCombinedClassifier" mtype: "" @@ -72,6 +76,10 @@ tf_module { name: "LinearClassifier" mtype: "" } + member { + name: "LinearEstimator" + mtype: "" + } member { name: "LinearRegressor" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.experimental.pbtxt new file mode 100644 index 00000000000000..0c3f04e468c4c8 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.experimental.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.experimental" +tf_module { + member_method { + name: "function_executor_type" + argspec: "args=[\'executor_type\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-block-diag.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-block-diag.__metaclass__.pbtxt deleted file mode 100644 index b1bed0c6db91e4..00000000000000 --- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-block-diag.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorBlockDiag.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant.__metaclass__.pbtxt deleted file mode 100644 index 5266853d489b04..00000000000000 --- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorCirculant.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant2-d.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant2-d.__metaclass__.pbtxt deleted file mode 100644 index 515714fb570ae7..00000000000000 --- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant2-d.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorCirculant2D.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant3-d.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant3-d.__metaclass__.pbtxt deleted file mode 100644 index 6d2606ccb26133..00000000000000 --- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant3-d.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorCirculant3D.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-composition.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-composition.__metaclass__.pbtxt deleted file mode 100644 index 09c61d4cb4f589..00000000000000 --- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-composition.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorComposition.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-diag.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-diag.__metaclass__.pbtxt deleted file mode 100644 index d13f7a1e445043..00000000000000 --- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-diag.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorDiag.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-full-matrix.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-full-matrix.__metaclass__.pbtxt deleted file mode 100644 index f8fbfac13c9c23..00000000000000 --- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-full-matrix.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorFullMatrix.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-identity.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-identity.__metaclass__.pbtxt deleted file mode 100644 index d87f5d31d33290..00000000000000 --- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-identity.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorIdentity.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-kronecker.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-kronecker.__metaclass__.pbtxt deleted file mode 100644 index d721caca397ec4..00000000000000 --- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-kronecker.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorKronecker.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-low-rank-update.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-low-rank-update.__metaclass__.pbtxt deleted file mode 100644 index 338f873788257b..00000000000000 --- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-low-rank-update.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorLowRankUpdate.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-lower-triangular.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-lower-triangular.__metaclass__.pbtxt deleted file mode 100644 index 46353200385d69..00000000000000 --- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-lower-triangular.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorLowerTriangular.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-scaled-identity.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-scaled-identity.__metaclass__.pbtxt deleted file mode 100644 index f3f370b35fa579..00000000000000 --- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-scaled-identity.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorScaledIdentity.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-zeros.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-zeros.__metaclass__.pbtxt deleted file mode 100644 index 14dd9423e6decc..00000000000000 --- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-zeros.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorZeros.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator.__metaclass__.pbtxt deleted file mode 100644 index dd5e383b5f8782..00000000000000 --- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperator.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt index cbab7ce6314320..1a4098d121b71d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt @@ -136,6 +136,10 @@ tf_module { name: "matmul" argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], " } + member_method { + name: "matvec" + argspec: "args=[\'a\', \'b\', \'transpose_a\', \'adjoint_a\', \'a_is_sparse\', \'b_is_sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " + } member_method { name: "norm" argspec: "args=[\'tensor\', \'ord\', \'axis\', \'keepdims\', \'name\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\', \'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.-interpreter.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.-interpreter.pbtxt new file mode 100644 index 00000000000000..ec0d9522bca9e0 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.-interpreter.pbtxt @@ -0,0 +1,49 @@ +path: "tensorflow.lite.Interpreter" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'model_path\', \'model_content\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "allocate_tensors" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_details" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_details" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_tensor" + argspec: "args=[\'self\', \'tensor_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_tensor_details" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "invoke" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_all_variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "resize_tensor_input" + argspec: "args=[\'self\', \'input_index\', \'tensor_size\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_tensor" + argspec: "args=[\'self\', \'tensor_index\', \'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "tensor" + argspec: "args=[\'self\', \'tensor_index\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.-op-hint.-op-hint-argument-tracker.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.-op-hint.-op-hint-argument-tracker.pbtxt new file mode 100644 index 00000000000000..1fe179f6c1b64e --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.-op-hint.-op-hint-argument-tracker.pbtxt @@ -0,0 +1,13 @@ +path: "tensorflow.lite.OpHint.OpHintArgumentTracker" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'function_name\', \'unique_function_id\', \'node_name_prefix\', \'attr_name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "add" + argspec: "args=[\'self\', \'arg\', \'tag\', \'name\', \'aggregate\', \'index_override\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.-op-hint.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.-op-hint.pbtxt new file mode 100644 index 00000000000000..66e692a5a37920 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.-op-hint.pbtxt @@ -0,0 +1,69 @@ +path: "tensorflow.lite.OpHint" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "AGGREGATE_FIRST" + mtype: "" + } + member { + name: "AGGREGATE_LAST" + mtype: "" + } + member { + name: "AGGREGATE_STACK" + mtype: "" + } + member { + name: "FUNCTION_AGGREGATE_ATTR" + mtype: "" + } + member { + name: "FUNCTION_INPUT_INDEX_ATTR" + mtype: "" + } + member { + name: "FUNCTION_NAME_ATTR" + mtype: "" + } + member { + name: "FUNCTION_OUTPUT_INDEX_ATTR" + mtype: "" + } + member { + name: "FUNCTION_SORT_INDEX_ATTR" + mtype: "" + } + member { + name: "FUNCTION_UUID_ATTR" + mtype: "" + } + member { + name: "OpHintArgumentTracker" + mtype: "" + } + member { + name: "TFLITE_INPUT_INDICES" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'function_name\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_input" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_inputs" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_output" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_outputs" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.-t-f-lite-converter.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.-t-f-lite-converter.pbtxt new file mode 100644 index 00000000000000..c955b1a04a4b8a --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.-t-f-lite-converter.pbtxt @@ -0,0 +1,33 @@ +path: "tensorflow.lite.TFLiteConverter" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'graph_def\', \'input_tensors\', \'output_tensors\', \'input_arrays_with_shape\', \'output_arrays\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "convert" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_frozen_graph" + argspec: "args=[\'cls\', \'graph_def_file\', \'input_arrays\', \'output_arrays\', \'input_shapes\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_keras_model_file" + argspec: "args=[\'cls\', \'model_file\', \'input_arrays\', \'input_shapes\', \'output_arrays\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "from_saved_model" + argspec: "args=[\'cls\', \'saved_model_dir\', \'input_arrays\', \'input_shapes\', \'output_arrays\', \'tag_set\', \'signature_key\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "from_session" + argspec: "args=[\'cls\', \'sess\', \'input_tensors\', \'output_tensors\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_arrays" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.-toco-converter.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.-toco-converter.pbtxt new file mode 100644 index 00000000000000..3ef90b8bc4646a --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.-toco-converter.pbtxt @@ -0,0 +1,24 @@ +path: "tensorflow.lite.TocoConverter" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "from_frozen_graph" + argspec: "args=[\'cls\', \'graph_def_file\', \'input_arrays\', \'output_arrays\', \'input_shapes\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_keras_model_file" + argspec: "args=[\'cls\', \'model_file\', \'input_arrays\', \'input_shapes\', \'output_arrays\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "from_saved_model" + argspec: "args=[\'cls\', \'saved_model_dir\', \'input_arrays\', \'input_shapes\', \'output_arrays\', \'tag_set\', \'signature_key\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "from_session" + argspec: "args=[\'cls\', \'sess\', \'input_tensors\', \'output_tensors\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.constants.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.constants.pbtxt new file mode 100644 index 00000000000000..08845553e55d3b --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.constants.pbtxt @@ -0,0 +1,31 @@ +path: "tensorflow.lite.constants" +tf_module { + member { + name: "FLOAT" + mtype: "" + } + member { + name: "GRAPHVIZ_DOT" + mtype: "" + } + member { + name: "INT32" + mtype: "" + } + member { + name: "INT64" + mtype: "" + } + member { + name: "QUANTIZED_UINT8" + mtype: "" + } + member { + name: "STRING" + mtype: "" + } + member { + name: "TFLITE" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.pbtxt new file mode 100644 index 00000000000000..f5013c250be847 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.pbtxt @@ -0,0 +1,27 @@ +path: "tensorflow.lite" +tf_module { + member { + name: "Interpreter" + mtype: "" + } + member { + name: "OpHint" + mtype: "" + } + member { + name: "TFLiteConverter" + mtype: "" + } + member { + name: "TocoConverter" + mtype: "" + } + member { + name: "constants" + mtype: "" + } + member_method { + name: "toco_convert" + argspec: "args=[\'input_data\', \'input_tensors\', \'output_tensors\'], varargs=args, keywords=kwargs, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 1926899c841ab3..9597dd7684eacf 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -340,6 +340,10 @@ tf_module { name: "estimator" mtype: "" } + member { + name: "experimental" + mtype: "" + } member { name: "feature_column" mtype: "" @@ -420,6 +424,10 @@ tf_module { name: "linalg" mtype: "" } + member { + name: "lite" + mtype: "" + } member { name: "logging" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.saved_model.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.experimental.pbtxt new file mode 100644 index 00000000000000..34343e7c039a37 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.experimental.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.saved_model.experimental" +tf_module { + member_method { + name: "save" + argspec: "args=[\'obj\', \'export_dir\', \'signatures\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.saved_model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.pbtxt index 5b28f7b9b1824e..2055bfbf066cbb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.saved_model.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.pbtxt @@ -108,6 +108,10 @@ tf_module { name: "constants" mtype: "" } + member { + name: "experimental" + mtype: "" + } member { name: "loader" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt index 9f6dcd8fdb0697..f7491649c22738 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt @@ -14,6 +14,12 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_STRING } + field { + name: "recv_buf_max_chunk" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } reserved_range { start: 2 end: 3 diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt index f3a515163df642..53b532beab344d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt @@ -137,6 +137,12 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_STRING } + field { + name: "recv_buf_max_chunk" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } reserved_range { start: 2 end: 3 diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-g-p-u-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-g-p-u-options.pbtxt index 353e63127de174..a2cc07483a4e10 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-g-p-u-options.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-g-p-u-options.pbtxt @@ -78,6 +78,12 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_INT32 } + field { + name: "collective_ring_order" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_STRING + } nested_type { name: "VirtualDevices" field { diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.__metaclass__.pbtxt deleted file mode 100644 index d81a3d986daded..00000000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.data.Dataset.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.__metaclass__.pbtxt deleted file mode 100644 index eb7c8dc2644b9b..00000000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.data.FixedLengthRecordDataset.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.__metaclass__.pbtxt deleted file mode 100644 index 7cd273b2dd8cfa..00000000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.data.TFRecordDataset.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.__metaclass__.pbtxt deleted file mode 100644 index b30f93ef5d7280..00000000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.data.TextLineDataset.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.__metaclass__.pbtxt deleted file mode 100644 index 604a1dc89e6a23..00000000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.data.experimental.CsvDataset.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.__metaclass__.pbtxt deleted file mode 100644 index 0c2300a4da6f6f..00000000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.data.experimental.RandomDataset.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.__metaclass__.pbtxt deleted file mode 100644 index f1a96b03e51ec5..00000000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.data.experimental.SqlDataset.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-estimator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-estimator.pbtxt new file mode 100644 index 00000000000000..4635a1544c35cc --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-estimator.pbtxt @@ -0,0 +1,62 @@ +path: "tensorflow.estimator.DNNEstimator" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "model_fn" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'head\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'batch_norm\'], varargs=None, keywords=None, defaults=[\'None\', \'Adagrad\', \'\', \'None\', \'None\', \'None\', \'None\', \'False\'], " + } + member_method { + name: "eval_dir" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_saved_model" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " + } + member_method { + name: "get_variable_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_variable_value" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "latest_checkpoint" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-estimator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-estimator.pbtxt new file mode 100644 index 00000000000000..3d6b03098aac47 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-estimator.pbtxt @@ -0,0 +1,62 @@ +path: "tensorflow.estimator.LinearEstimator" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "model_fn" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'head\', \'feature_columns\', \'model_dir\', \'optimizer\', \'config\', \'partitioner\', \'sparse_combiner\'], varargs=None, keywords=None, defaults=[\'None\', \'Ftrl\', \'None\', \'None\', \'sum\'], " + } + member_method { + name: "eval_dir" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_saved_model" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " + } + member_method { + name: "get_variable_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_variable_value" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "latest_checkpoint" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.pbtxt index ec3216ae705709..c5b0085b8d3ec5 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.pbtxt @@ -28,6 +28,10 @@ tf_module { name: "DNNClassifier" mtype: "" } + member { + name: "DNNEstimator" + mtype: "" + } member { name: "DNNLinearCombinedClassifier" mtype: "" @@ -72,6 +76,10 @@ tf_module { name: "LinearClassifier" mtype: "" } + member { + name: "LinearEstimator" + mtype: "" + } member { name: "LinearRegressor" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.experimental.pbtxt new file mode 100644 index 00000000000000..0c3f04e468c4c8 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.experimental.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.experimental" +tf_module { + member_method { + name: "function_executor_type" + argspec: "args=[\'executor_type\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-block-diag.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-block-diag.__metaclass__.pbtxt deleted file mode 100644 index b1bed0c6db91e4..00000000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-block-diag.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorBlockDiag.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant.__metaclass__.pbtxt deleted file mode 100644 index 5266853d489b04..00000000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorCirculant.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant2-d.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant2-d.__metaclass__.pbtxt deleted file mode 100644 index 515714fb570ae7..00000000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant2-d.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorCirculant2D.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant3-d.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant3-d.__metaclass__.pbtxt deleted file mode 100644 index 6d2606ccb26133..00000000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant3-d.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorCirculant3D.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-composition.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-composition.__metaclass__.pbtxt deleted file mode 100644 index 09c61d4cb4f589..00000000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-composition.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorComposition.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-diag.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-diag.__metaclass__.pbtxt deleted file mode 100644 index d13f7a1e445043..00000000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-diag.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorDiag.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-full-matrix.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-full-matrix.__metaclass__.pbtxt deleted file mode 100644 index f8fbfac13c9c23..00000000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-full-matrix.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorFullMatrix.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-identity.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-identity.__metaclass__.pbtxt deleted file mode 100644 index d87f5d31d33290..00000000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-identity.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorIdentity.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-kronecker.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-kronecker.__metaclass__.pbtxt deleted file mode 100644 index d721caca397ec4..00000000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-kronecker.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorKronecker.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-low-rank-update.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-low-rank-update.__metaclass__.pbtxt deleted file mode 100644 index 338f873788257b..00000000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-low-rank-update.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorLowRankUpdate.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-lower-triangular.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-lower-triangular.__metaclass__.pbtxt deleted file mode 100644 index 46353200385d69..00000000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-lower-triangular.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorLowerTriangular.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-scaled-identity.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-scaled-identity.__metaclass__.pbtxt deleted file mode 100644 index f3f370b35fa579..00000000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-scaled-identity.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorScaledIdentity.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-zeros.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-zeros.__metaclass__.pbtxt deleted file mode 100644 index 14dd9423e6decc..00000000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-zeros.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperatorZeros.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator.__metaclass__.pbtxt deleted file mode 100644 index dd5e383b5f8782..00000000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator.__metaclass__.pbtxt +++ /dev/null @@ -1,14 +0,0 @@ -path: "tensorflow.linalg.LinearOperator.__metaclass__" -tf_class { - is_instance: "" - member_method { - name: "__init__" - } - member_method { - name: "mro" - } - member_method { - name: "register" - argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt index cbab7ce6314320..1a4098d121b71d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt @@ -136,6 +136,10 @@ tf_module { name: "matmul" argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], " } + member_method { + name: "matvec" + argspec: "args=[\'a\', \'b\', \'transpose_a\', \'adjoint_a\', \'a_is_sparse\', \'b_is_sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " + } member_method { name: "norm" argspec: "args=[\'tensor\', \'ord\', \'axis\', \'keepdims\', \'name\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\', \'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.lite.-interpreter.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.lite.-interpreter.pbtxt new file mode 100644 index 00000000000000..ec0d9522bca9e0 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.lite.-interpreter.pbtxt @@ -0,0 +1,49 @@ +path: "tensorflow.lite.Interpreter" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'model_path\', \'model_content\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "allocate_tensors" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_details" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_details" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_tensor" + argspec: "args=[\'self\', \'tensor_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_tensor_details" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "invoke" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_all_variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "resize_tensor_input" + argspec: "args=[\'self\', \'input_index\', \'tensor_size\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_tensor" + argspec: "args=[\'self\', \'tensor_index\', \'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "tensor" + argspec: "args=[\'self\', \'tensor_index\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.lite.-op-hint.-op-hint-argument-tracker.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.lite.-op-hint.-op-hint-argument-tracker.pbtxt new file mode 100644 index 00000000000000..1fe179f6c1b64e --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.lite.-op-hint.-op-hint-argument-tracker.pbtxt @@ -0,0 +1,13 @@ +path: "tensorflow.lite.OpHint.OpHintArgumentTracker" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'function_name\', \'unique_function_id\', \'node_name_prefix\', \'attr_name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "add" + argspec: "args=[\'self\', \'arg\', \'tag\', \'name\', \'aggregate\', \'index_override\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.lite.-op-hint.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.lite.-op-hint.pbtxt new file mode 100644 index 00000000000000..66e692a5a37920 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.lite.-op-hint.pbtxt @@ -0,0 +1,69 @@ +path: "tensorflow.lite.OpHint" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "AGGREGATE_FIRST" + mtype: "" + } + member { + name: "AGGREGATE_LAST" + mtype: "" + } + member { + name: "AGGREGATE_STACK" + mtype: "" + } + member { + name: "FUNCTION_AGGREGATE_ATTR" + mtype: "" + } + member { + name: "FUNCTION_INPUT_INDEX_ATTR" + mtype: "" + } + member { + name: "FUNCTION_NAME_ATTR" + mtype: "" + } + member { + name: "FUNCTION_OUTPUT_INDEX_ATTR" + mtype: "" + } + member { + name: "FUNCTION_SORT_INDEX_ATTR" + mtype: "" + } + member { + name: "FUNCTION_UUID_ATTR" + mtype: "" + } + member { + name: "OpHintArgumentTracker" + mtype: "" + } + member { + name: "TFLITE_INPUT_INDICES" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'function_name\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_input" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_inputs" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_output" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_outputs" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.lite.-t-f-lite-converter.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.lite.-t-f-lite-converter.pbtxt new file mode 100644 index 00000000000000..c955b1a04a4b8a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.lite.-t-f-lite-converter.pbtxt @@ -0,0 +1,33 @@ +path: "tensorflow.lite.TFLiteConverter" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'graph_def\', \'input_tensors\', \'output_tensors\', \'input_arrays_with_shape\', \'output_arrays\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "convert" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_frozen_graph" + argspec: "args=[\'cls\', \'graph_def_file\', \'input_arrays\', \'output_arrays\', \'input_shapes\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_keras_model_file" + argspec: "args=[\'cls\', \'model_file\', \'input_arrays\', \'input_shapes\', \'output_arrays\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "from_saved_model" + argspec: "args=[\'cls\', \'saved_model_dir\', \'input_arrays\', \'input_shapes\', \'output_arrays\', \'tag_set\', \'signature_key\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "from_session" + argspec: "args=[\'cls\', \'sess\', \'input_tensors\', \'output_tensors\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_arrays" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.lite.-toco-converter.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.lite.-toco-converter.pbtxt new file mode 100644 index 00000000000000..3ef90b8bc4646a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.lite.-toco-converter.pbtxt @@ -0,0 +1,24 @@ +path: "tensorflow.lite.TocoConverter" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "from_frozen_graph" + argspec: "args=[\'cls\', \'graph_def_file\', \'input_arrays\', \'output_arrays\', \'input_shapes\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_keras_model_file" + argspec: "args=[\'cls\', \'model_file\', \'input_arrays\', \'input_shapes\', \'output_arrays\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "from_saved_model" + argspec: "args=[\'cls\', \'saved_model_dir\', \'input_arrays\', \'input_shapes\', \'output_arrays\', \'tag_set\', \'signature_key\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "from_session" + argspec: "args=[\'cls\', \'sess\', \'input_tensors\', \'output_tensors\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.lite.constants.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.lite.constants.pbtxt new file mode 100644 index 00000000000000..08845553e55d3b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.lite.constants.pbtxt @@ -0,0 +1,31 @@ +path: "tensorflow.lite.constants" +tf_module { + member { + name: "FLOAT" + mtype: "" + } + member { + name: "GRAPHVIZ_DOT" + mtype: "" + } + member { + name: "INT32" + mtype: "" + } + member { + name: "INT64" + mtype: "" + } + member { + name: "QUANTIZED_UINT8" + mtype: "" + } + member { + name: "STRING" + mtype: "" + } + member { + name: "TFLITE" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.lite.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.lite.pbtxt new file mode 100644 index 00000000000000..f5013c250be847 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.lite.pbtxt @@ -0,0 +1,27 @@ +path: "tensorflow.lite" +tf_module { + member { + name: "Interpreter" + mtype: "" + } + member { + name: "OpHint" + mtype: "" + } + member { + name: "TFLiteConverter" + mtype: "" + } + member { + name: "TocoConverter" + mtype: "" + } + member { + name: "constants" + mtype: "" + } + member_method { + name: "toco_convert" + argspec: "args=[\'input_data\', \'input_tensors\', \'output_tensors\'], varargs=args, keywords=kwargs, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.manip.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.manip.pbtxt deleted file mode 100644 index d6924d26b9a8b9..00000000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.manip.pbtxt +++ /dev/null @@ -1,31 +0,0 @@ -path: "tensorflow.manip" -tf_module { - member_method { - name: "batch_to_space_nd" - argspec: "args=[\'input\', \'block_shape\', \'crops\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "gather_nd" - argspec: "args=[\'params\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "reshape" - argspec: "args=[\'tensor\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "reverse" - argspec: "args=[\'tensor\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "scatter_nd" - argspec: "args=[\'indices\', \'updates\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "space_to_batch_nd" - argspec: "args=[\'input\', \'block_shape\', \'paddings\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "tile" - argspec: "args=[\'input\', \'multiples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 28b03ef0a2b6b5..7c865bb0022536 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -104,10 +104,6 @@ tf_module { name: "OptimizerOptions" mtype: "" } - member { - name: "QUANTIZED_DTYPES" - mtype: "" - } member { name: "RegisterGradient" mtype: "" @@ -240,6 +236,10 @@ tf_module { name: "estimator" mtype: "" } + member { + name: "experimental" + mtype: "" + } member { name: "feature_column" mtype: "" @@ -313,15 +313,15 @@ tf_module { mtype: "" } member { - name: "logging" + name: "lite" mtype: "" } member { - name: "losses" + name: "logging" mtype: "" } member { - name: "manip" + name: "losses" mtype: "" } member { @@ -596,10 +596,6 @@ tf_module { name: "batch_to_space_nd" argspec: "args=[\'input\', \'block_shape\', \'crops\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "betainc" - argspec: "args=[\'a\', \'b\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "bitcast" argspec: "args=[\'input\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -628,18 +624,6 @@ tf_module { name: "cast" argspec: "args=[\'x\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "ceil" - argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "check_numerics" - argspec: "args=[\'tensor\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "cholesky" - argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "clip_by_average_norm" argspec: "args=[\'t\', \'clip_norm\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -656,10 +640,6 @@ tf_module { name: "clip_by_value" argspec: "args=[\'t\', \'clip_value_min\', \'clip_value_max\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "colocate_with" - argspec: "args=[\'op\', \'ignore_existing\'], varargs=None, keywords=None, defaults=[\'False\'], " - } member_method { name: "complex" argspec: "args=[\'real\', \'imag\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -712,10 +692,6 @@ tf_module { name: "create_partitioned_variables" argspec: "args=[\'shape\', \'slicing\', \'initializer\', \'dtype\', \'trainable\', \'collections\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\"\", \'True\', \'None\', \'None\', \'None\'], " } - member_method { - name: "cross" - argspec: "args=[\'a\', \'b\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "cumsum" argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], " @@ -724,42 +700,10 @@ tf_module { name: "custom_gradient" argspec: "args=[\'f\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "decode_base64" - argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "decode_compressed" - argspec: "args=[\'bytes\', \'compression_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], " - } - member_method { - name: "decode_json_example" - argspec: "args=[\'json_examples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "decode_raw" - argspec: "args=[\'bytes\', \'out_type\', \'little_endian\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " - } - member_method { - name: "dequantize" - argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\'], " - } member_method { name: "device" argspec: "args=[\'device_name_or_function\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "diag" - argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "diag_part" - argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "digamma" - argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "div" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -792,10 +736,6 @@ tf_module { name: "enable_eager_execution" argspec: "args=[\'config\', \'device_policy\', \'execution_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " } - member_method { - name: "encode_base64" - argspec: "args=[\'input\', \'pad\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " - } member_method { name: "ensure_shape" argspec: "args=[\'x\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -804,10 +744,6 @@ tf_module { name: "equal" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "erfc" - argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "executing_eagerly" argspec: "args=[], varargs=None, keywords=None, defaults=None" @@ -820,14 +756,6 @@ tf_module { name: "expand_dims" argspec: "args=[\'input\', \'axis\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " } - member_method { - name: "expm1" - argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "extract_image_patches" - argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "extract_volume_patches" argspec: "args=[\'input\', \'ksizes\', \'strides\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -836,34 +764,6 @@ tf_module { name: "eye" argspec: "args=[\'num_rows\', \'num_columns\', \'batch_shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \"\", \'None\'], " } - member_method { - name: "fake_quant_with_min_max_args" - argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'-6\', \'6\', \'8\', \'False\', \'None\'], " - } - member_method { - name: "fake_quant_with_min_max_args_gradient" - argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'-6\', \'6\', \'8\', \'False\', \'None\'], " - } - member_method { - name: "fake_quant_with_min_max_vars" - argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], " - } - member_method { - name: "fake_quant_with_min_max_vars_gradient" - argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], " - } - member_method { - name: "fake_quant_with_min_max_vars_per_channel" - argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], " - } - member_method { - name: "fake_quant_with_min_max_vars_per_channel_gradient" - argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], " - } - member_method { - name: "fft" - argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "fft2d" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -960,10 +860,6 @@ tf_module { name: "identity_n" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "ifft" - argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "ifft2d" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -972,14 +868,6 @@ tf_module { name: "ifft3d" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "igamma" - argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "igammac" - argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "import_graph_def" argspec: "args=[\'graph_def\', \'input_map\', \'return_elements\', \'name\', \'op_dict\', \'producer_op_list\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " @@ -992,22 +880,6 @@ tf_module { name: "initialize_all_tables" argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'init_all_tables\'], " } - member_method { - name: "invert_permutation" - argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "is_finite" - argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "is_inf" - argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "is_nan" - argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "less" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -1016,10 +888,6 @@ tf_module { name: "less_equal" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "lgamma" - argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "lin_space" argspec: "args=[\'start\', \'stop\', \'num\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -1076,50 +944,14 @@ tf_module { name: "map_fn" argspec: "args=[\'fn\', \'elems\', \'dtype\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'False\', \'True\', \'None\'], " } - member_method { - name: "matching_files" - argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "matmul" argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], " } - member_method { - name: "matrix_band_part" - argspec: "args=[\'input\', \'num_lower\', \'num_upper\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "matrix_determinant" - argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "matrix_diag" - argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "matrix_diag_part" - argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "matrix_inverse" - argspec: "args=[\'input\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " - } - member_method { - name: "matrix_set_diag" - argspec: "args=[\'input\', \'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "matrix_solve" - argspec: "args=[\'matrix\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " - } member_method { name: "matrix_square_root" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "matrix_triangular_solve" - argspec: "args=[\'matrix\', \'rhs\', \'lower\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'None\'], " - } member_method { name: "maximum" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -1204,10 +1036,6 @@ tf_module { name: "parse_single_sequence_example" argspec: "args=[\'serialized\', \'context_features\', \'sequence_features\', \'example_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } - member_method { - name: "parse_tensor" - argspec: "args=[\'serialized\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "placeholder" argspec: "args=[\'dtype\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " @@ -1216,10 +1044,6 @@ tf_module { name: "placeholder_with_default" argspec: "args=[\'input\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "polygamma" - argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "pow" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -1232,18 +1056,10 @@ tf_module { name: "py_func" argspec: "args=[\'func\', \'inp\', \'Tout\', \'stateful\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " } - member_method { - name: "qr" - argspec: "args=[\'input\', \'full_matrices\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " - } member_method { name: "quantize_v2" argspec: "args=[\'input\', \'min_range\', \'max_range\', \'T\', \'mode\', \'name\', \'round_mode\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'HALF_AWAY_FROM_ZERO\'], " } - member_method { - name: "quantized_concat" - argspec: "args=[\'concat_dim\', \'values\', \'input_mins\', \'input_maxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "random_crop" argspec: "args=[\'value\', \'size\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " @@ -1268,18 +1084,10 @@ tf_module { name: "rank" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "read_file" - argspec: "args=[\'filename\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "realdiv" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "reciprocal" - argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "reduce_all" argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " @@ -1336,14 +1144,6 @@ tf_module { name: "reverse_sequence" argspec: "args=[\'input\', \'seq_lengths\', \'seq_axis\', \'batch_axis\', \'name\', \'seq_dim\', \'batch_dim\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " } - member_method { - name: "reverse_v2" - argspec: "args=[\'tensor\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "rint" - argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "roll" argspec: "args=[\'input\', \'shift\', \'axis\'], varargs=None, keywords=None, defaults=None" @@ -1352,10 +1152,6 @@ tf_module { name: "round" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "rsqrt" - argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "saturate_cast" argspec: "args=[\'value\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -1392,26 +1188,6 @@ tf_module { name: "searchsorted" argspec: "args=[\'sorted_sequence\', \'values\', \'side\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'left\', \"\", \'None\'], " } - member_method { - name: "segment_max" - argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "segment_mean" - argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "segment_min" - argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "segment_prod" - argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "segment_sum" - argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "sequence_mask" argspec: "args=[\'lengths\', \'maxlen\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"\", \'None\'], " @@ -1508,10 +1284,6 @@ tf_module { name: "square" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "squared_difference" - argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "squeeze" argspec: "args=[\'input\', \'axis\', \'name\', \'squeeze_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " @@ -1528,34 +1300,10 @@ tf_module { name: "strided_slice" argspec: "args=[\'input_\', \'begin\', \'end\', \'strides\', \'begin_mask\', \'end_mask\', \'ellipsis_mask\', \'new_axis_mask\', \'shrink_axis_mask\', \'var\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'0\', \'0\', \'0\', \'0\', \'None\', \'None\'], " } - member_method { - name: "string_join" - argspec: "args=[\'inputs\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], " - } member_method { name: "string_split" argspec: "args=[\'source\', \'delimiter\', \'skip_empty\'], varargs=None, keywords=None, defaults=[\' \', \'True\'], " } - member_method { - name: "string_strip" - argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "string_to_hash_bucket" - argspec: "args=[\'string_tensor\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "string_to_hash_bucket_fast" - argspec: "args=[\'input\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "string_to_hash_bucket_strong" - argspec: "args=[\'input\', \'num_buckets\', \'key\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "string_to_number" - argspec: "args=[\'string_tensor\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " - } member_method { name: "substr" argspec: "args=[\'input\', \'pos\', \'len\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], " @@ -1652,22 +1400,6 @@ tf_module { name: "unravel_index" argspec: "args=[\'indices\', \'dims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "unsorted_segment_max" - argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "unsorted_segment_min" - argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "unsorted_segment_prod" - argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "unsorted_segment_sum" - argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "unstack" argspec: "args=[\'value\', \'num\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'unstack\'], " @@ -1688,10 +1420,6 @@ tf_module { name: "while_loop" argspec: "args=[\'cond\', \'body\', \'loop_vars\', \'shape_invariants\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'name\', \'maximum_iterations\', \'return_same_structure\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'None\', \'None\', \'False\'], " } - member_method { - name: "write_file" - argspec: "args=[\'filename\', \'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "zeros" argspec: "args=[\'shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " @@ -1700,8 +1428,4 @@ tf_module { name: "zeros_like" argspec: "args=[\'tensor\', \'dtype\', \'name\', \'optimize\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], " } - member_method { - name: "zeta" - argspec: "args=[\'x\', \'q\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.-builder.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.-builder.pbtxt deleted file mode 100644 index 67457de070830d..00000000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.-builder.pbtxt +++ /dev/null @@ -1,21 +0,0 @@ -path: "tensorflow.saved_model.Builder" -tf_class { - is_instance: "" - is_instance: "" - member_method { - name: "__init__" - argspec: "args=[\'self\', \'export_dir\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "add_meta_graph" - argspec: "args=[\'self\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\', \'saver\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\', \'None\'], " - } - member_method { - name: "add_meta_graph_and_variables" - argspec: "args=[\'self\', \'sess\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\', \'saver\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\', \'None\'], " - } - member_method { - name: "save" - argspec: "args=[\'self\', \'as_text\'], varargs=None, keywords=None, defaults=[\'False\'], " - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.pbtxt index dc26a67fa0e1f8..d57936a2f1cb9e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.pbtxt @@ -8,10 +8,6 @@ tf_module { name: "ASSETS_KEY" mtype: "" } - member { - name: "Builder" - mtype: "" - } member { name: "CLASSIFY_INPUTS" mtype: "" @@ -104,34 +100,14 @@ tf_module { name: "build_signature_def" argspec: "args=[\'inputs\', \'outputs\', \'method_name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " } - member_method { - name: "build_tensor_info" - argspec: "args=[\'tensor\'], varargs=None, keywords=None, defaults=None" - } member_method { name: "classification_signature_def" argspec: "args=[\'examples\', \'classes\', \'scores\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "get_tensor_from_tensor_info" - argspec: "args=[\'tensor_info\', \'graph\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } member_method { name: "is_valid_signature" argspec: "args=[\'signature_def\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "load" - argspec: "args=[\'sess\', \'tags\', \'export_dir\', \'import_scope\'], varargs=None, keywords=saver_kwargs, defaults=[\'None\'], " - } - member_method { - name: "main_op" - argspec: "args=[], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "main_op_with_restore" - argspec: "args=[\'restore_op_name\'], varargs=None, keywords=None, defaults=None" - } member_method { name: "maybe_saved_model_directory" argspec: "args=[\'export_dir\'], varargs=None, keywords=None, defaults=None" @@ -145,7 +121,7 @@ tf_module { argspec: "args=[\'examples\', \'predictions\'], varargs=None, keywords=None, defaults=None" } member_method { - name: "simple_save" - argspec: "args=[\'session\', \'export_dir\', \'inputs\', \'outputs\', \'legacy_init_op\'], varargs=None, keywords=None, defaults=[\'None\'], " + name: "save" + argspec: "args=[\'obj\', \'export_dir\', \'signatures\'], varargs=None, keywords=None, defaults=[\'None\'], " } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-sync-replicas-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-sync-replicas-optimizer.pbtxt deleted file mode 100644 index 2c0fda3c72b7e1..00000000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.train.-sync-replicas-optimizer.pbtxt +++ /dev/null @@ -1,63 +0,0 @@ -path: "tensorflow.train.SyncReplicasOptimizer" -tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" - member { - name: "GATE_GRAPH" - mtype: "" - } - member { - name: "GATE_NONE" - mtype: "" - } - member { - name: "GATE_OP" - mtype: "" - } - member_method { - name: "__init__" - argspec: "args=[\'self\', \'opt\', \'replicas_to_aggregate\', \'total_num_replicas\', \'variable_averages\', \'variables_to_average\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'sync_replicas\'], " - } - member_method { - name: "apply_gradients" - argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } - member_method { - name: "compute_gradients" - argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" - } - member_method { - name: "get_chief_queue_runner" - argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "get_init_tokens_op" - argspec: "args=[\'self\', \'num_tokens\'], varargs=None, keywords=None, defaults=[\'-1\'], " - } - member_method { - name: "get_name" - argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "get_slot" - argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" - } - member_method { - name: "get_slot_names" - argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" - } - member_method { - name: "make_session_run_hook" - argspec: "args=[\'self\', \'is_chief\', \'num_tokens\'], varargs=None, keywords=None, defaults=[\'-1\'], " - } - member_method { - name: "minimize" - argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " - } - member_method { - name: "variables" - argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt index c2dc4140e8ebe1..582c0ee3d035d6 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt @@ -212,10 +212,6 @@ tf_module { name: "Supervisor" mtype: "" } - member { - name: "SyncReplicasOptimizer" - mtype: "" - } member { name: "VocabInfo" mtype: "" diff --git a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py index fae35135f278f4..70df38ba8b8c46 100644 --- a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py +++ b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py @@ -51,6 +51,10 @@ _NORMALIZE_TYPE[""] = "" _NORMALIZE_ISINSTANCE = { "": # pylint: disable=line-too-long + "", + "": # pylint: disable=line-too-long "", diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh index 66c4f1cbed80c9..2c348a0e3390af 100755 --- a/tensorflow/tools/ci_build/ci_sanity.sh +++ b/tensorflow/tools/ci_build/ci_sanity.sh @@ -439,9 +439,9 @@ cmd_status(){ # out by default in TF WORKSPACE file. do_bazel_nobuild() { BUILD_TARGET="//tensorflow/..." - BUILD_TARGET="${BUILD_TARGET} -//tensorflow/contrib/lite/java/demo/app/..." - BUILD_TARGET="${BUILD_TARGET} -//tensorflow/contrib/lite/examples/android/..." - BUILD_TARGET="${BUILD_TARGET} -//tensorflow/contrib/lite/schema/..." + BUILD_TARGET="${BUILD_TARGET} -//tensorflow/lite/java/demo/app/..." + BUILD_TARGET="${BUILD_TARGET} -//tensorflow/lite/examples/android/..." + BUILD_TARGET="${BUILD_TARGET} -//tensorflow/lite/schema/..." BUILD_CMD="bazel build --nobuild ${BAZEL_FLAGS} -- ${BUILD_TARGET}" ${BUILD_CMD} diff --git a/tensorflow/tools/ci_build/install/install_deb_packages.sh b/tensorflow/tools/ci_build/install/install_deb_packages.sh index 5e5c2588467c31..989f2a92eb6e59 100755 --- a/tensorflow/tools/ci_build/install/install_deb_packages.sh +++ b/tensorflow/tools/ci_build/install/install_deb_packages.sh @@ -38,13 +38,13 @@ if [[ "$ubuntu_version" == "14" ]]; then apt-get dist-upgrade -y fi +## TODO(yifeif) remove ffmpeg once ffmpeg is removed from contrib apt-get install -y --no-install-recommends \ autoconf \ automake \ build-essential \ clang-format-3.8 \ curl \ - ## TODO(yifeif) remove once ffmpeg is removed from contrib ffmpeg \ git \ libcurl4-openssl-dev \ diff --git a/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh b/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh index 5c5a36139f50e8..3efd994d783d8f 100755 --- a/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh +++ b/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh @@ -35,4 +35,4 @@ bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test,-nomac,-no_mac \ --test_timeout 300,450,1200,3600 \ --test_size_filters=small,medium --config=opt \ --jobs=${N_JOBS} --build_tests_only --test_output=errors -k -- \ - //tensorflow/contrib/... -//tensorflow/contrib/lite/... + //tensorflow/contrib/... -//tensorflow/lite/... diff --git a/tensorflow/tools/compatibility/ast_edits.py b/tensorflow/tools/compatibility/ast_edits.py index 23cc4a21a9e6f8..a5b9fbdae8be9a 100644 --- a/tensorflow/tools/compatibility/ast_edits.py +++ b/tensorflow/tools/compatibility/ast_edits.py @@ -184,6 +184,17 @@ def _rename_functions(self, node, full_name): except KeyError: pass + def _print_warning_for_function(self, node, full_name): + function_warnings = self._api_change_spec.function_warnings + try: + warning_message = function_warnings[full_name] + warning_message = warning_message.replace("", full_name) + self._file_edit.add(warning_message, + node.lineno, node.col_offset, full_name, full_name, + error="%s requires manual check." % full_name) + except KeyError: + pass + def _get_attribute_full_path(self, node): """Traverse an attribute to generate a full name e.g. tf.foo.bar. @@ -350,6 +361,7 @@ def visit_Attribute(self, node): # pylint: disable=invalid-name full_name = self._get_attribute_full_path(node) if full_name: self._rename_functions(node, full_name) + self._print_warning_for_function(node, full_name) if full_name in self._api_change_spec.change_to_function: if not hasattr(node, "is_function_for_call"): new_text = full_name + "()" diff --git a/tensorflow/tools/compatibility/renames_v2.py b/tensorflow/tools/compatibility/renames_v2.py index 9362b36bd2d073..260278878fa8d9 100644 --- a/tensorflow/tools/compatibility/renames_v2.py +++ b/tensorflow/tools/compatibility/renames_v2.py @@ -141,18 +141,12 @@ 'tf.reverse_v2': 'tf.reverse', 'tf.rint': 'tf.math.rint', 'tf.rsqrt': 'tf.math.rsqrt', - 'tf.saved_model.builder.SavedModelBuilder': 'tf.saved_model.Builder', - 'tf.saved_model.loader.load': 'tf.saved_model.load', 'tf.saved_model.loader.maybe_saved_model_directory': 'tf.saved_model.maybe_saved_model_directory', - 'tf.saved_model.main_op.main_op': 'tf.saved_model.main_op', - 'tf.saved_model.main_op.main_op_with_restore': 'tf.saved_model.main_op_with_restore', 'tf.saved_model.signature_def_utils.build_signature_def': 'tf.saved_model.build_signature_def', 'tf.saved_model.signature_def_utils.classification_signature_def': 'tf.saved_model.classification_signature_def', 'tf.saved_model.signature_def_utils.is_valid_signature': 'tf.saved_model.is_valid_signature', 'tf.saved_model.signature_def_utils.predict_signature_def': 'tf.saved_model.predict_signature_def', 'tf.saved_model.signature_def_utils.regression_signature_def': 'tf.saved_model.regression_signature_def', - 'tf.saved_model.utils.build_tensor_info': 'tf.saved_model.build_tensor_info', - 'tf.saved_model.utils.get_tensor_from_tensor_info': 'tf.saved_model.get_tensor_from_tensor_info', 'tf.segment_max': 'tf.math.segment_max', 'tf.segment_mean': 'tf.math.segment_mean', 'tf.segment_min': 'tf.math.segment_min', diff --git a/tensorflow/tools/compatibility/tf_upgrade.py b/tensorflow/tools/compatibility/tf_upgrade.py index 96705b1a4c27e7..2dabf7834dad62 100644 --- a/tensorflow/tools/compatibility/tf_upgrade.py +++ b/tensorflow/tools/compatibility/tf_upgrade.py @@ -178,6 +178,9 @@ def __init__(self): # Specially handled functions. self.function_handle = {"tf.reverse": self._reverse_handler} + # Warnings that should be printed if corresponding functions are used. + self.function_warnings = {} + @staticmethod def _reverse_handler(file_edit_recorder, node): # TODO(aselle): Could check for a literal list of bools and try to convert diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py index 53c546b10c01b0..dda45468fcdd3b 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py @@ -19,7 +19,6 @@ from __future__ import print_function import argparse -import functools from tensorflow.tools.compatibility import ast_edits from tensorflow.tools.compatibility import renames_v2 @@ -35,6 +34,48 @@ def __init__(self): # Mapping from function to the new name of the function self.function_renames = renames_v2.renames + # pylint: disable=line-too-long + self.function_renames.update({ + "tf.FixedLengthRecordReader": "tf.compat.v1.FixedLengthRecordReader", + "tf.IdentityReader": "tf.compat.v1.IdentityReader", + "tf.LMDBReader": "tf.compat.v1.LMDBReader", + "tf.ReaderBase": "tf.compat.v1.ReaderBase", + "tf.TFRecordReader": "tf.compat.v1.TFRecordReader", + "tf.TextLineReader": "tf.compat.v1.TextLineReader", + "tf.WholeFileReader": "tf.compat.v1.WholeFileReader", + "tf.saved_model.builder.SavedModelBuilder": "tf.compat.v1.saved_model.Builder", + "tf.saved_model.loader.load": "tf.compat.v1.saved_model.load", + "tf.saved_model.main_op.main_op": "tf.compat.v1.saved_model.main_op", + "tf.saved_model.main_op.main_op_with_restore": "tf.compat.v1.saved_model.main_op_with_restore", + "tf.saved_model.simple_save": "tf.compat.v1.saved_model.simple_save", + "tf.saved_model.utils.build_tensor_info": "tf.compat.v1.saved_model.build_tensor_info", + "tf.saved_model.utils.get_tensor_from_tensor_info": "tf.compat.v1.saved_model.get_tensor_from_tensor_info", + "tf.train.QueueRunner": "tf.compat.v1.QueueRunner", + "tf.train.add_queue_runner": "tf.compat.v1.add_queue_runner", + "tf.train.batch": "tf.compat.v1.train.batch", + "tf.train.batch_join": "tf.compat.v1.train.batch_join", + "tf.train.input_producer": "tf.compat.v1.train.input_producer", + "tf.train.limit_epochs": "tf.compat.v1.train.limit_epochs", + "tf.train.maybe_batch": "tf.compat.v1.train.maybe_batch", + "tf.train.maybe_batch_join": "tf.compat.v1.train.maybe_batch_join", + "tf.train.maybe_shuffle_batch": "tf.compat.v1.train.maybe_shuffle_batch", + "tf.train.maybe_shuffle_batch_join": "tf.compat.v1.train.maybe_shuffle_batch_join", + "tf.train.queue_runner.QueueRunner": "tf.compat.v1.queue_runner.QueueRunner", + "tf.train.queue_runner.add_queue_runner": "tf.compat.v1.queue_runner.add_queue_runner", + "tf.train.queue_runner.start_queue_runners": "tf.compat.v1.queue_runner.start_queue_runners", + "tf.train.range_input_producer": "tf.compat.v1.train.range_input_producer", + "tf.train.shuffle_batch": "tf.compat.v1.train.shuffle_batch", + "tf.train.shuffle_batch_join": "tf.compat.v1.train.shuffle_batch_join", + "tf.train.slice_input_producer": "tf.compat.v1.train.slice_input_producer", + "tf.train.string_input_producer": "tf.compat.v1.train.string_input_producer", + "tf.train.start_queue_runners": "tf.compat.v1.start_queue_runners", + }) + # pylint: enable=line-too-long + self.function_renames["tf.colocate_with"] = "tf.compat.v1.colocate_with" + + # TODO(amitpatankar): Fix the function rename script + # to handle constants without hardcoding. + self.function_renames["QUANTIZED_DTYPES"] = "dtypes.QUANTIZED_DTYPES" # Variables that should be changed to functions. self.change_to_function = {} @@ -46,29 +87,28 @@ def __init__(self): # Specially handled functions. self.function_handle = {} - for decay in ["tf.train.exponential_decay", "tf.train.piecewise_constant", - "tf.train.polynomial_decay", "tf.train.natural_exp_decay", - "tf.train.inverse_time_decay", "tf.train.cosine_decay", - "tf.train.cosine_decay_restarts", - "tf.train.linear_cosine_decay", - "tf.train.noisy_linear_cosine_decay"]: - self.function_handle[decay] = functools.partial( - self._learning_rate_decay_handler, decay_name=decay) - - @staticmethod - def _learning_rate_decay_handler(file_edit_recorder, node, decay_name): - comment = ("ERROR: %s has been changed to return a callable instead of a " - "tensor when graph building, but its functionality remains " - "unchanged during eager execution (returns a callable like " - "before). The converter cannot detect and fix this reliably, so " - "you need to inspect this usage manually.\n") % decay_name - file_edit_recorder.add( - comment, - node.lineno, - node.col_offset, - decay_name, - decay_name, - error="%s requires manual check." % decay_name) + + decay_function_comment = ( + "ERROR: has been changed to return a callable instead " + "of a tensor when graph building, but its functionality remains " + "unchanged during eager execution (returns a callable like " + "before). The converter cannot detect and fix this reliably, so " + "you need to inspect this usage manually.\n" + ) + + # Function warnings. placeholder inside warnings will be + # replaced by function name. + self.function_warnings = { + "tf.train.exponential_decay": decay_function_comment, + "tf.train.piecewise_constant": decay_function_comment, + "tf.train.polynomial_decay": decay_function_comment, + "tf.train.natural_exp_decay": decay_function_comment, + "tf.train.inverse_time_decay": decay_function_comment, + "tf.train.cosine_decay": decay_function_comment, + "tf.train.cosine_decay_restarts": decay_function_comment, + "tf.train.linear_cosine_decay": decay_function_comment, + "tf.train.noisy_linear_cosine_decay": decay_function_comment, + } if __name__ == "__main__": diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py index f606d202a608cc..6a0c3a787dafdf 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py @@ -73,9 +73,10 @@ def testLearningRateDecay(self): "tf.train.noisy_linear_cosine_decay"]: text = "%s(a, b)\n" % decay - _, unused_report, errors, new_text = self._upgrade(text) + _, report, errors, new_text = self._upgrade(text) self.assertEqual(text, new_text) self.assertEqual(errors, ["test.py:1: %s requires manual check." % decay]) + self.assertIn("%s has been changed" % decay, report) class TestUpgradeFiles(test_util.TensorFlowTestCase): diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 970dd49e117dc3..f9b0a1129b71ff 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -153,6 +153,7 @@ filegroup( "@highwayhash//:LICENSE", "@icu//:icu4c/LICENSE", "@jpeg//:LICENSE.md", + "@keras_applications_archive//:LICENSE", "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", "@nasm//:LICENSE", @@ -226,15 +227,15 @@ sh_binary( data = select({ "//tensorflow:windows": [ ":simple_console_for_windows", - "//tensorflow/contrib/lite/python:interpreter_test_data", - "//tensorflow/contrib/lite/python:tflite_convert", - "//tensorflow/contrib/lite/toco/python:toco_from_protos", + "//tensorflow/lite/python:interpreter_test_data", + "//tensorflow/lite/python:tflite_convert", + "//tensorflow/lite/toco/python:toco_from_protos", ], "//conditions:default": COMMON_PIP_DEPS + [ ":simple_console", - "//tensorflow/contrib/lite/python:interpreter_test_data", - "//tensorflow/contrib/lite/python:tflite_convert", - "//tensorflow/contrib/lite/toco/python:toco_from_protos", + "//tensorflow/lite/python:interpreter_test_data", + "//tensorflow/lite/python:tflite_convert", + "//tensorflow/lite/toco/python:toco_from_protos", ], }) + if_mkl_ml(["//third_party/mkl:intel_binary_blob"]), ) diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh index c62271c5cb1731..439b5428b3b7bf 100755 --- a/tensorflow/tools/pip_package/build_pip_package.sh +++ b/tensorflow/tools/pip_package/build_pip_package.sh @@ -120,7 +120,7 @@ function prepare_src() { fi mkdir "${TMPDIR}/tensorflow/aux-bin" # Install toco as a binary in aux-bin. - cp bazel-bin/tensorflow/contrib/lite/python/tflite_convert ${TMPDIR}/tensorflow/aux-bin/ + cp bazel-bin/tensorflow/lite/python/tflite_convert ${TMPDIR}/tensorflow/aux-bin/ fi # protobuf pip package doesn't ship with header files. Copy the headers diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py index c77845caee2bed..ff821b864300c1 100644 --- a/tensorflow/tools/pip_package/pip_smoke_test.py +++ b/tensorflow/tools/pip_package/pip_smoke_test.py @@ -37,7 +37,7 @@ def GetBuild(dir_base): for root, _, files in os.walk(dir_base): for name in files: if (name == "BUILD" and - root.find("tensorflow/contrib/lite/examples/android") == -1): + root.find("tensorflow/lite/examples/android") == -1): items.append("//" + root + ":all") return items @@ -85,14 +85,14 @@ def BuildPyTestDependencies(): # contrib "//tensorflow/contrib/session_bundle:session_bundle_half_plus_two", "//tensorflow/contrib/keras:testing_utils", - "//tensorflow/contrib/lite/experimental/examples/lstm:tflite_lstm", - "//tensorflow/contrib/lite/experimental/examples/lstm:tflite_lstm.py", - "//tensorflow/contrib/lite/experimental/examples/lstm:unidirectional_sequence_lstm_test", # pylint:disable=line-too-long - "//tensorflow/contrib/lite/experimental/examples/lstm:unidirectional_sequence_lstm_test.py", # pylint:disable=line-too-long - "//tensorflow/contrib/lite/python:interpreter", - "//tensorflow/contrib/lite/python:interpreter_test", - "//tensorflow/contrib/lite/python:interpreter.py", - "//tensorflow/contrib/lite/python:interpreter_test.py", + "//tensorflow/lite/experimental/examples/lstm:tflite_lstm", + "//tensorflow/lite/experimental/examples/lstm:tflite_lstm.py", + "//tensorflow/lite/experimental/examples/lstm:unidirectional_sequence_lstm_test", # pylint:disable=line-too-long + "//tensorflow/lite/experimental/examples/lstm:unidirectional_sequence_lstm_test.py", # pylint:disable=line-too-long + "//tensorflow/lite/python:interpreter", + "//tensorflow/lite/python:interpreter_test", + "//tensorflow/lite/python:interpreter.py", + "//tensorflow/lite/python:interpreter_test.py", "//tensorflow/contrib/ffmpeg:test_data", "//tensorflow/contrib/fused_conv:fused_conv2d_bias_activation_op_test_base", "//tensorflow/contrib/hadoop:test_data", diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 83df96fb16ac85..07475cc0c4de6b 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -97,9 +97,9 @@ # pylint: disable=line-too-long CONSOLE_SCRIPTS = [ 'freeze_graph = tensorflow.python.tools.freeze_graph:run_main', - 'toco_from_protos = tensorflow.contrib.lite.toco.python.toco_from_protos:main', - 'tflite_convert = tensorflow.contrib.lite.python.tflite_convert:main', - 'toco = tensorflow.contrib.lite.python.tflite_convert:main', + 'toco_from_protos = tensorflow.lite.toco.python.toco_from_protos:main', + 'tflite_convert = tensorflow.lite.python.tflite_convert:main', + 'toco = tensorflow.lite.python.tflite_convert:main', 'saved_model_cli = tensorflow.python.tools.saved_model_cli:main', # We need to keep the TensorBoard command, even though the console script # is now declared by the tensorboard pip package. If we remove the diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 44885d2abbf722..d3a3204b234bfe 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -27,6 +27,7 @@ load("//third_party/icu:workspace.bzl", icu = "repo") load("//third_party/jpeg:workspace.bzl", jpeg = "repo") load("//third_party/nasm:workspace.bzl", nasm = "repo") load("//third_party/kissfft:workspace.bzl", kissfft = "repo") +load("//third_party/keras_applications_archive:workspace.bzl", keras_applications = "repo") def initialize_third_party(): """ Load third party repositories. See above load() statements. """ @@ -34,6 +35,7 @@ def initialize_third_party(): flatbuffers() highwayhash() icu() + keras_applications() kissfft() jpeg() nasm() @@ -121,11 +123,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "com_google_absl", build_file = clean_dep("//third_party:com_google_absl.BUILD"), - sha256 = "1dd634982ef56c47b6f425f74c906dc28ff10cf060bb991cd614365eb2ad98d4", - strip_prefix = "abseil-cpp-f86f9413856b65afdd61fea938d684b8ab73115a", + sha256 = "3cf6132129ba87f0781c383bfaf381b7174b5818e81fffcc5d04bb451154f0f2", + strip_prefix = "abseil-cpp-f95179062eb65ce40895cc76f1398cce25394369", urls = [ - "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/f86f9413856b65afdd61fea938d684b8ab73115a.tar.gz", - "https://github.com/abseil/abseil-cpp/archive/f86f9413856b65afdd61fea938d684b8ab73115a.tar.gz", + "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/f95179062eb65ce40895cc76f1398cce25394369.tar.gz", + "https://github.com/abseil/abseil-cpp/archive/f95179062eb65ce40895cc76f1398cce25394369.tar.gz", ], ) @@ -470,11 +472,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "llvm", build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"), - sha256 = "45fa15cb3fa36535b8e8b6eb0b54e146d3fba3b77924151e4c0827414f409563", - strip_prefix = "llvm-d362f0fbbba60765a35260d349608f382ffaa0ed", + sha256 = "2342cb98083eb1191a8411542dcd57cb3efc28677be4412e166f40cf22bd2b8c", + strip_prefix = "llvm-3fe1b12fca949399a3334a072ee7f96e2b6f557e", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/d362f0fbbba60765a35260d349608f382ffaa0ed.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/d362f0fbbba60765a35260d349608f382ffaa0ed.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/3fe1b12fca949399a3334a072ee7f96e2b6f557e.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/3fe1b12fca949399a3334a072ee7f96e2b6f557e.tar.gz", ], ) diff --git a/third_party/keras_applications_archive/BUILD b/third_party/keras_applications_archive/BUILD new file mode 100644 index 00000000000000..82bab3ffd96463 --- /dev/null +++ b/third_party/keras_applications_archive/BUILD @@ -0,0 +1 @@ +# This empty BUILD file is required to make Bazel treat this directory as a package. diff --git a/third_party/keras_applications_archive/BUILD.bazel b/third_party/keras_applications_archive/BUILD.bazel new file mode 100644 index 00000000000000..57c8f597c7f64c --- /dev/null +++ b/third_party/keras_applications_archive/BUILD.bazel @@ -0,0 +1,31 @@ +# Description: Keras Applications: set of pre-trained deep learning models. + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # MIT + +exports_files(["LICENSE"]) + +py_library( + name = "keras_applications", + srcs = [ + "keras_applications/__init__.py", + "keras_applications/densenet.py", + "keras_applications/imagenet_utils.py", + "keras_applications/inception_resnet_v2.py", + "keras_applications/inception_v3.py", + "keras_applications/mobilenet.py", + "keras_applications/mobilenet_v2.py", + "keras_applications/nasnet.py", + "keras_applications/resnet50.py", + "keras_applications/vgg16.py", + "keras_applications/vgg19.py", + "keras_applications/xception.py", + ], + deps = [ + "@org_tensorflow//third_party/py/numpy", + "@six_archive//:six", + ], +) diff --git a/third_party/keras_applications_archive/workspace.bzl b/third_party/keras_applications_archive/workspace.bzl new file mode 100644 index 00000000000000..e90630fa974fb9 --- /dev/null +++ b/third_party/keras_applications_archive/workspace.bzl @@ -0,0 +1,15 @@ +"""Loads Keras-applications python package.""" + +load("//third_party:repo.bzl", "third_party_http_archive") + +def repo(): + third_party_http_archive( + name = "keras_applications_archive", + strip_prefix = "keras-applications-1.0.6", + sha256 = "2cb412c97153160ec267b238e958d281ac3532b139cab42045c2d7086a157c21", + urls = [ + "http://mirror.bazel.build/github.com/keras-team/keras-applications/archive/1.0.6.tar.gz", + "https://github.com/keras-team/keras-applications/archive/1.0.6.tar.gz", + ], + build_file = "//third_party/keras_applications_archive:BUILD.bazel", + ) diff --git a/third_party/nccl/archive.BUILD b/third_party/nccl/archive.BUILD index f57f04c75ed64f..c0833828a736a1 100644 --- a/third_party/nccl/archive.BUILD +++ b/third_party/nccl/archive.BUILD @@ -7,10 +7,10 @@ exports_files(["LICENSE.txt"]) load( "@local_config_nccl//:build_defs.bzl", - "device_link", "gen_nccl_h", "nccl_library", "rdc_copts", + "rdc_library", ) load( "@local_config_cuda//cuda:build_defs.bzl", @@ -136,9 +136,9 @@ nccl_library( linkstatic = True, ) -device_link( +rdc_library( name = "device_code", - srcs = [ + deps = [ ":functions", ":max", ":min", @@ -167,13 +167,8 @@ nccl_library( copts = cuda_default_copts(), deps = [ ":device_code", - ":functions", ":include_hdrs", - ":max", - ":min", - ":prod", ":src_hdrs", - ":sum", ], visibility = ["//visibility:public"], ) diff --git a/third_party/nccl/build_defs.bzl.tpl b/third_party/nccl/build_defs.bzl.tpl index bb6518753be032..42de79c411c844 100644 --- a/third_party/nccl/build_defs.bzl.tpl +++ b/third_party/nccl/build_defs.bzl.tpl @@ -43,8 +43,7 @@ def _process_srcs_impl(ctx): substitutions = { "\"collectives.h": "\"collectives/collectives.h", "\"../collectives.h": "\"collectives/collectives.h", - "#if __CUDACC_VER_MAJOR__": - "#if defined __CUDACC_VER_MAJOR__ && __CUDACC_VER_MAJOR__", + "#if __CUDACC_VER_MAJOR__": "#if defined __CUDACC_VER_MAJOR__ && __CUDACC_VER_MAJOR__", # Substitutions are applied in order. "std::nullptr_t": "nullptr_t", "nullptr_t": "std::nullptr_t", @@ -140,13 +139,16 @@ _gen_link_src = rule( ) """Patches the include directives for the link.stub file.""" -def device_link(name, srcs): - """Links seperately compiled relocatable device code into a cc_library.""" +def rdc_library(name, deps): + """Produces a cc_library from deps containing relocatable device code.""" - # From .a and .pic.a archives, just use the latter. + # From .a and .pic.a archives, just use the latter. Otherwise we get + # multiply defined symbols. + # TODO(csigg): C++ Sandwich once available should allow passing this target + # to a cc_library dependency, which would avoid the linking order issue. _filter( - name = name + "_pic_a", - srcs = srcs, + name = name + "_deps_a", + srcs = deps, suffix = ".pic.a", ) @@ -160,10 +162,8 @@ def device_link(name, srcs): cmd = ("$(location %s) " % nvlink + select({ # NCCL is only supported on Linux. - "@org_tensorflow//tensorflow:linux_x86_64": - "--cpu-arch=X86_64 ", - "@org_tensorflow//tensorflow:linux_ppc64le": - "--cpu-arch=PPC64LE ", + "@org_tensorflow//tensorflow:linux_x86_64": "--cpu-arch=X86_64 ", + "@org_tensorflow//tensorflow:linux_ppc64le": "--cpu-arch=PPC64LE ", "//conditions:default": "", }) + "--arch=%s $(SRCS) " % arch + @@ -172,7 +172,7 @@ def device_link(name, srcs): native.genrule( name = "%s_%s" % (name, arch), outs = [register_hdr, cubin], - srcs = [name + "_pic_a"], + srcs = [name + "_deps_a"], cmd = cmd, tools = [nvlink], ) @@ -182,8 +182,9 @@ def device_link(name, srcs): # Generate fatbin header from all cubins. fatbin_hdr = name + ".fatbin.h" fatbinary = "@local_config_nccl//:cuda/bin/fatbinary" - cmd = ("PATH=$$CUDA_TOOLKIT_PATH/bin:$$PATH " + # for bin2c - "$(location %s) -64 --cmdline=--compile-only --link " % fatbinary + + bin2c = "@local_config_nccl//:cuda/bin/bin2c" + cmd = ("$(location %s) -64 --cmdline=--compile-only " % fatbinary + + "--link --bin2c-path $$(dirname $(location %s)) " % bin2c + "--compress-all %s --create=%%{name}.fatbin " % " ".join(images) + "--embedded-fatbin=$@") native.genrule( @@ -191,12 +192,12 @@ def device_link(name, srcs): outs = [fatbin_hdr], srcs = cubins, cmd = cmd, - tools = [fatbinary], + tools = [fatbinary, bin2c], ) # Generate the source file #including the headers generated above. _gen_link_src( - name = name + "_cc", + name = name + "_dlink_src", # Include just the last one, they are equivalent. register_hdr = register_hdr, fatbin_hdr = fatbin_hdr, @@ -206,12 +207,13 @@ def device_link(name, srcs): # Compile the source file into the cc_library. native.cc_library( - name = name, - srcs = [name + "_cc"], + name = name + "_dlink_a", + srcs = [ + name + "_dlink_src", + ], textual_hdrs = [register_hdr, fatbin_hdr], deps = [ "@local_config_cuda//cuda:cuda_headers", - "@local_config_cuda//cuda:cudart_static", ], defines = [ # Silence warning about including internal header. @@ -220,4 +222,31 @@ def device_link(name, srcs): "__NV_EXTRA_INITIALIZATION=", "__NV_EXTRA_FINALIZATION=", ], + linkstatic = True, + ) + + # Repackage deps into a single archive. This avoid unresolved symbols when + # the archives happen to be linked in the wrong order. For more details, see + # https://eli.thegreenplace.net/2013/07/09/library-order-in-static-linking + native.genrule( + name = name + "_a", + srcs = [ + name + "_deps_a", + name + "_dlink_a", + ], + outs = [name + ".a"], + # See https://stackoverflow.com/a/23621751 + cmd = """ +addlibs=$$(echo $(SRCS) | sed "s/[^ ]* */\\naddlib &/g") +printf "create $@$${addlibs}\\nsave\\nend" | $(AR) -M +""", + ) + + native.cc_library( + name = name, + srcs = [name + "_a"], + deps = [ + "@local_config_cuda//cuda:cudart_static", + ], + linkstatic = True, ) diff --git a/third_party/nccl/nccl_configure.bzl b/third_party/nccl/nccl_configure.bzl index 7f00df096202da..1e6422b49ef4d7 100644 --- a/third_party/nccl/nccl_configure.bzl +++ b/third_party/nccl/nccl_configure.bzl @@ -3,7 +3,7 @@ `nccl_configure` depends on the following environment variables: - * `TF_NCCL_VERSION`: The NCCL version. + * `TF_NCCL_VERSION`: Installed NCCL version or empty to build from source. * `NCCL_INSTALL_PATH`: The installation path of the NCCL library. * `NCCL_HDR_PATH`: The installation path of the NCCL header files. """ @@ -44,6 +44,7 @@ _NCCL_ARCHIVE_BUILD_CONTENT = """ exports_files([ "cuda/bin/crt/link.stub", "cuda/bin/fatbinary", + "cuda/bin/bin2c", "nvlink", ]) diff --git a/third_party/toolchains/gpus/cuda/BUILD b/third_party/toolchains/gpus/cuda/BUILD index f59e025019caff..f63a0ea8192578 100644 --- a/third_party/toolchains/gpus/cuda/BUILD +++ b/third_party/toolchains/gpus/cuda/BUILD @@ -1258,7 +1258,7 @@ genrule( "cuda/lib/libcupti.so.9.0", ], cmd = """ -if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/stubs/libcuda.so" "$(@D)/cuda/lib/libcuda.so" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart.so.9.0.176" "$(@D)/cuda/lib/libcudart.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart_static.a" "$(@D)/cuda/lib/libcudart_static.a" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcublas.so.9.0.480" "$(@D)/cuda/lib/libcublas.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcusolver.so.9.0.176" "$(@D)/cuda/lib/libcusolver.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcurand.so.9.0.176" "$(@D)/cuda/lib/libcurand.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcufft.so.9.0.176" "$(@D)/cuda/lib/libcufft.so.9.0" && cp "/usr/lib/x86_64-linux-gnu/libcudnn.so.7.2.1" "$(@D)/cuda/lib/libcudnn.so.7" && cp "/usr/local/cuda-9.0/extras/CUPTI/lib64/libcupti.so.9.0.176" "$(@D)/cuda/lib/libcupti.so.9.0" +if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/stubs/libcuda.so" "$(@D)/cuda/lib/libcuda.so" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart.so.9.0.176" "$(@D)/cuda/lib/libcudart.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart_static.a" "$(@D)/cuda/lib/libcudart_static.a" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcublas.so.9.0.480" "$(@D)/cuda/lib/libcublas.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcusolver.so.9.0.176" "$(@D)/cuda/lib/libcusolver.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcurand.so.9.0.176" "$(@D)/cuda/lib/libcurand.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcufft.so.9.0.176" "$(@D)/cuda/lib/libcufft.so.9.0" && cp "/usr/lib/x86_64-linux-gnu/libcudnn.so.7.2.1" "$(@D)/cuda/lib/libcudnn.so.7" && cp "/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.9.0.176" "$(@D)/cuda/lib/libcupti.so.9.0" """, ) diff --git a/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD b/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD index 05abcb56d84789..247e0ace243264 100755 --- a/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD +++ b/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD @@ -1253,7 +1253,7 @@ genrule( "cuda/lib/libcupti.so.9.0", ], cmd = """ -if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/stubs/libcuda.so" "$(@D)/cuda/lib/libcuda.so" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart.so.9.0.176" "$(@D)/cuda/lib/libcudart.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart_static.a" "$(@D)/cuda/lib/libcudart_static.a" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcublas.so.9.0.480" "$(@D)/cuda/lib/libcublas.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcusolver.so.9.0.176" "$(@D)/cuda/lib/libcusolver.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcurand.so.9.0.176" "$(@D)/cuda/lib/libcurand.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcufft.so.9.0.176" "$(@D)/cuda/lib/libcufft.so.9.0" && cp "/usr/lib/x86_64-linux-gnu/libcudnn.so.7.1.4" "$(@D)/cuda/lib/libcudnn.so.7" && cp "/usr/local/cuda-9.0/extras/CUPTI/lib64/libcupti.so.9.0.176" "$(@D)/cuda/lib/libcupti.so.9.0" +if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/stubs/libcuda.so" "$(@D)/cuda/lib/libcuda.so" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart.so.9.0.176" "$(@D)/cuda/lib/libcudart.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart_static.a" "$(@D)/cuda/lib/libcudart_static.a" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcublas.so.9.0.480" "$(@D)/cuda/lib/libcublas.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcusolver.so.9.0.176" "$(@D)/cuda/lib/libcusolver.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcurand.so.9.0.176" "$(@D)/cuda/lib/libcurand.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcufft.so.9.0.176" "$(@D)/cuda/lib/libcufft.so.9.0" && cp "/usr/lib/x86_64-linux-gnu/libcudnn.so.7.1.4" "$(@D)/cuda/lib/libcudnn.so.7" && cp "/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.9.0.176" "$(@D)/cuda/lib/libcupti.so.9.0" """, )