Skip to content

Commit

Permalink
Add more type support for burn-jit (#2454)
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge authored Nov 4, 2024
1 parent 5597657 commit 42f39f1
Show file tree
Hide file tree
Showing 151 changed files with 2,314 additions and 1,504 deletions.
181 changes: 105 additions & 76 deletions Cargo.lock

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ version = "0.16.0"
[workspace.dependencies]
atomic_float = "1"
bytemuck = "1.19.0"
candle-core = { version = "0.6.0" }
candle-core = { version = "0.7" }
clap = { version = "4.5.20", features = ["derive"] }
colored = "2.1.0"
console_error_panic_hook = "0.1.7"
Expand All @@ -53,6 +53,7 @@ js-sys = "0.3.69"
libm = "0.2.9"
log = { default-features = false, version = "0.4.22" }
md5 = "0.7.0"
paste = "1"
percent-encoding = "2.3.1"
polars = { version = "0.41.3", features = ["lazy"] }
pretty_assertions = "1.4.1"
Expand Down Expand Up @@ -117,7 +118,7 @@ bincode = { version = "2.0.0-rc.3", features = [
#
# The following packages disable the "std" feature for no_std compatibility
#
derive-new = { version = "0.6.0", default-features = false }
derive-new = { version = "0.7.0", default-features = false }

blas-src = { version = "0.10.0", default-features = false }
half = { version = "2.4.1", features = [
Expand Down Expand Up @@ -152,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.2", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0dff475fec254e884f6b82e305e7a52adebf1dd7" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0dff475fec254e884f6b82e305e7a52adebf1dd7" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "9460a3244aa2b42e1d6c36bd25b65f814f81ecd0" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "9460a3244aa2b42e1d6c36bd25b65f814f81ecd0" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-autodiff/src/runtime/mspc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ impl AutodiffClient for ChannelClient {
.unwrap()
}

fn backward<B: Backend, const D: usize>(&self, root: AutodiffTensor<B, D>) -> Gradients {
fn backward<B: Backend>(&self, root: AutodiffTensor<B>) -> Gradients {
let node_id = root.node.id;
let grads = Gradients::new::<B, D>(root.node, root.primitive);
let grads = Gradients::new::<B>(root.node, root.primitive);
let (callback, receiver) = std::sync::mpsc::channel();

self.sender
Expand Down
8 changes: 4 additions & 4 deletions crates/burn-autodiff/src/tests/abs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ mod tests {
let grad_2 = tensor_2.grad(&grads).unwrap();

let expected = TensorData::from([[71.0, 107.0], [71.0, 107.0]]);
grad_1.to_data().assert_approx_eq(&expected, 3);
grad_1.to_data().assert_approx_eq(&expected, 5);

let expected = TensorData::from([[84.0, 42.0], [90.0, 54.0]]);
grad_2.to_data().assert_approx_eq(&expected, 3);
grad_2.to_data().assert_approx_eq(&expected, 5);
}

#[test]
Expand All @@ -42,10 +42,10 @@ mod tests {
let grad_2 = tensor_2.grad(&grads).unwrap();

let expected = TensorData::from([[1.0, 7.0], [1.0, 7.0]]);
grad_1.to_data().assert_approx_eq(&expected, 3);
grad_1.to_data().assert_approx_eq(&expected, 5);

let expected = TensorData::from([[0.0, -15.0], [-3.0, -3.0]]);
grad_2.to_data().assert_approx_eq(&expected, 3);
grad_2.to_data().assert_approx_eq(&expected, 5);

let contains_nan = grad_2.contains_nan();
assert_eq!(contains_nan.into_scalar(), false);
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-autodiff/src/tests/adaptive_avgpool1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ mod tests {

x_grad
.to_data()
.assert_approx_eq(&x_grad_actual.into_data(), 3);
.assert_approx_eq(&x_grad_actual.into_data(), 4);
}
}
}
2 changes: 1 addition & 1 deletion crates/burn-autodiff/src/tests/adaptive_avgpool2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ mod tests {

x_grad
.to_data()
.assert_approx_eq(&x_grad_actual.into_data(), 3);
.assert_approx_eq(&x_grad_actual.into_data(), 4);
}
}
}
2 changes: 1 addition & 1 deletion crates/burn-autodiff/src/tests/avgpool1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ mod tests {

x_grad
.to_data()
.assert_approx_eq(&x_grad_actual.into_data(), 3);
.assert_approx_eq(&x_grad_actual.into_data(), 4);
}
}
}
2 changes: 1 addition & 1 deletion crates/burn-autodiff/src/tests/avgpool2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ mod tests {

x_grad
.to_data()
.assert_approx_eq(&x_grad_actual.into_data(), 3);
.assert_approx_eq(&x_grad_actual.into_data(), 4);
}
}
}
8 changes: 4 additions & 4 deletions crates/burn-autodiff/src/tests/cat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,21 @@ mod tests {
.clone()
.slice([0..1])
.to_data()
.assert_approx_eq(&grad_1_slice_1.to_data(), 3);
.assert_approx_eq(&grad_1_slice_1.to_data(), 5);
grad_1
.slice([1..2])
.to_data()
.assert_approx_eq(&grad_1_slice_2.to_data(), 3);
.assert_approx_eq(&grad_1_slice_2.to_data(), 5);

grad_2
.clone()
.slice([0..1])
.to_data()
.assert_approx_eq(&grad_2_slice_1.to_data(), 3);
.assert_approx_eq(&grad_2_slice_1.to_data(), 5);
grad_2
.slice([1..2])
.to_data()
.assert_approx_eq(&grad_2_slice_2.to_data(), 3);
.assert_approx_eq(&grad_2_slice_2.to_data(), 5);
}

#[test]
Expand Down
6 changes: 3 additions & 3 deletions crates/burn-autodiff/src/tests/conv1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,15 +265,15 @@ mod tests {
expected_grads
.bias
.to_data()
.assert_approx_eq(&bias_grad_actual.to_data(), 3);
.assert_approx_eq(&bias_grad_actual.to_data(), 5);
expected_grads
.weight
.to_data()
.assert_approx_eq(&weight_grad_actual.to_data(), 3);
.assert_approx_eq(&weight_grad_actual.to_data(), 5);
expected_grads
.x
.to_data()
.assert_approx_eq(&x_grad_actual.to_data(), 3);
.assert_approx_eq(&x_grad_actual.to_data(), 5);
}
}
}
6 changes: 3 additions & 3 deletions crates/burn-autodiff/src/tests/conv2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -883,15 +883,15 @@ mod tests {
expected_grads
.bias
.to_data()
.assert_approx_eq(&bias_grad_actual.to_data(), 3);
.assert_approx_eq(&bias_grad_actual.to_data(), 5);
expected_grads
.x
.to_data()
.assert_approx_eq(&x_grad_actual.to_data(), 3);
.assert_approx_eq(&x_grad_actual.to_data(), 5);
expected_grads
.weight
.to_data()
.assert_approx_eq(&weight_grad_actual.to_data(), 3);
.assert_approx_eq(&weight_grad_actual.to_data(), 5);
}
}
}
6 changes: 3 additions & 3 deletions crates/burn-autodiff/src/tests/conv3d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -512,15 +512,15 @@ mod tests {
expected_grads
.bias
.to_data()
.assert_approx_eq(&bias_grad_actual.to_data(), 3);
.assert_approx_eq(&bias_grad_actual.to_data(), 5);
expected_grads
.x
.to_data()
.assert_approx_eq(&x_grad_actual.to_data(), 3);
.assert_approx_eq(&x_grad_actual.to_data(), 5);
expected_grads
.weight
.to_data()
.assert_approx_eq(&weight_grad_actual.to_data(), 3);
.assert_approx_eq(&weight_grad_actual.to_data(), 5);
}
}
}
6 changes: 3 additions & 3 deletions crates/burn-autodiff/src/tests/conv_transpose1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,15 +281,15 @@ mod tests {
expected_grads
.bias
.to_data()
.assert_approx_eq(&bias_grad_actual.to_data(), 3);
.assert_approx_eq(&bias_grad_actual.to_data(), 5);
expected_grads
.x
.to_data()
.assert_approx_eq(&x_grad_actual.to_data(), 3);
.assert_approx_eq(&x_grad_actual.to_data(), 5);
expected_grads
.weight
.to_data()
.assert_approx_eq(&weight_grad_actual.to_data(), 3);
.assert_approx_eq(&weight_grad_actual.to_data(), 5);
}
}
}
6 changes: 3 additions & 3 deletions crates/burn-autodiff/src/tests/conv_transpose2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -694,15 +694,15 @@ mod tests {
expected_grads
.bias
.to_data()
.assert_approx_eq(&bias_grad_actual.to_data(), 3);
.assert_approx_eq(&bias_grad_actual.to_data(), 5);
expected_grads
.x
.to_data()
.assert_approx_eq(&x_grad_actual.to_data(), 3);
.assert_approx_eq(&x_grad_actual.to_data(), 5);
expected_grads
.weight
.to_data()
.assert_approx_eq(&weight_grad_actual.to_data(), 3);
.assert_approx_eq(&weight_grad_actual.to_data(), 5);
}
}
}
6 changes: 3 additions & 3 deletions crates/burn-autodiff/src/tests/conv_transpose3d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,15 +567,15 @@ mod tests {
expected_grads
.bias
.to_data()
.assert_approx_eq(&bias_grad_actual.to_data(), 3);
.assert_approx_eq(&bias_grad_actual.to_data(), 5);
expected_grads
.x
.to_data()
.assert_approx_eq(&x_grad_actual.to_data(), 3);
.assert_approx_eq(&x_grad_actual.to_data(), 5);
expected_grads
.weight
.to_data()
.assert_approx_eq(&weight_grad_actual.to_data(), 3);
.assert_approx_eq(&weight_grad_actual.to_data(), 5);
}
}
}
2 changes: 1 addition & 1 deletion crates/burn-autodiff/src/tests/deform_conv2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1408,7 +1408,7 @@ mod tests {
expected_grads
.weight
.to_data()
.assert_approx_eq(&weight_grad_actual.to_data(), 3);
.assert_approx_eq_diff(&weight_grad_actual.to_data(), 0.04);
}
}
}
49 changes: 46 additions & 3 deletions crates/burn-autodiff/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,54 @@ mod transpose;

#[macro_export]
macro_rules! testgen_all {
// Avoid using paste dependency with no parameters
() => {
type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
type TestAutodiffTensor<const D: usize> = burn_tensor::Tensor<TestAutodiffBackend, D>;
mod autodiff {
pub use super::*;
type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
type TestAutodiffTensor<const D: usize> = burn_tensor::Tensor<TestAutodiffBackend, D>;

// Behavior
pub type FloatType = <TestBackend as burn_tensor::backend::Backend>::FloatElem;
pub type IntType = <TestBackend as burn_tensor::backend::Backend>::IntElem;
pub type BoolType = <TestBackend as burn_tensor::backend::Backend>::BoolTensorPrimitive;

$crate::testgen_with_float_param!();
}
};
([$($float:ident),*]) => {
mod autodiff {
pub use super::*;
type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
type TestAutodiffTensor<const D: usize> = burn_tensor::Tensor<TestAutodiffBackend, D>;

pub type FloatType = <TestBackend as burn_tensor::backend::Backend>::FloatElem;
pub type IntType = <TestBackend as burn_tensor::backend::Backend>::IntElem;
pub type BoolType = <TestBackend as burn_tensor::backend::Backend>::BoolTensorPrimitive;

::paste::paste! {
$(mod [<$float _ty>] {
pub use super::*;

pub type TestBackend = TestBackend2<$float, IntType>;
pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
pub type TestAutodiffTensor<const D: usize> = burn_tensor::Tensor<TestAutodiffBackend, D>;
pub type TestTensor<const D: usize> = TestTensor2<$float, IntType, D>;
pub type TestTensorInt<const D: usize> = TestTensorInt2<$float, IntType, D>;
pub type TestTensorBool<const D: usize> = TestTensorBool2<$float, IntType, D>;

type FloatType = $float;

$crate::testgen_with_float_param!();
})*
}
}
};
}

#[macro_export]
macro_rules! testgen_with_float_param {
() => {
// Behaviour
burn_autodiff::testgen_ad_broadcast!();
burn_autodiff::testgen_gradients!();
burn_autodiff::testgen_bridge!();
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-autodiff/src/tests/sqrt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ mod tests {
let grad_2 = tensor_2.grad(&grads).unwrap();

let expected = TensorData::from([[82.1126, 99.0832], [82.1126, 99.0832]]);
grad_1.to_data().assert_approx_eq(&expected, 3);
grad_1.to_data().assert_approx_eq_diff(&expected, 0.02);

let expected = TensorData::from([[30.3093, 33.1204], [34.5819, 38.7694]]);
grad_2.to_data().assert_approx_eq(&expected, 3);
grad_2.to_data().assert_approx_eq(&expected, 2);
}
}
12 changes: 6 additions & 6 deletions crates/burn-autodiff/src/tests/transpose.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ mod tests {

grad_1
.to_data()
.assert_eq(&TensorData::from([[6.0, 10.0], [6.0, 10.0]]), false);
.assert_approx_eq(&TensorData::from([[6.0, 10.0], [6.0, 10.0]]), 3);
grad_2
.to_data()
.assert_eq(&TensorData::from([[3.0, 10.0], [3.0, 10.0]]), false);
.assert_approx_eq(&TensorData::from([[3.0, 10.0], [3.0, 10.0]]), 3);
}

#[test]
Expand All @@ -48,13 +48,13 @@ mod tests {
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();

grad_1.to_data().assert_eq(
grad_1.to_data().assert_approx_eq(
&TensorData::from([[[66., 78.], [66., 78.]], [[270., 306.], [270., 306.]]]),
false,
3,
);
grad_2.to_data().assert_eq(
grad_2.to_data().assert_approx_eq(
&TensorData::from([[[22., 286.], [28., 316.]], [[172., 652.], [190., 694.]]]),
false,
3,
);
}
}
2 changes: 2 additions & 0 deletions crates/burn-candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ mod tests {
type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
type TestAutodiffTensor<const D: usize> = burn_tensor::Tensor<TestAutodiffBackend, D>;

pub type FloatType = f32;

// test activation
burn_tensor::testgen_gelu!();
burn_tensor::testgen_prelu!();
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-core/src/nn/transformer/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ mod tests {
// Should produce the same tokens.
output_1
.into_data()
.assert_approx_eq(&output_2.into_data(), 3);
.assert_approx_eq(&output_2.into_data(), 2);
}

#[test]
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-core/src/nn/transformer/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ mod tests {

output_1
.into_data()
.assert_approx_eq(&output_2.into_data(), 3);
.assert_approx_eq(&output_2.into_data(), 2);
}

#[test]
Expand Down
Loading

0 comments on commit 42f39f1

Please sign in to comment.