Skip to content

Commit

Permalink
Add QTensorOps docs + refactor tests to simplify inputs (#2557)
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui authored Nov 27, 2024
1 parent ba7ed4f commit b4cef57
Show file tree
Hide file tree
Showing 51 changed files with 787 additions and 2,211 deletions.
35 changes: 33 additions & 2 deletions contributor-book/src/guides/adding-a-new-operation-to-burn.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ supertrait[^supertrait]. This is the trait that is then implemented by the diffe
backends (such as `burn-ndarray` and `burn-wgpu`) which must implement the functions if no default
is provided.

In this case, we don't need to worry about `Bool` Tensors. Ops for `Float` is implemented under
In this case, we don't need to worry about `Bool` Tensors. `Float` ops are implemented under
[`crates/burn-tensor/src/tensor/ops/tensor.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/ops/tensor.rs#L991),
and for `Int` under
and `Int` ops under
[`crates/burn-tensor/src/tensor/ops/int_tensor.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/ops/int_tensor.rs#L539).
The current convention is ops of each type, if not unique to that type, are prefixed with the type.
So `powf` and sundry would be defined as `int_powf` for `IntTensorOps` and `float_powf` for
Expand All @@ -50,6 +50,17 @@ The `Int` Tensor function uses the ones defined for Float with 2 extra casts (LH
tensor, Output to an `Int`). Given that the rest of the code will only look at the float
implementations.

With the addition of quantized float tensors, the `Float` tensor primitive is represented by the
[`TensorPrimitive`](https://github.com/tracel-ai/burn/blob/a6a5c22e0db56d947b9165d4dae42783a5a6b689/crates/burn-tensor/src/tensor/api/kind.rs#L69)
enum. This allows us to handle both float and quantized float operations in the `Tensor`
implementation, correctly dispatching to the corresponding op (float or quantized) based on the
variant. Following the same convention, the equivalent
[quantized tensor ops](https://github.com/tracel-ai/burn/blob/a6a5c22e0db56d947b9165d4dae42783a5a6b689/crates/burn-tensor/src/tensor/ops/qtensor.rs#L45)
are prefixed with `q_*` (e.g., `q_reshape` instead of `float_reshape`). Most ops have a default
implementation that simply dequantizes the input into its floating-point representation, performs
the operation on the float tensor, and quantizes the output. Backends can overwrite specific
implementations when required/desired.

### Adding Tests

Additional Tests should be added to `burn-tensor` under
Expand All @@ -60,6 +71,26 @@ inserting the module name into `crates/burn-tensor/src/tests/ops/mod.rs`. Then a
necessary to define tests in the backends directly, save for those that require specific testing
such as `burn-autodiff`.

For float tensor operations, the
[`QTensorOps`](https://github.com/tracel-ai/burn/blob/a6a5c22e0db56d947b9165d4dae42783a5a6b689/crates/burn-tensor/src/tensor/ops/qtensor.rs#L45)
counterpart is usually added at the same time with a default implementation (as mentioned in the
previous section). Tests for `q_*` ops follow a similar procedure: the test is added under
[`crates/burn-tensor/src/tests/quantization/ops/{op_name}.rs`](https://github.com/tracel-ai/burn/tree/a6a5c22e0db56d947b9165d4dae42783a5a6b689/crates/burn-tensor/src/tests/quantization/ops),
the module name is inserted into `crates/burn-tensor/src/tests/quantization/ops/mod.rs` and finally
the test is added to the
[`testgen_quantization` macro](https://github.com/tracel-ai/burn/blob/a6a5c22e0db56d947b9165d4dae42783a5a6b689/crates/burn-tensor/src/tests/mod.rs#L67).
If you take a look at any of the existing tests for an operation on a quantized tensor,
you will see that the inputs and expected outputs are always defined with floating point values.
While it assumes that the quantization and dequantization are correct, it makes the tests much more
readable and easier to understand w.r.t. what is being tested. Effectively, the tests are there to
ensure that a tensor operation is invariant to quantization (up to some quantization error, of
course).

_Note: the tests try to use tensors with floating point values which can be de/quantized without
introducing too much quantization error, but the result always depends on the operation (e.g.,
tensor product of values can grow larger and significantly increase the output tensor range, leading
to more de/quantization error on the results)._

## Adding the Op to burn-autodiff

Since this is probably the hardest and the least straightforward, we'll cover this backend
Expand Down
3 changes: 3 additions & 0 deletions crates/burn-tensor/src/tensor/quantization/scheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ impl QuantizationScheme {
let zero = Tensor::zeros_like(&range.max);
let max = range.max.max_pair(zero);

// If scale is 0 (most likely due to a tensor full of zeros), we arbitrarily adjust the
// scale to 0.1 to avoid division by zero.
let scale = max.sub(min.clone()).div_scalar(b - a);
let scale = scale.clone().mask_fill(scale.equal_elem(0.), 0.1);
let offset = Some(-(min.div(scale.clone()).sub_scalar(a)).int());
QuantizationParameters { scale, offset }
}
Expand Down
39 changes: 22 additions & 17 deletions crates/burn-tensor/src/tensor/quantization/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,20 @@ pub struct AffineQuantization<E: Float, Q: PrimInt, A: PrimInt> {
_a: PhantomData<A>,
}

fn valid_scale<E: Float>(mut scale: E) -> E {
// If scale is 0 (most likely due to a tensor full of zeros), we arbitrarily adjust the
// scale to 0.1 to avoid division by zero.
if scale.eq(&E::zero()) {
scale = E::from(0.1).unwrap();
}
scale
}

impl<E: Float, Q: PrimInt, A: PrimInt> AffineQuantization<E, Q, A> {
/// Initialize an affine quantization scheme with the given parameters.
pub fn init(scale: E, offset: Q) -> Self {
let mut scale = scale;
// If scale is 0 (most likely due to a tensor full of zeros), we arbitrarily adjust the
// scale to 0.1 to avoid division by zero.
if scale.eq(&E::zero()) {
scale = E::from(0.1).unwrap();
}
Self {
scale,
scale: valid_scale(scale),
offset,
_a: PhantomData,
}
Expand All @@ -87,9 +90,13 @@ impl<E: Float, Q: PrimInt, A: PrimInt> Quantization<E, Q> for AffineQuantization
let beta = E::max(beta, E::zero());

// Compute scale and offset to convert a floating point value in range `[alpha, beta]` to the quantized range
let scale = (beta - alpha) / (b - a);
let scale = valid_scale((beta - alpha) / (b - a));
let z = -(alpha / scale - a);
Self::init(scale, Q::from(z).unwrap())
Self {
scale,
offset: Q::from(z).unwrap(),
_a: PhantomData,
}
}

fn quantize(&self, values: &[E]) -> Vec<Q> {
Expand Down Expand Up @@ -136,14 +143,8 @@ pub struct SymmetricQuantization<E: Float, Q: PrimInt> {
impl<E: Float, Q: PrimInt> SymmetricQuantization<E, Q> {
/// Initialize a symmetric quantization scheme with the given parameters.
pub fn init(scale: E) -> Self {
let mut scale = scale;
// If scale is 0 (most likely due to a tensor full of zeros), we arbitrarily adjust the
// scale to 0.1 to avoid division by zero.
if scale.eq(&E::zero()) {
scale = E::from(0.1).unwrap();
}
Self {
scale,
scale: valid_scale(scale),
_q: PhantomData,
}
}
Expand All @@ -162,7 +163,11 @@ impl<E: Float, Q: PrimInt> Quantization<E, Q> for SymmetricQuantization<E, Q> {

// Compute scale to convert a floating point value in range `[-alpha, alpha]` to the quantized range
let alpha = alpha.abs().max(beta.abs());
Self::init((alpha + alpha) / (b - a))
let scale = valid_scale((alpha + alpha) / (b - a));
Self {
scale,
_q: PhantomData,
}
}

fn quantize(&self, values: &[E]) -> Vec<Q> {
Expand Down
39 changes: 38 additions & 1 deletion crates/burn-tensor/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,42 @@ macro_rules! testgen_all {
#[macro_export]
macro_rules! testgen_quantization {
() => {
// Quantized tensor utilities
pub mod qtensor {
use core::marker::PhantomData;

use burn_tensor::{
backend::Backend,
quantization::{QuantizationScheme, QuantizationType},
Tensor, TensorData,
};

pub struct QTensor<B: Backend, const D: usize> {
b: PhantomData<B>,
}

impl<B: Backend, const D: usize> QTensor<B, D> {
/// Creates a quantized int8 tensor from the floating point data using the default quantization scheme
/// (i.e., per-tensor symmetric quantization).
pub fn int8<F: Into<TensorData>>(floats: F) -> Tensor<B, D> {
Self::int8_symmetric(floats)
}
/// Creates a quantized int8 tensor from the floating point data using per-tensor symmetric quantization.
pub fn int8_symmetric<F: Into<TensorData>>(floats: F) -> Tensor<B, D> {
Tensor::from_floats(floats, &Default::default()).quantize_dynamic(
&QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8),
)
}
/// Creates a quantized int8 tensor from the floating point data using per-tensor affine quantization.
pub fn int8_affine<F: Into<TensorData>>(floats: F) -> Tensor<B, D> {
Tensor::from_floats(floats, &Default::default()).quantize_dynamic(
&QuantizationScheme::PerTensorAffine(QuantizationType::QInt8),
)
}
}
}
pub use qtensor::*;

// test quantization
burn_tensor::testgen_calibration!();
burn_tensor::testgen_scheme!();
Expand Down Expand Up @@ -104,11 +140,12 @@ macro_rules! testgen_quantization {
burn_tensor::testgen_q_remainder!();
burn_tensor::testgen_q_repeat_dim!();
burn_tensor::testgen_q_reshape!();
burn_tensor::testgen_q_round!();
burn_tensor::testgen_q_select!();
burn_tensor::testgen_q_sin!();
burn_tensor::testgen_q_slice!();
burn_tensor::testgen_q_sort_argsort!();
// burn_tensor::testgen_q_split!();
burn_tensor::testgen_q_split!();
burn_tensor::testgen_q_sqrt!();
burn_tensor::testgen_q_stack!();
burn_tensor::testgen_q_sub!();
Expand Down
11 changes: 2 additions & 9 deletions crates/burn-tensor/src/tests/quantization/ops/abs.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
#[burn_tensor_testgen::testgen(q_abs)]
mod tests {
use super::*;
use burn_tensor::quantization::{QuantizationStrategy, SymmetricQuantization};
use burn_tensor::{Tensor, TensorData};
use burn_tensor::TensorData;

#[test]
fn should_support_abs_ops() {
// Quantized [[0.0, -1.0, 2.0], [3.0, 4.0, -5.0]]
let data = TensorData::quantized(
vec![0i8, -25, 51, 76, 102, -127],
[2, 3],
QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)),
);
let tensor = TestTensor::<2>::from_data(data, &Default::default());
let tensor = QTensor::<TestBackend, 2>::int8([[0.0, -1.0, 2.0], [3.0, 4.0, -5.0]]);

let output = tensor.abs();

Expand Down
91 changes: 12 additions & 79 deletions crates/burn-tensor/src/tests/quantization/ops/add.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,12 @@
#[burn_tensor_testgen::testgen(q_add)]
mod tests {
use super::*;
use burn_tensor::quantization::{QuantizationStrategy, SymmetricQuantization};
use burn_tensor::{Tensor, TensorData};
use burn_tensor::TensorData;

#[test]
fn test_add_d2() {
// Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
let data = TensorData::quantized(
vec![0i8, 25, 51, 76, 102, 127],
[2, 3],
QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)),
);
let tensor_1 = TestTensor::<2>::from_data(data, &Default::default());
// Quantized [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]
let data = TensorData::quantized(
vec![69i8, 81, 92, 104, 115, 127],
[2, 3],
QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.08661418)),
);
let tensor_2 = TestTensor::<2>::from_data(data, &Default::default());
let tensor_1 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor_2 = QTensor::<TestBackend, 2>::int8([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]);

let output = tensor_1 + tensor_2;

Expand All @@ -32,20 +19,8 @@ mod tests {

#[test]
fn test_add_broadcast() {
// Quantized [[0.0, 1.0, 2.0]]
let data = TensorData::quantized(
vec![0i8, 64, 127],
[1, 3],
QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.015748031)),
);
let tensor_1 = TestTensor::<2>::from_data(data, &Default::default());
// Quantized [[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]
let data = TensorData::quantized(
vec![48i8, 64, 79, 95, 111, 127],
[2, 3],
QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.062992126)),
);
let tensor_2 = TestTensor::<2>::from_data(data, &Default::default());
let tensor_1 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0]]);
let tensor_2 = QTensor::<TestBackend, 2>::int8([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]);

let output = tensor_1 + tensor_2;

Expand All @@ -58,22 +33,10 @@ mod tests {

#[test]
fn test_add_different_strides_rhs() {
// Quantized [[0.0, 1.0], [2.0, 3.0]]
let data = TensorData::quantized(
vec![0i8, 42, 85, 127],
[2, 2],
QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.023622047)),
);
// We need to execute an operation after `from data` to trigger inplace in some backends.
// Which is the operation that might be problematic in this case.
let tensor_1 = TestTensor::<2>::from_data(data, &Default::default()) * 1;
// Quantized [[4.0, 5.0], [6.0, 7.0]]
let data = TensorData::quantized(
vec![73i8, 91, 109, 127],
[2, 2],
QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.05511811)),
);
let tensor_2 = TestTensor::<2>::from_data(data, &Default::default()) * 1;
let tensor_1 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0], [2.0, 3.0]]) * 1;
let tensor_2 = QTensor::<TestBackend, 2>::int8([[4.0, 5.0], [6.0, 7.0]]) * 1;

let output = tensor_1 + tensor_2.transpose();

Expand All @@ -86,22 +49,10 @@ mod tests {

#[test]
fn test_add_different_strides_lhs() {
// Quantized [[0.0, 1.0], [2.0, 3.0]]
let data = TensorData::quantized(
vec![0i8, 42, 85, 127],
[2, 2],
QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.023622047)),
);
// We need to execute an operation after `from data` to trigger inplace in some backends.
// Which is the operation that might be problematic in this case.
let tensor_1 = TestTensor::<2>::from_data(data, &Default::default()) * 1;
// Quantized [[4.0, 5.0], [6.0, 7.0]]
let data = TensorData::quantized(
vec![73i8, 91, 109, 127],
[2, 2],
QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.05511811)),
);
let tensor_2 = TestTensor::<2>::from_data(data, &Default::default()) * 1;
let tensor_1 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0], [2.0, 3.0]]) * 1;
let tensor_2 = QTensor::<TestBackend, 2>::int8([[4.0, 5.0], [6.0, 7.0]]) * 1;

let output = tensor_1.transpose() + tensor_2;

Expand All @@ -114,22 +65,10 @@ mod tests {

#[test]
fn test_add_different_strides_broadcast() {
// Quantized [[0.0, 1.0], [2.0, 3.0]]
let data = TensorData::quantized(
vec![0i8, 42, 85, 127],
[2, 2],
QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.023622047)),
);
// We need to execute an operation after `from data` to trigger inplace in some backends.
// Which is the operation that might be problematic in this case.
let tensor_1 = TestTensor::<2>::from_data(data, &Default::default()) * 1;
// Quantized [[4.0, 5.0]]
let data = TensorData::quantized(
vec![102i8, 127],
[1, 2],
QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)),
);
let tensor_2 = TestTensor::<2>::from_data(data, &Default::default()) * 1;
let tensor_1 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0], [2.0, 3.0]]) * 1;
let tensor_2 = QTensor::<TestBackend, 2>::int8([[4.0, 5.0]]) * 1;

let output = tensor_1.transpose() + tensor_2;

Expand All @@ -143,13 +82,7 @@ mod tests {
#[test]
fn should_support_add_scalar_ops() {
let scalar = 2.0;
// Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
let data = TensorData::quantized(
vec![0i8, 25, 51, 76, 102, 127],
[2, 3],
QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(0.03937008)),
);
let tensor = TestTensor::<2>::from_data(data, &Default::default());
let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);

let output = tensor + scalar;

Expand Down
Loading

0 comments on commit b4cef57

Please sign in to comment.