Skip to content

Commit

Permalink
[RUNTIME] Implement TVMDSOOp(TensorFlow custom op) for TVM runtime (a…
Browse files Browse the repository at this point in the history
…pache#4459)

* Add implementation of TVMDSOOp

* feat: Update cmake script to work with c++11 and in-repo build

* feat: Use libtvm as oplib dependency

* fix: Add missing link dependency to libtvm

* feat: Update tf tvmdso op by review comments

* fix: Update with pr comments

* fix: Fix lint

* feat: Add test script and fix gpu shape

* feat: Add test script and fix gpu shape

* fix: Conditional build tftvm op for gpu

* fix: Conditional build tftvm op for gpu

* fix: Fix pylint of tf_op module.py

* fix: Fix pylint of tf_op module.py

* feat: Conditional enable gpu test for tftvm op

* feat: Conditional enable gpu test for tftvm op

* feat: Add tf_tvmdsoop test script as an app test

* fix: Fix gpu/cpu enabled check on tvm in test script

* fix: Make tf tvmdso op test script runnable with pytest

* remove unused test script test_tfop_module.py

* fix: Remove pushd & popd in tfdsoop test script

* fix: Upgrade tftvmop use python3 to find TensorFlow

* fix: Upgrade tftvmop use python3 to find TensorFlow

* fix: Change target_link_options to target_link_libraries

* fix: Add tftvmop build script's c++ option

* fix: Add tvm library path to tf op test library path

* fix: Debug ci build for tftvm dso op

* fix: Fix cmake error and skip tfop test

* fix: Fix typo and indentation issues

* feat: Use TF list input op def

* fix: Fix style and unexpected changes

Co-authored-by: baoxinqi <[email protected]>
Co-authored-by: Chen Dihao <[email protected]>
Co-authored-by: wrongtest <[email protected]>
  • Loading branch information
4 people authored and dpankratz committed Apr 24, 2020
1 parent 04a49cb commit 26c3b54
Show file tree
Hide file tree
Showing 11 changed files with 728 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ tvm_option(USE_MSVC_MT "Build with MT" OFF)
tvm_option(USE_MICRO "Build with Micro" OFF)
tvm_option(INSTALL_DEV "Install compiler infrastructure" OFF)
tvm_option(HIDE_PRIVATE_SYMBOLS "Compile with -fvisibility=hidden." OFF)
tvm_option(USE_TF_TVMDSOOP "Build with TensorFlow TVMDSOOp" OFF)

# 3rdparty libraries
tvm_option(DLPACK_PATH "Path to DLPACK" "3rdparty/dlpack/include")
Expand Down Expand Up @@ -259,6 +260,7 @@ include(cmake/modules/contrib/Sort.cmake)
include(cmake/modules/contrib/NNPack.cmake)
include(cmake/modules/contrib/HybridDump.cmake)
include(cmake/modules/contrib/TFLite.cmake)
include(cmake/modules/contrib/TF_TVMDSOOP.cmake)

if(NOT MSVC)
include(CheckCXXCompilerFlag)
Expand Down
34 changes: 34 additions & 0 deletions apps/tf_tvmdsoop/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
cmake_minimum_required(VERSION 3.2)
project(tf_tvmdsoop C CXX)

set(TFTVM_COMPILE_FLAGS -std=c++11)
set(BUILD_TVMDSOOP_ONLY ON)
set(CMAKE_CURRENT_SOURCE_DIR ${TVM_ROOT})
set(CMAKE_CURRENT_BINARY_DIR ${TVM_ROOT}/build)

include_directories(${TVM_ROOT}/3rdparty/dlpack/include/)
include_directories(${TVM_ROOT}/3rdparty/dmlc-core/include/)
include_directories(${TVM_ROOT}/include)

link_directories(${TVM_ROOT}/build)

include(${TVM_ROOT}/cmake/util/FindCUDA.cmake)
include(${TVM_ROOT}/cmake/modules/CUDA.cmake)

include(${TVM_ROOT}/cmake/modules/contrib/TF_TVMDSOOP.cmake)
35 changes: 35 additions & 0 deletions apps/tf_tvmdsoop/prepare_and_test_tfop_module.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/bin/bash
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

TVM_ROOT=$(cd $(dirname $0)/../..; pwd)
echo "TVM_ROOT=${TVM_ROOT}"

export PYTHONPATH=${TVM_ROOT}/python

python3 -c "import tvm; print(tvm.runtime.enabled('gpu'))" | grep -e 1
if [ "$?" -eq 0 ]; then
echo "Build TF_TVMDSOOP with gpu support and execute tests"
CMAKE_OPTIONS="-DUSE_CUDA=ON -DPython3_EXECUTABLE=python3 -DTVM_ROOT=${TVM_ROOT}"

mkdir -p build
cd build; cmake .. ${CMAKE_OPTIONS} && make
cd ..

LD_LIBRARY_PATH=${TVM_ROOT}/build:./build:$LD_LIBRARY_PATH python3 -m pytest -v ./tests
fi

118 changes: 118 additions & 0 deletions apps/tf_tvmdsoop/tests/test_tfop_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#!/usr/bin/env python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test script for tf op module"""
import tempfile
import os
import logging
import tensorflow as tf
import numpy as np
import tvm
from tvm import te
from tvm.contrib import tf_op


def test_use_tvmdso_op():
"""main test function"""

def export_cpu_add_lib():
"""create cpu add op lib"""
n = te.var("n")
ph_a = te.placeholder((n,), name='ph_a')
ph_b = te.placeholder((n,), name='ph_b')
ph_c = te.compute(ph_a.shape, lambda i: ph_a[i] + ph_b[i], name='ph_c')
sched = te.create_schedule(ph_c.op)
fadd_dylib = tvm.build(sched, [ph_a, ph_b, ph_c], "c", name="vector_add")
lib_path = tempfile.mktemp("tvm_add_dll.so")
fadd_dylib.export_library(lib_path)
return lib_path


def export_gpu_add_lib():
"""create gpu add op lib"""
n = te.var("n")
ph_a = te.placeholder((n,), name='ph_a')
ph_b = te.placeholder((n,), name='ph_b')
ph_c = te.compute(ph_a.shape, lambda i: ph_a[i] + ph_b[i], name='ph_c')
sched = te.create_schedule(ph_c.op)
b_axis, t_axis = sched[ph_c].split(ph_c.op.axis[0], factor=64)
sched[ph_c].bind(b_axis, te.thread_axis("blockIdx.x"))
sched[ph_c].bind(t_axis, te.thread_axis("threadIdx.x"))
fadd_dylib = tvm.build(sched, [ph_a, ph_b, ph_c], "cuda", name="vector_add")
lib_path = tempfile.mktemp("tvm_add_cuda_dll.so")
fadd_dylib.export_library(lib_path)
return lib_path


def test_add(session, lib_path, tf_device):
"""test add lib with TensorFlow wrapper"""
module = tf_op.OpModule(lib_path)

left = tf.placeholder("float32", shape=[4])
right = tf.placeholder("float32", shape=[4])

feed_dict = {left: [1.0, 2.0, 3.0, 4.0], right: [5.0, 6.0, 7.0, 8.0]}
expect = np.asarray([6.0, 8.0, 10.0, 12.0])

add1 = module.func("vector_add", output_shape=[4], output_dtype="float")
add2 = module.func("vector_add", output_shape=tf.shape(left), output_dtype="float")
add3 = module.func("vector_add", output_shape=[tf.shape(left)[0]], output_dtype="float")

with tf.device(tf_device):
output1 = session.run(add1(left, right), feed_dict)
np.testing.assert_equal(output1, expect)

output2 = session.run(add2(left, right), feed_dict)
np.testing.assert_equal(output2, expect)

output3 = session.run(add3(left, right), feed_dict)
np.testing.assert_equal(output3, expect)


def cpu_test(session):
"""test function for cpu"""
cpu_lib = None
try:
cpu_lib = export_cpu_add_lib()
test_add(session, cpu_lib, "/cpu:0")
finally:
if cpu_lib is not None:
os.remove(cpu_lib)


def gpu_test(session):
"""test function for gpu"""
gpu_lib = None
try:
gpu_lib = export_gpu_add_lib()
test_add(session, gpu_lib, "/gpu:0")
finally:
if gpu_lib is not None:
os.remove(gpu_lib)

with tf.Session() as session:
if tvm.runtime.enabled("cpu"):
logging.info("Test TensorFlow op on cpu kernel")
cpu_test(session)
if tvm.runtime.enabled("gpu"):
logging.info("Test TensorFlow op on gpu kernel")
gpu_test(session)


if __name__ == "__main__":
test_use_tvmdso_op()
4 changes: 4 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,7 @@ set(USE_EXAMPLE_EXT_RUNTIME OFF)

# Whether use Thrust
set(USE_THRUST OFF)

# Whether to build the TensorFlow TVMDSOOp module
set(USE_TF_TVMDSOOP OFF)

58 changes: 58 additions & 0 deletions cmake/modules/contrib/TF_TVMDSOOP.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

if(NOT USE_TF_TVMDSOOP STREQUAL "OFF")
find_package(Python3 COMPONENTS Interpreter)

execute_process(COMMAND ${Python3_EXECUTABLE} -c "import tensorflow as tf; print(' '.join(tf.sysconfig.get_compile_flags()))"
OUTPUT_VARIABLE TF_COMPILE_FLAGS_STR
RESULT_VARIABLE TF_STATUS)
if (NOT ${TF_STATUS} EQUAL 0)
message(FATAL_ERROR "Fail to get TensorFlow compile flags")
endif()

if(NOT USE_CUDA STREQUAL "OFF")
add_definitions(-DTF_TVMDSOOP_ENABLE_GPU)
endif()

execute_process(COMMAND ${Python3_EXECUTABLE} -c "import tensorflow as tf; print(' '.join(tf.sysconfig.get_link_flags()))"
OUTPUT_VARIABLE TF_LINK_FLAGS_STR
RESULT_VARIABLE TF_STATUS)
if (NOT ${TF_STATUS} EQUAL 0)
message(FATAL_ERROR "Fail to get TensorFlow link flags")
endif()

string(REGEX REPLACE "\n" " " TF_FLAGS "${TF_COMPILE_FLAGS} ${TF_LINK_FLAGS}")
separate_arguments(TF_COMPILE_FLAGS UNIX_COMMAND ${TF_COMPILE_FLAGS_STR})
separate_arguments(TF_LINK_FLAGS UNIX_COMMAND ${TF_LINK_FLAGS_STR})


set(OP_LIBRARY_NAME tvm_dso_op)
file(GLOB_RECURSE TFTVM_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/tf_op/*.cc)
add_library(${OP_LIBRARY_NAME} SHARED ${TFTVM_SRCS})
set_target_properties(${OP_LIBRARY_NAME} PROPERTIES PREFIX "")
set(TFTVM_LINK_FLAGS -ltvm -L${CMAKE_CURRENT_BINARY_DIR})

if (NOT BUILD_TVMDSOOP_ONLY STREQUAL "ON")
add_dependencies(${OP_LIBRARY_NAME} tvm)
endif()

target_compile_options(${OP_LIBRARY_NAME} PUBLIC ${TFTVM_COMPILE_FLAGS} ${TF_COMPILE_FLAGS})
target_link_libraries(${OP_LIBRARY_NAME} PUBLIC ${TFTVM_LINK_FLAGS} ${TF_LINK_FLAGS})

endif()

20 changes: 20 additions & 0 deletions python/tvm/contrib/tf_op/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Module container of TensorFlow TVMDSO op"""
from . import module

OpModule = module.OpModule
Loading

0 comments on commit 26c3b54

Please sign in to comment.