Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Kronecker product #261

Merged
merged 4 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/linalg/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@
## [Dot product](./src/dot.cairo)

The dot product or scalar product is an algebraic operation that takes two equal-length sequences of numbers (usually coordinate vectors), and returns a single number. Algebraically, the dot product is the sum of the products of the corresponding entries of the two sequences of numbers ([see also](https://en.wikipedia.org/wiki/Dot_product)).

## [Kronecker product](./src/kron.cairo)

The Kronecker product is an an algebraic operation that takes two equal-length sequences of numbers and returns an array of numbers([see also](https://numpy.org/doc/stable/reference/generated/numpy.kron.html)).
42 changes: 42 additions & 0 deletions src/linalg/src/kron.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use core::array::SpanTrait;
//! Kronecker product of two arrays

#[derive(Drop, Copy, PartialEq)]
enum KronError {
UnequalLength,
}

/// Compute the Kronecker product for 2 given arrays.
/// # Arguments
/// * `xs` - The first sequence of len L.
/// * `ys` - The second sequence of len L.
/// # Returns
/// * `Result<Array<T>, KronError>` - The Kronecker product.
fn kron<T, +Mul<T>, +AddEq<T>, +Zeroable<T>, +Copy<T>, +Drop<T>,>(
mut xs: Span<T>, mut ys: Span<T>
) -> Result<Array<T>, KronError> {
// [Check] Inputs
if xs.len() != ys.len() {
return Result::Err(KronError::UnequalLength);
}
assert(xs.len() == ys.len(), 'Arrays must have the same len');
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assert is useless


// [Compute] Kronecker product in a loop
let mut array = array![];
loop {
match xs.pop_front() {
Option::Some(x_value) => {
let mut ys_clone = ys;
loop {
match ys_clone.pop_front() {
Option::Some(y_value) => { array.append(*x_value * *y_value); },
Option::None => { break; },
};
};
},
Option::None => { break; },
};
};

Result::Ok(array)
}
1 change: 1 addition & 0 deletions src/linalg/src/lib.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod dot;
mod kron;

#[cfg(test)]
mod tests;
1 change: 1 addition & 0 deletions src/linalg/src/tests.cairo
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
mod dot_test;
mod kron_test;
29 changes: 29 additions & 0 deletions src/linalg/src/tests/kron_test.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use alexandria_linalg::kron::{kron, KronError};

#[test]
#[available_gas(2000000)]
fn kron_product_test() {
let mut xs: Array<u64> = array![1, 10, 100];
let mut ys = array![5, 6, 7];
let zs = kron(xs.span(), ys.span()).unwrap();
assert(*zs[0] == 5, 'wrong value at index 0');
assert(*zs[1] == 6, 'wrong value at index 1');
assert(*zs[2] == 7, 'wrong value at index 2');
assert(*zs[3] == 50, 'wrong value at index 3');
assert(*zs[4] == 60, 'wrong value at index 4');
assert(*zs[5] == 70, 'wrong value at index 5');
assert(*zs[6] == 500, 'wrong value at index 6');
assert(*zs[7] == 600, 'wrong value at index 7');
assert(*zs[8] == 700, 'wrong value at index 8');
}

#[test]
#[available_gas(2000000)]
fn kron_product_test_check_len() {
let mut xs: Array<u64> = array![1];
let mut ys = array![];
assert(
kron(xs.span(), ys.span()) == Result::Err(KronError::UnequalLength),
'Arrays must have the same len'
);
}