diff --git a/packages/deep-learning/src/lib.cairo b/packages/deep-learning/src/lib.cairo index 489ba4e62..f0cef23ad 100644 --- a/packages/deep-learning/src/lib.cairo +++ b/packages/deep-learning/src/lib.cairo @@ -1,9 +1,11 @@ pub(crate) mod ops; -pub(crate) mod utils; pub use ops::binary::{BinaryOpMetadata, tensor_add, tensor_mul, tensor_rem, tensor_lt}; pub use ops::unary::{tensor_log2, tensor_exp2, tensor_sin, tensor_sqrt, tensor_recip}; -pub use ops::reduce::{tensor_reduce_sum_1d, tensor_reduce_sum_nd, ReduceOpMetadata}; +pub use ops::reduce::{ + tensor_sum_reduce_1d, tensor_sum_reduce_nd, tensor_max_reduce_1d, tensor_max_reduce_nd, + ReduceOpMetadata +}; #[derive(Drop, Copy)] diff --git a/packages/deep-learning/src/ops/reduce.cairo b/packages/deep-learning/src/ops/reduce.cairo index 3e8d9aa90..63e9a7fe3 100644 --- a/packages/deep-learning/src/ops/reduce.cairo +++ b/packages/deep-learning/src/ops/reduce.cairo @@ -3,6 +3,7 @@ use core::fmt::Debug; use orion_dl::{Tensor, MutTensor}; use orion_data_structures::vec::{NullableVec, VecTrait}; use core::ops::AddAssign; +use core::cmp::max; #[derive(Drop, Copy)] @@ -11,7 +12,7 @@ pub(crate) struct ReduceOpMetadata { output_size: usize, } -pub(crate) fn tensor_reduce_sum_1d, +Zero, +Copy, +Drop, +Debug>( +pub(crate) fn tensor_sum_reduce_1d, +Zero, +Copy, +Drop, +Debug>( mut input: Tensor ) -> Tensor { let mut result = Zero::::zero(); @@ -29,7 +30,7 @@ pub(crate) fn tensor_reduce_sum_1d, +Zero, +Copy, +Drop, +De Tensor { data: result_data.span() } } -pub(crate) fn tensor_reduce_sum_nd, +Copy, +Drop, +Debug, +Zero>( +pub(crate) fn tensor_sum_reduce_nd, +Copy, +Drop, +Debug, +Zero>( mut input: Tensor, ref metadata: ReduceOpMetadata ) -> MutTensor { let mut result_data: NullableVec = VecTrait::new(metadata.output_size); @@ -54,23 +55,73 @@ pub(crate) fn tensor_reduce_sum_nd, +Copy, +Drop, +Debug, +Z MutTensor { data: result_data } } + +pub(crate) fn tensor_max_reduce_1d, +Drop, +Debug, +PartialOrd>( + mut input: Tensor +) -> Tensor { + let mut result: Option = Option::None(()); + + loop { + match input.data.pop_front() { + Option::Some(input_value) => { + result = match result { + Option::Some(current_max) => Option::Some(max(*input_value, current_max)), + Option::None(_) => Option::Some(*input_value), + }; + }, + Option::None(_) => { break; } + }; + }; + + let mut result_data = ArrayTrait::new(); + result_data.append(result.unwrap()); + + Tensor { data: result_data.span() } +} + +pub(crate) fn tensor_max_reduce_nd, +Drop, +Debug, +PartialOrd, +Zero>( + mut input: Tensor, ref metadata: ReduceOpMetadata +) -> MutTensor { + let mut result_data: NullableVec = VecTrait::new(metadata.output_size); + + loop { + match input.data.pop_front() { + Option::Some(input_value) => { + match metadata.output_indices.pop_front() { + Option::Some(output_index) => { + let current_max = result_data.at(*output_index); + result_data.set(*output_index, max(*input_value, current_max)); + }, + Option::None(_) => { + break; // This should never happen if metadata is correct + } + } + }, + Option::None(_) => { break; } + }; + }; + + MutTensor { data: result_data } +} + + #[cfg(test)] -mod tests { +mod tests_sum_reduce { use super::{ - Tensor, MutTensor, VecTrait, NullableVec, ReduceOpMetadata, tensor_reduce_sum_1d, - tensor_reduce_sum_nd + Tensor, MutTensor, VecTrait, NullableVec, ReduceOpMetadata, tensor_sum_reduce_1d, + tensor_sum_reduce_nd }; #[test] #[available_gas(20000000)] - fn test_tensor_reduce_sum_1d() { + fn test_tensor_sum_reduce_1d() { // Test case: Reduce sum along axis 0 for a 1D tensor (full reduction) let input_data: Array = array![1, 2, 3, 4, 5]; let input = Tensor { data: input_data.span() }; - let result = tensor_reduce_sum_1d(input); + let result = tensor_sum_reduce_1d(input); let expected = array![15]; // [1+2+3+4+5] assert_eq!(result.data.len(), expected.len(), "Incorrect result length"); @@ -79,7 +130,7 @@ mod tests { #[test] #[available_gas(20000000)] - fn test_tensor_reduce_sum_2d() { + fn test_tensor_sum_reduce_2d() { // Test case: Reduce sum along axis 1 for a 2x3 tensor let input_data: Array = array![1, 2, 3, 4, 5, 6]; let output_indices: Array = array![0, 0, 0, 1, 1, 1]; @@ -89,7 +140,7 @@ mod tests { output_indices: output_indices.span(), output_size: 2, }; - let mut result = tensor_reduce_sum_nd(input, ref metadata); + let mut result = tensor_sum_reduce_nd(input, ref metadata); let expected = array![6, 15]; // [1+2+3, 4+5+6] assert_eq!(result.data.len(), expected.len(), "Incorrect result length"); @@ -99,7 +150,7 @@ mod tests { #[test] #[available_gas(20000000)] - fn test_tensor_reduce_sum_3d_axis0() { + fn test_tensor_sum_reduce_3d_axis0() { // Test case: Reduce sum along axis 0 for a 2x2x2 tensor let input_data: Array = array![1, 2, 3, 4, 5, 6, 7, 8]; let output_indices: Array = array![0, 1, 2, 3, 0, 1, 2, 3]; @@ -109,7 +160,7 @@ mod tests { output_indices: output_indices.span(), output_size: 4, }; - let mut result = tensor_reduce_sum_nd(input, ref metadata); + let mut result = tensor_sum_reduce_nd(input, ref metadata); let expected = array![6, 8, 10, 12]; // [1+5, 2+6, 3+7, 4+8] assert_eq!(result.data.len(), expected.len(), "Incorrect result length"); @@ -125,7 +176,7 @@ mod tests { #[test] #[available_gas(20000000)] - fn test_tensor_reduce_sum_3d_axis1() { + fn test_tensor_sum_reduce_3d_axis1() { // Test case: Reduce sum along axis 1 for a 2x3x2 tensor let input_data: Array = array![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; let output_indices: Array = array![0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3]; @@ -135,7 +186,7 @@ mod tests { output_indices: output_indices.span(), output_size: 4, }; - let mut result = tensor_reduce_sum_nd(input, ref metadata); + let mut result = tensor_sum_reduce_nd(input, ref metadata); let expected = array![9, 12, 27, 30]; // [1+3+5, 2+4+6, 7+9+11, 8+10+12] assert_eq!(result.data.len(), expected.len(), "Incorrect result length"); @@ -151,7 +202,7 @@ mod tests { #[test] #[available_gas(20000000)] - fn test_tensor_reduce_sum_4d() { + fn test_tensor_sum_reduce_4d() { // Test case: Reduce sum along axis 2 for a 2x2x3x2 tensor let input_data: Array = array![ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 @@ -165,7 +216,7 @@ mod tests { output_indices: output_indices.span(), output_size: 8, }; - let mut result = tensor_reduce_sum_nd(input, ref metadata); + let mut result = tensor_sum_reduce_nd(input, ref metadata); let expected = array![9, 12, 27, 30, 45, 48, 63, 66]; // [1+3+5, 2+4+6, 7+9+11, 8+10+12, 13+15+17, 14+16+18, 19+21+23, 20+22+24] @@ -180,3 +231,122 @@ mod tests { }; } } + +#[cfg(test)] +mod tests_max_reduce { + use super::{ + Tensor, MutTensor, VecTrait, NullableVec, ReduceOpMetadata, tensor_max_reduce_1d, + tensor_max_reduce_nd + }; + + #[test] + #[available_gas(20000000)] + fn test_tensor_max_reduce_1d() { + let input_data: Array = array![1, 5, 3, 4, 2]; + let input = Tensor { data: input_data.span() }; + + let result = tensor_max_reduce_1d(input); + + let expected = array![5]; + assert_eq!(result.data.len(), expected.len(), "Incorrect result length"); + assert_eq!(*result.data.at(0), *expected[0], "Incorrect max"); + } + + #[test] + #[available_gas(20000000)] + fn test_tensor_max_reduce_2d() { + let input_data: Array = array![1, 2, 3, 4, 5, 6]; + let output_indices: Array = array![0, 0, 0, 1, 1, 1]; + + let input = Tensor { data: input_data.span() }; + let mut metadata = ReduceOpMetadata { + output_indices: output_indices.span(), output_size: 2, + }; + + let mut result = tensor_max_reduce_nd(input, ref metadata); + + let expected = array![3, 6]; + assert_eq!(result.data.len(), expected.len(), "Incorrect result length"); + assert_eq!(result.data.at(0), *expected[0], "Incorrect first max"); + assert_eq!(result.data.at(1), *expected[1], "Incorrect second max"); + } + + #[test] + #[available_gas(20000000)] + fn test_tensor_max_reduce_3d_axis0() { + let input_data: Array = array![1, 2, 3, 4, 5, 6, 7, 8]; + let output_indices: Array = array![0, 1, 2, 3, 0, 1, 2, 3]; + + let input = Tensor { data: input_data.span() }; + let mut metadata = ReduceOpMetadata { + output_indices: output_indices.span(), output_size: 4, + }; + + let mut result = tensor_max_reduce_nd(input, ref metadata); + + let expected = array![5, 6, 7, 8]; + assert_eq!(result.data.len(), expected.len(), "Incorrect result length"); + let mut i = 0; + loop { + if i == expected.len() { + break; + } + assert_eq!(result.data.at(i), *expected[i], "Incorrect max at index"); + i += 1; + }; + } + + #[test] + #[available_gas(20000000)] + fn test_tensor_max_reduce_3d_axis1() { + let input_data: Array = array![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + let output_indices: Array = array![0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3]; + + let input = Tensor { data: input_data.span() }; + let mut metadata = ReduceOpMetadata { + output_indices: output_indices.span(), output_size: 4, + }; + + let mut result = tensor_max_reduce_nd(input, ref metadata); + + let expected = array![5, 6, 11, 12]; + assert_eq!(result.data.len(), expected.len(), "Incorrect result length"); + let mut i = 0; + loop { + if i == expected.len() { + break; + } + assert_eq!(result.data.at(i), *expected[i], "Incorrect max at index"); + i += 1; + }; + } + + #[test] + #[available_gas(20000000)] + fn test_tensor_max_reduce_4d() { + let input_data: Array = array![ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 + ]; + let output_indices: Array = array![ + 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7 + ]; + + let input = Tensor { data: input_data.span() }; + let mut metadata = ReduceOpMetadata { + output_indices: output_indices.span(), output_size: 8, + }; + + let mut result = tensor_max_reduce_nd(input, ref metadata); + + let expected = array![5, 6, 11, 12, 17, 18, 23, 24]; + assert_eq!(result.data.len(), expected.len(), "Incorrect result length"); + let mut i = 0; + loop { + if i == expected.len() { + break; + } + assert_eq!(result.data.at(i), *expected[i], "Incorrect max at index"); + i += 1; + }; + } +} diff --git a/packages/deep-learning/src/utils.cairo b/packages/deep-learning/src/utils.cairo deleted file mode 100644 index 40d4aafb9..000000000 --- a/packages/deep-learning/src/utils.cairo +++ /dev/null @@ -1,7 +0,0 @@ -pub(crate) fn u32_max(a: u32, b: u32) -> u32 { - if a > b { - a - } else { - b - } -}