Skip to content

Commit

Permalink
[Rust][Fix] Memory leak (apache#8714)
Browse files Browse the repository at this point in the history
* Fix obvious memory leak in function.rs

* Update object pointer
  • Loading branch information
jroesch authored and ylc committed Sep 29, 2021
1 parent 5034dd6 commit 05c9dfe
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 69 deletions.
89 changes: 39 additions & 50 deletions rust/tvm-rt/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
//! See the tests and examples repository for more examples.
use std::convert::{TryFrom, TryInto};
use std::sync::Arc;
use std::{
ffi::CString,
os::raw::{c_char, c_int},
Expand All @@ -39,36 +40,43 @@ pub use tvm_sys::{ffi, ArgValue, RetValue};

pub type Result<T> = std::result::Result<T, Error>;

/// Wrapper around TVM function handle which includes `is_global`
/// indicating whether the function is global or not, and `is_cloned` showing
/// not to drop a cloned function from Rust side.
/// The value of these fields can be accessed through their respective methods.
#[derive(Debug, Hash)]
pub struct Function {
pub(crate) handle: ffi::TVMFunctionHandle,
// whether the registered function is global or not.
is_global: bool,
from_rust: bool,
struct FunctionPtr {
handle: ffi::TVMFunctionHandle,
}

unsafe impl Send for Function {}
unsafe impl Sync for Function {}
// NB(@jroesch): I think this is ok, need to double check,
// if not we should mutex the pointer or move to Rc.
unsafe impl Send for FunctionPtr {}
unsafe impl Sync for FunctionPtr {}

impl FunctionPtr {
fn from_raw(handle: ffi::TVMFunctionHandle) -> Self {
FunctionPtr { handle }
}
}

impl Drop for FunctionPtr {
fn drop(&mut self) {
check_call!(ffi::TVMFuncFree(self.handle));
}
}

/// An owned thread-safe version of `tvm::PackedFunc` for consumption in Rust.
#[derive(Debug, Hash)]
pub struct Function {
inner: Arc<FunctionPtr>,
}

impl Function {
pub(crate) fn new(handle: ffi::TVMFunctionHandle) -> Self {
pub(crate) fn from_raw(handle: ffi::TVMFunctionHandle) -> Self {
Function {
handle,
is_global: false,
from_rust: false,
inner: Arc::new(FunctionPtr::from_raw(handle)),
}
}

pub unsafe fn null() -> Self {
Function {
handle: std::ptr::null_mut(),
is_global: false,
from_rust: false,
}
Function::from_raw(std::ptr::null_mut())
}

/// For a given function, it returns a function by name.
Expand All @@ -84,11 +92,7 @@ impl Function {
if handle.is_null() {
None
} else {
Some(Function {
handle,
is_global: true,
from_rust: false,
})
Some(Function::from_raw(handle))
}
}

Expand All @@ -103,12 +107,7 @@ impl Function {

/// Returns the underlying TVM function handle.
pub fn handle(&self) -> ffi::TVMFunctionHandle {
self.handle
}

/// Returns `true` if the underlying TVM function is global and `false` otherwise.
pub fn is_global(&self) -> bool {
self.is_global
self.inner.handle
}

/// Calls the function that created from `Builder`.
Expand All @@ -122,7 +121,7 @@ impl Function {

let ret_code = unsafe {
ffi::TVMFuncCall(
self.handle,
self.handle(),
values.as_mut_ptr() as *mut ffi::TVMValue,
type_codes.as_mut_ptr() as *mut c_int,
num_args as c_int,
Expand Down Expand Up @@ -171,25 +170,15 @@ impl_to_fn!(T1, T2, T3, T4, T5, T6,);

impl Clone for Function {
fn clone(&self) -> Function {
Self {
handle: self.handle,
is_global: self.is_global,
from_rust: true,
Function {
inner: self.inner.clone(),
}
}
}

// impl Drop for Function {
// fn drop(&mut self) {
// if !self.is_global && !self.is_cloned {
// check_call!(ffi::TVMFuncFree(self.handle));
// }
// }
// }

impl From<Function> for RetValue {
fn from(func: Function) -> RetValue {
RetValue::FuncHandle(func.handle)
RetValue::FuncHandle(func.handle())
}
}

Expand All @@ -198,7 +187,7 @@ impl TryFrom<RetValue> for Function {

fn try_from(ret_value: RetValue) -> Result<Function> {
match ret_value {
RetValue::FuncHandle(handle) => Ok(Function::new(handle)),
RetValue::FuncHandle(handle) => Ok(Function::from_raw(handle)),
_ => Err(Error::downcast(
format!("{:?}", ret_value),
"FunctionHandle",
Expand All @@ -209,10 +198,10 @@ impl TryFrom<RetValue> for Function {

impl<'a> From<Function> for ArgValue<'a> {
fn from(func: Function) -> ArgValue<'a> {
if func.handle.is_null() {
if func.handle().is_null() {
ArgValue::Null
} else {
ArgValue::FuncHandle(func.handle)
ArgValue::FuncHandle(func.handle())
}
}
}
Expand All @@ -222,7 +211,7 @@ impl<'a> TryFrom<ArgValue<'a>> for Function {

fn try_from(arg_value: ArgValue<'a>) -> Result<Function> {
match arg_value {
ArgValue::FuncHandle(handle) => Ok(Function::new(handle)),
ArgValue::FuncHandle(handle) => Ok(Function::from_raw(handle)),
_ => Err(Error::downcast(
format!("{:?}", arg_value),
"FunctionHandle",
Expand All @@ -236,7 +225,7 @@ impl<'a> TryFrom<&ArgValue<'a>> for Function {

fn try_from(arg_value: &ArgValue<'a>) -> Result<Function> {
match arg_value {
ArgValue::FuncHandle(handle) => Ok(Function::new(*handle)),
ArgValue::FuncHandle(handle) => Ok(Function::from_raw(*handle)),
_ => Err(Error::downcast(
format!("{:?}", arg_value),
"FunctionHandle",
Expand Down
2 changes: 1 addition & 1 deletion rust/tvm-rt/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl Module {
return Err(errors::Error::NullHandle(name.into_string()?.to_string()));
}

Ok(Function::new(fhandle))
Ok(Function::from_raw(fhandle))
}

/// Imports a dependent module such as `.ptx` for cuda gpu.
Expand Down
4 changes: 2 additions & 2 deletions rust/tvm-rt/src/ndarray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ use num_traits::Num;

use crate::errors::NDArrayError;

use crate::object::{Object, ObjectPtr};
use crate::object::{Object, ObjectPtr, ObjectRef};

/// See the [`module-level documentation`](../ndarray/index.html) for more details.
#[repr(C)]
Expand All @@ -73,7 +73,7 @@ pub struct NDArrayContainer {
// Container Base
dl_tensor: DLTensor,
manager_ctx: *mut c_void,
// TOOD: shape?
shape: ObjectRef,
}

impl NDArrayContainer {
Expand Down
12 changes: 0 additions & 12 deletions rust/tvm-rt/src/object/object_ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,6 @@ impl Object {
}
}

// impl fmt::Debug for Object {
// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// let index =
// format!("{} // key: {}", self.type_index, "the_key");

// f.debug_struct("Object")
// .field("type_index", &index)
// // TODO(@jroesch: do we expose other fields?)
// .finish()
// }
// }

/// An unsafe trait which should be implemented for an object
/// subtype.
///
Expand Down
2 changes: 1 addition & 1 deletion rust/tvm-rt/src/to_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ pub trait ToFunction<I, O>: Sized {
&mut fhandle as *mut ffi::TVMFunctionHandle,
));

Function::new(fhandle)
Function::from_raw(fhandle)
}

/// The callback function which is wrapped converted by TVM
Expand Down
10 changes: 8 additions & 2 deletions rust/tvm-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

extern crate bindgen;

use std::{path::{Path, PathBuf}, str::FromStr};
use std::{
path::{Path, PathBuf},
str::FromStr,
};

use anyhow::{Context, Result};
use tvm_build::{BuildConfig, CMakeSetting};
Expand Down Expand Up @@ -195,7 +198,10 @@ fn find_using_tvm_build() -> Result<TVMInstall> {
if cfg!(feature = "use-vitis-ai") {
build_config.settings.use_vitis_ai = Some(true);
}
if cfg!(any(feature = "static-linking", feature = "build-static-runtime")) {
if cfg!(any(
feature = "static-linking",
feature = "build-static-runtime"
)) {
build_config.settings.build_static_runtime = Some(true);
}

Expand Down
2 changes: 1 addition & 1 deletion rust/tvm/tests/basics/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ fn main() {
let mut arr = NDArray::empty(shape, dev, dtype);
arr.copy_from_buffer(data.as_mut_slice());
let ret = NDArray::empty(shape, dev, dtype);
let mut fadd = Module::load(&concat!(env!("OUT_DIR"), "/test_add.so")).unwrap();
let fadd = Module::load(&concat!(env!("OUT_DIR"), "/test_add.so")).unwrap();
if !fadd.enabled(dev_name) {
return;
}
Expand Down

0 comments on commit 05c9dfe

Please sign in to comment.