-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rust Refactor Stage 4: Rewrite Rust graph runtime to use new APIs #5830
Conversation
@robo-corg @tqchen this is like a straight copy of the old |
strides: if dlt.strides.is_null() { | ||
None | ||
} else { | ||
Some(slice::from_raw_parts_mut(dlt.strides as *mut usize, size).to_vec()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some(slice::from_raw_parts_mut(dlt.strides as *mut usize, size).to_vec()) | |
Some(slice::from_raw_parts_mut(dlt.strides as *mut usize, shape.len()).to_vec()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is correct, as the field currently is passing the number of elements not the rank of the tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From this figure (https://i.stack.imgur.com/oQQVI.png), I think the length of strides
should be equal to the rank of the tensor (ie. dimension).
@jroesch Can we enable derive_default for bindgen? Otherwise, test_wasm32 fails with tvm-sys because of the generated padding field. The following change looks necessary which existed in the previous implementation. diff --git a/rust/tvm-graph-rt/src/array.rs b/rust/tvm-graph-rt/src/array.rs
index 38519bd..8209b59 100644
--- a/rust/tvm-graph-rt/src/array.rs
+++ b/rust/tvm-graph-rt/src/array.rs
@@ -288,6 +288,7 @@ impl<'a> Tensor<'a> {
self.strides.as_ref().unwrap().as_ptr()
} as *mut i64,
byte_offset: 0,
+ ..Default::default()
}
}
}
diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs
index 85e16be..01d2934 100644
--- a/rust/tvm-sys/build.rs
+++ b/rust/tvm-sys/build.rs
@@ -54,6 +54,7 @@ fn main() {
.layout_tests(false)
.derive_partialeq(true)
.derive_eq(true)
+ .derive_default(true)
.generate()
.expect("unable to generate bindings")
.write_to_file(PathBuf::from("src/c_runtime_api.rs"))
diff --git a/rust/tvm-sys/src/array.rs b/rust/tvm-sys/src/array.rs
index 1627e9e..5d09d86 100644
--- a/rust/tvm-sys/src/array.rs
+++ b/rust/tvm-sys/src/array.rs
@@ -48,6 +48,7 @@ macro_rules! impl_dltensor_from_ndarray {
shape: arr.shape().as_ptr() as *const i64 as *mut i64,
strides: arr.strides().as_ptr() as *const i64 as *mut i64,
byte_offset: 0,
+ ..Default::default()
}
}
} |
pub fn new(size: usize, align: Option<usize>) -> Result<Self, LayoutErr> { | ||
let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES); | ||
let layout = Layout::from_size_align(size, alignment)?; | ||
let ptr = unsafe { alloc::alloc(layout) }; | ||
if ptr.is_null() { | ||
alloc::handle_alloc_error(layout); | ||
} | ||
Ok(Self { ptr, layout }) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this return std::mem::MaybeUninit<Allocation>
or does that matter here since it is just bytes?
|
||
#[derive(PartialEq, Eq)] | ||
pub struct Allocation { | ||
layout: Layout, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume I will find out why we need to track alignment?
/// let mut a_nd: ndarray::ArrayD<f32> = a.try_into().unwrap(); | ||
/// ``` | ||
#[derive(PartialEq)] | ||
pub struct Tensor<'a> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it make sense to have an owned Tensor and a TensorRef type? I guess that can get added later.
pub(crate) data: Storage<'a>, | ||
pub(crate) ctx: Context, | ||
pub(crate) dtype: DataType, | ||
pub(crate) shape: Vec<i64>, | ||
// ^ not usize because `typedef int64_t tvm_index_t` in c_runtime_api.h | ||
/// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous. | ||
pub(crate) strides: Option<Vec<usize>>, | ||
pub(crate) byte_offset: isize, | ||
/// The number of elements in the `Tensor`. | ||
pub(crate) size: usize, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would make these pub(self) or remove pub entirely since it looks like you have unsafe code using them.
graph: &Graph, | ||
lib: &'m M, | ||
tensors: &[Tensor<'t>], | ||
) -> Result<Vec<Box<dyn Fn() + 'm>>, Error> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe make this execs a function of Tensor<'t>
or for <'a> Tensor<'a>
?
pending: Arc<AtomicUsize>, | ||
} | ||
|
||
impl Job { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason not to use rayon for this? I think you can tell it to spawn using tvm's thread pool: https://docs.rs/rayon/1.3.1/rayon/struct.ThreadPoolBuilder.html
free: Vec<usize>, | ||
in_use: Vec<usize>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be a good use for https://docs.rs/hibitset/0.6.3/hibitset/
if ws_size < size { | ||
return cur_ws_idx; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like you could end up with some really extreme over allocation if you have a combination of large and small tensors.
|
||
[dependencies] | ||
crossbeam = "0.7.3" | ||
failure = "0.1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not anyhow
or thiserror
?
C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') | ||
s = tvm.te.create_schedule(C.op) | ||
s[C].parallel(s[C].op.axis[0]) | ||
print(tvm.lower(s, [A, B, C], simple_mode=True)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want this print
statement here?
C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') | ||
s = tvm.te.create_schedule(C.op) | ||
s[C].parallel(s[C].op.axis[0]) | ||
print(tvm.lower(s, [A, B, C], simple_mode=True)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another print.
Yeah I will turn it back on, there are just so many changes to juggle I was bound to make a mistake or two. |
@binarybana and @robo-corg many of the things you guys identified existed in the previous graph runtime, it would be my preference to land this port which just makes it compile against the new bindings, and then bring other patches to improve the code in follow up PRs. thoughts? |
Co-authored-by: Andrew <[email protected]>
Yes, that works for me, which is why I approved the PR despite my comments. |
That works great! |
…ache#5830) * Port graph-runtime to new API * --amend * Fix file lint * Remove old travis file * Add @kazum's patch * Update rust/tvm-sys/src/datatype.rs Co-authored-by: Andrew <[email protected]> Co-authored-by: Andrew <[email protected]>
…ache#5830) * Port graph-runtime to new API * --amend * Fix file lint * Remove old travis file * Add @kazum's patch * Update rust/tvm-sys/src/datatype.rs Co-authored-by: Andrew <[email protected]> Co-authored-by: Andrew <[email protected]>
This is the fourth stage of the Rust rewrite which ports the graph-rt to the new API. The final stage after this will be turning on CI for the new bindings, updating docs, and deprecating the old bindings. This depends on #5764.