diff --git a/.bazelrc b/.bazelrc index 483eace0cfb..a3051b59f47 100644 --- a/.bazelrc +++ b/.bazelrc @@ -41,6 +41,9 @@ build:nativeopt --copt=-march=native build:nativeopt --host_copt=-march=native build:nativeopt --copt=-O3 +# Support TF pluggable device +build --copt=-DSUPPORT_TF_PLUGINS --define=with_plugins_support=true + build --keep_going build --verbose_failures=true build --spawn_strategy=standalone diff --git a/tensorflow_serving/BUILD b/tensorflow_serving/BUILD index 7a017679afb..270d594ee18 100644 --- a/tensorflow_serving/BUILD +++ b/tensorflow_serving/BUILD @@ -24,3 +24,10 @@ filegroup( ], ), ) + + +config_setting( + name = "with_plugins_support", + define_values = {"with_plugins_support": "true"}, + visibility = ["//visibility:public"], +) diff --git a/tensorflow_serving/model_servers/BUILD b/tensorflow_serving/model_servers/BUILD index d11434eada6..48ccb571b2a 100644 --- a/tensorflow_serving/model_servers/BUILD +++ b/tensorflow_serving/model_servers/BUILD @@ -4,6 +4,7 @@ load("@rules_pkg//:pkg.bzl", "pkg_deb", "pkg_tar") # Placeholder: load py_binary # Placeholder: load py_test +load("//tensorflow_serving:serving.bzl", "if_with_plugins_support") load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_google", "if_libtpu", "if_with_tpu_support") load("//tensorflow_serving:tensorflow_version.bzl", "if_not_v2", "if_v2") @@ -423,7 +424,11 @@ cc_library( "@org_tensorflow//tensorflow/core:protos_all_cc", "@org_tensorflow//tensorflow/core:tensorflow", "@org_tensorflow//tensorflow/core/profiler/rpc:profiler_service_impl", - ] + SUPPORTED_TENSORFLOW_OPS, + ] + SUPPORTED_TENSORFLOW_OPS + if_with_plugins_support([ + "@org_tensorflow//tensorflow/c:c_api_experimental", + "@org_tensorflow//tensorflow/c:kernels_experimental", + "@org_tensorflow//tensorflow/c/experimental/next_pluggable_device:c_api", + ]), ) cc_library( @@ -441,7 +446,6 @@ cc_library( ], deps = [ ":server_lib", - "@org_tensorflow//tensorflow/c:c_api", "@org_tensorflow//tensorflow/compiler/jit:xla_cpu_jit", "@org_tensorflow//tensorflow/core:lib", "@org_tensorflow//tensorflow/core/platform/cloud:gcs_file_system", @@ -458,10 +462,13 @@ cc_library( cc_binary( name = "tensorflow_model_server", - linkopts = [ - # Exports Tensorflow APIs - "-rdynamic", - ], + additional_linker_inputs = + if_with_plugins_support([ + "tf_c_api_exported_symbols.lds", + ]), + linkopts = if_with_plugins_support([ + "-Wl,-dynamic-list,$(location :tf_c_api_exported_symbols.lds)", + ]), stamp = 1, visibility = [ ":testing", diff --git a/tensorflow_serving/model_servers/main.cc b/tensorflow_serving/model_servers/main.cc index 180db4bd77a..11f4d9f74f4 100644 --- a/tensorflow_serving/model_servers/main.cc +++ b/tensorflow_serving/model_servers/main.cc @@ -79,7 +79,12 @@ void InitializeTPU(tensorflow::serving::main::Server::Options& server_options) { } #endif -int main(int argc, char** argv) { +#ifdef SUPPORT_TF_PLUGINS +#include +#include "tensorflow/c/c_api_experimental.h" +#endif + +int main(int argc, char **argv) { tensorflow::serving::main::Server::Options options; bool display_version = false; bool xla_cpu_compilation_enabled = false; @@ -296,6 +301,12 @@ int main(int argc, char** argv) { &options.thread_pool_factory_config_file, "If non-empty, read an ascii ThreadPoolConfig protobuf " "from the supplied file name."), +#ifdef SUPPORT_TF_PLUGINS + tensorflow::Flag("tensorflow_plugins", &options.tensorflow_plugins, + "Enable tensorflow plugins by giving a path to folder. " + "If non-empty, load all .so files under this folder " + "as tensorflow plugins."), +#endif tensorflow::Flag("skip_initialize_tpu", &options.skip_initialize_tpu, "Whether to skip auto initializing TPU.")}; @@ -306,6 +317,28 @@ int main(int argc, char** argv) { } tensorflow::port::InitMain(argv[0], &argc, &argv); + +#ifdef SUPPORT_TF_PLUGINS + if (std::filesystem::exists(options.tensorflow_plugins)) { + for (const auto &entry : + std::filesystem::directory_iterator(options.tensorflow_plugins)) { + std::string plugin_file = entry.path().string(); + if (plugin_file.size() > 3 && + plugin_file.compare(plugin_file.size() - 3, 3, ".so") == 0) { + TF_Status *plugin_status = TF_NewStatus(); + TF_LoadPluggableDeviceLibrary(entry.path().c_str(), plugin_status); + TF_Code code = TF_GetCode(plugin_status); + if (code == TF_OK) { + VLOG(0) << "plugin library " << entry.path() << " load successfully!"; + } else { + std::string status_msg(TF_Message(plugin_status)); + VLOG(0) << "Could not load " << entry.path() << ": " << status_msg; + } + } + } + } +#endif + #if defined(LIBTPU_ON_GCE) || defined(PLATFORM_CLOUD_TPU) InitializeTPU(options); #endif diff --git a/tensorflow_serving/model_servers/server.h b/tensorflow_serving/model_servers/server.h index 178672a70d4..3930ef49394 100644 --- a/tensorflow_serving/model_servers/server.h +++ b/tensorflow_serving/model_servers/server.h @@ -101,6 +101,9 @@ class Server { tensorflow::string thread_pool_factory_config_file; bool enable_signature_method_name_check = false; bool enable_profiler = true; +#ifdef SUPPORT_TF_PLUGINS + string tensorflow_plugins = ""; +#endif bool skip_initialize_tpu = false; Options(); }; diff --git a/tensorflow_serving/model_servers/tf_c_api_exported_symbols.lds b/tensorflow_serving/model_servers/tf_c_api_exported_symbols.lds new file mode 100644 index 00000000000..b5e82a097c2 --- /dev/null +++ b/tensorflow_serving/model_servers/tf_c_api_exported_symbols.lds @@ -0,0 +1,3 @@ +{ + *TF_*; +}; diff --git a/tensorflow_serving/serving.bzl b/tensorflow_serving/serving.bzl index 1f8788eb88c..cfff390368f 100644 --- a/tensorflow_serving/serving.bzl +++ b/tensorflow_serving/serving.bzl @@ -78,6 +78,13 @@ def serving_tensorflow_proto_dep(dep): """ return "{}_cc".format(dep) +def if_with_plugins_support(if_true, if_false = []): + """Shorthand for select()ing whether to build API support for TensorFlow Plugins""" + return select({ + "//tensorflow_serving:with_plugins_support": if_true, + "//conditions:default": if_false, + }) + def oss_only_cc_test(name, srcs = [], deps = [], data = [], size = "medium", linkstatic = 0): """cc_test that is only run in open source environment.""" return native.cc_test(