Skip to content

Commit

Permalink
Support real custom ops for Toco --allow_eager_ops flow.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 217020295
  • Loading branch information
miaout17 authored and tensorflower-gardener committed Oct 14, 2018
1 parent 109a0c1 commit e4b1832
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 1 deletion.
5 changes: 5 additions & 0 deletions tensorflow/contrib/lite/toco/tflite/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ cc_library(
"//tensorflow/contrib/lite/schema:schema_fbs",
"//tensorflow/contrib/lite/toco:graph_transformations",
"//tensorflow/contrib/lite/toco:model",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:ptr_util",
"@com_google_absl//absl/memory",
Expand All @@ -42,6 +43,7 @@ tf_cc_test(
deps = [
":operator",
"//tensorflow/contrib/lite/toco:tooling_util",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
"@com_google_googletest//:gtest_main",
"@flatbuffers",
Expand Down Expand Up @@ -71,6 +73,7 @@ tf_cc_test(
tags = ["no_oss"],
deps = [
":types",
"//tensorflow/core:ops",
"@com_google_googletest//:gtest_main",
],
)
Expand Down Expand Up @@ -106,6 +109,7 @@ tf_cc_test(
deps = [
":export",
"//tensorflow/contrib/lite/schema:schema_fbs",
"//tensorflow/core:ops",
"@com_google_googletest//:gtest_main",
],
)
Expand Down Expand Up @@ -141,6 +145,7 @@ tf_cc_test(
":import",
"//tensorflow/contrib/lite:schema_fbs_version",
"//tensorflow/contrib/lite/schema:schema_fbs",
"//tensorflow/core:ops",
"@com_google_googletest//:gtest_main",
"@flatbuffers",
],
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/contrib/lite/toco/tflite/export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ OperatorKey GetOperatorKey(

// TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way
// to populate a regular custom op. We need to find a way to fix this.
if (allow_flex_ops) {
if (ShouldExportAsFlexOp(allow_flex_ops, unsupported_op.tensorflow_op)) {
key.is_flex_op = true;
key.flex_tensorflow_op = tensorflow_op;
key.custom_code =
Expand Down
27 changes: 27 additions & 0 deletions tensorflow/contrib/lite/toco/tflite/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ limitations under the License.
#include "tensorflow/contrib/lite/toco/tflite/types.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/util/ptr_util.h"

namespace toco {
Expand Down Expand Up @@ -1258,6 +1260,16 @@ class TensorFlowUnsupported : public BaseOperator {
return std::unique_ptr<flexbuffers::Builder>();
}

if (ShouldExportAsFlexOp(allow_flex_ops_, node_def.op())) {
fbb->Vector([&]() {
fbb->String(node_def.op());
fbb->String(op.tensorflow_node_def);
});
fbb->Finish();
LOG(INFO) << "Writing flex op: " << node_def.op();
return std::unique_ptr<flexbuffers::Builder>(fbb.release());
}

bool has_valid_attr = false;
size_t map_start = fbb->StartMap();
for (const auto& pair : node_def.attr()) {
Expand Down Expand Up @@ -1588,6 +1600,21 @@ std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
return result;
}

bool ShouldExportAsFlexOp(bool allow_flex_ops,
const string& tensorflow_op_name) {
// If Flex ops aren't allow at all, simply return false.
if (!allow_flex_ops) {
return false;
}
// Check if we can find the `OpDef` for the TensorFlow op. If we can find
// it, export the op as an Flex op. Otherwise, export it as a regular custom
// op.
const tensorflow::OpDef* op_def = nullptr;
return tensorflow::OpRegistry::Global()
->LookUpOpDef(tensorflow_op_name, &op_def)
.ok();
}

} // namespace tflite

} // namespace toco
5 changes: 5 additions & 0 deletions tensorflow/contrib/lite/toco/tflite/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ class BaseOperator {
OperatorType type_;
};

// Helper function to determine if a unsupported TensorFlow op should be
// exported as an Flex op or a regular custom op.
bool ShouldExportAsFlexOp(bool allow_flex_ops,
const string& tensorflow_op_name);

} // namespace tflite

} // namespace toco
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/contrib/lite/toco/tflite/operator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,12 @@ TEST_F(OperatorTest, TensorFlowUnsupportedWithoutAttr) {
EXPECT_TRUE(output_node_def.attr().empty());
}

TEST_F(OperatorTest, TestShouldExportAsFlexOp) {
EXPECT_FALSE(ShouldExportAsFlexOp(false, "Conv2D"));
EXPECT_TRUE(ShouldExportAsFlexOp(true, "Conv2D"));
EXPECT_FALSE(ShouldExportAsFlexOp(true, "MyAwesomeCustomOp"));
}

} // namespace
} // namespace tflite

Expand Down

0 comments on commit e4b1832

Please sign in to comment.