Skip to content

Commit

Permalink
Add basic storage array.
Browse files Browse the repository at this point in the history
commit-id:2da700e2
  • Loading branch information
gilbens-starkware committed Jul 9, 2024
1 parent 9e72a1c commit f7e928e
Show file tree
Hide file tree
Showing 6 changed files with 491 additions and 26 deletions.
3 changes: 3 additions & 0 deletions corelib/src/starknet/storage.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ use starknet::storage_access::StorageBaseAddress;
use starknet::SyscallResult;
use starknet::storage_access::storage_base_address_from_felt252;

mod vec;
pub use vec::{StorageVec, StorageVecTrait, MutableStorageVecTrait};


/// A pointer to an address in storage, can be used to read and write values, if the generic type
/// supports it (e.g. basic types like `felt252`).
Expand Down
115 changes: 115 additions & 0 deletions corelib/src/starknet/storage/vec.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
use super::{
StorageAsPath, StorageAsPointer, StoragePath, StoragePointer0Offset, Mutable, StoragePathTrait,
StoragePathUpdateTrait, StoragePointerReadAccess, StoragePointerWriteAccess
};

/// A type to represent a vec in storage. The length of the storage is stored in the storage
/// base, while the elements are stored in hash(storage_base, index).
#[phantom]
pub struct StorageVec<T> {}

impl StorageVecDrop<T> of Drop<StorageVec<T>> {}
impl StorageVecCopy<T> of Copy<StorageVec<T>> {}

/// Implement as_ptr for StorageVec.
impl StorageVecAsPointer<T> of StorageAsPointer<StoragePath<StorageVec<T>>> {
type Value = u64;
fn as_ptr(self: @StoragePath<StorageVec<T>>) -> StoragePointer0Offset<u64> {
StoragePointer0Offset { address: (*self).finalize() }
}
}

/// Implement as_ptr for Mutable<StorageVec>.
impl MutableStorageVecAsPointer<T> of StorageAsPointer<StoragePath<Mutable<StorageVec<T>>>> {
type Value = Mutable<u64>;
fn as_ptr(self: @StoragePath<Mutable<StorageVec<T>>>) -> StoragePointer0Offset<Mutable<u64>> {
StoragePointer0Offset { address: (*self).finalize() }
}
}


/// Trait for the interface of a storage vec.
pub trait StorageVecTrait<T> {
type ElementType;
fn at(self: T, index: u64) -> StoragePath<Self::ElementType>;
fn len(self: T) -> u64;
}

/// Implement `StorageVecTrait` for `StoragePath<StorageVec<T>>`.
impl StorageVecImpl<T> of StorageVecTrait<StoragePath<StorageVec<T>>> {
type ElementType = T;
fn at(self: StoragePath<StorageVec<T>>, index: u64) -> StoragePath<T> {
let vec_len = self.len();
assert!(index < vec_len, "Index out of bounds");
self.update(index)
}
fn len(self: StoragePath<StorageVec<T>>) -> u64 {
self.as_ptr().read()
}
}

/// Implement `StorageVecTrait` for any type that implements StorageAsPath into a storage path
/// that implements StorageVecTrait.
impl PathableStorageVecImpl<
T,
+Drop<T>,
impl PathImpl: StorageAsPath<T>,
impl VecTraitImpl: StorageVecTrait<StoragePath<PathImpl::Value>>
> of StorageVecTrait<T> {
type ElementType = VecTraitImpl::ElementType;
fn at(self: T, index: u64) -> StoragePath<VecTraitImpl::ElementType> {
self.as_path().at(index)
}
fn len(self: T) -> u64 {
self.as_path().len()
}
}

/// Trait for the interface of a mutable storage vec.
pub trait MutableStorageVecTrait<T> {
type ElementType;
fn at(self: T, index: u64) -> StoragePath<Mutable<Self::ElementType>>;
fn len(self: T) -> u64;
fn append(self: T) -> StoragePath<Mutable<Self::ElementType>>;
}

/// Implement `MutableStorageVecTrait` for `StoragePath<Mutable<StorageVec<T>>`.
impl MutableStorageVecImpl<
T, +Drop<T>
> of MutableStorageVecTrait<StoragePath<Mutable<StorageVec<T>>>> {
type ElementType = T;
fn at(self: StoragePath<Mutable<StorageVec<T>>>, index: u64) -> StoragePath<Mutable<T>> {
let vec_len = self.len();
assert!(index < vec_len, "Index out of bounds");
self.update(index)
}
fn len(self: StoragePath<Mutable<StorageVec<T>>>) -> u64 {
self.as_ptr().read()
}
fn append(self: StoragePath<Mutable<StorageVec<T>>>) -> StoragePath<Mutable<T>> {
let vec_len = self.len();
self.as_ptr().write(vec_len + 1);
self.update(vec_len)
}
}

/// Implement `MutableStorageVecTrait` for any type that implements StorageAsPath into a storage
/// path that implements MutableStorageVecTrait.
impl PathableMutableStorageVecImpl<
T,
+Drop<T>,
impl PathImpl: StorageAsPath<T>,
impl VecTraitImpl: MutableStorageVecTrait<StoragePath<PathImpl::Value>>
> of MutableStorageVecTrait<T> {
type ElementType = VecTraitImpl::ElementType;
fn at(self: T, index: u64) -> StoragePath<Mutable<VecTraitImpl::ElementType>> {
self.as_path().at(index)
}
fn len(self: T) -> u64 {
self.as_path().len()
}
fn append(self: T) -> StoragePath<Mutable<VecTraitImpl::ElementType>> {
self.as_path().append()
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
use core::starknet::storage::StoragePointerReadAccess;
use core::starknet::storage::MutableStorageVecTrait;
use core::starknet::storage::StoragePathEntry;


#[starknet::contract]
mod contract_with_map {
use starknet::storage::Map;
#[storage]
struct Storage {
simple: Map::<u64, u32>,
nested: Map::<u64, Map<u64, u32>>,
}
}

#[starknet::contract]
mod contract_with_vec {
use starknet::storage::StorageVec;
#[storage]
struct Storage {
simple: StorageVec::<u32>,
nested: StorageVec::<StorageVec<u32>>,
}
}

#[test]
fn test_simple_member_write_to_map() {
let mut map_contract_state = contract_with_map::contract_state_for_testing();
let mut vec_contract_state = contract_with_vec::contract_state_for_testing();
let vec_entry = vec_contract_state.simple.append();
map_contract_state.simple.entry(0).write(1);
assert_eq!(vec_entry.read(), 1);
}

#[test]
fn test_simple_member_write_to_vec() {
let mut map_contract_state = contract_with_map::contract_state_for_testing();
let mut vec_contract_state = contract_with_vec::contract_state_for_testing();
vec_contract_state.simple.append().write(1);
assert_eq!(map_contract_state.simple.entry(0).read(), 1);
}

#[test]
fn test_nested_member_write_to_map() {
let mut map_contract_state = contract_with_map::contract_state_for_testing();
let mut vec_contract_state = contract_with_vec::contract_state_for_testing();
let vec_entry = vec_contract_state.nested.append().append();
map_contract_state.nested.entry(0).entry(0).write(1);
assert_eq!(vec_entry.read(), 1);
}

#[test]
fn test_nested_member_write_to_vec() {
let mut map_contract_state = contract_with_map::contract_state_for_testing();
let mut vec_contract_state = contract_with_vec::contract_state_for_testing();
vec_contract_state.nested.append().append().write(1);
assert_eq!(map_contract_state.nested.entry(0).entry(0).read(), 1);
}
2 changes: 2 additions & 0 deletions crates/cairo-lang-starknet/cairo_level_tests/lib.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@ mod storage_access;
#[cfg(test)]
mod contract_address_test;
mod utils;
#[cfg(test)]
mod collections_test;
101 changes: 100 additions & 1 deletion crates/cairo-lang-starknet/cairo_level_tests/storage_access.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use super::utils::{deserialized, serialized};
use core::integer::BoundedInt;
use core::num::traits::Zero;
use core::byte_array::ByteArrayTrait;
use starknet::storage::StorageVec;

impl StorageAddressPartialEq of PartialEq<StorageAddress> {
fn eq(lhs: @StorageAddress, rhs: @StorageAddress) -> bool {
Expand Down Expand Up @@ -78,15 +79,24 @@ struct NonZeros {
value_felt252: NonZero<felt252>,
}

#[starknet::storage_node]
struct StorageVecs {
vec: StorageVec<u32>,
vec_of_vecs: StorageVec<StorageVec<u32>>,
}

#[starknet::contract]
mod test_contract {
use super::{AbcEtc, ByteArrays, NonZeros};
use core::starknet::storage::StoragePointerWriteAccess;
use super::{AbcEtc, ByteArrays, NonZeros, StorageVecs,};
use starknet::storage::{StorageVecTrait, MutableStorageVecTrait, StorageAsPath,};

#[storage]
struct Storage {
data: AbcEtc,
byte_arrays: ByteArrays,
non_zeros: NonZeros,
vecs: StorageVecs,
}

#[external(v0)]
Expand Down Expand Up @@ -128,6 +138,46 @@ mod test_contract {
pub fn get_non_zeros(self: @ContractState) -> NonZeros {
self.non_zeros.read()
}

#[external(v0)]
pub fn append_to_vec(ref self: ContractState, value: u32) {
self.vecs.vec.append().write(value);
}

#[external(v0)]
pub fn get_vec_length(self: @ContractState) -> u64 {
self.vecs.vec.len()
}

#[external(v0)]
pub fn get_vec_element(self: @ContractState, index: u64) -> u32 {
self.vecs.vec.at(index).read()
}

#[external(v0)]
pub fn append_an_vec(ref self: ContractState) {
self.vecs.vec_of_vecs.append();
}

#[external(v0)]
pub fn append_to_nested_vec(ref self: ContractState, index: u64, value: u32) {
self.vecs.vec_of_vecs.at(index).append().write(value);
}

#[external(v0)]
pub fn get_vec_of_vecs_length(self: @ContractState) -> u64 {
self.vecs.vec_of_vecs.len()
}

#[external(v0)]
pub fn get_nested_vec_length(self: @ContractState, index: u64) -> u64 {
self.vecs.vec_of_vecs.at(index).len()
}

#[external(v0)]
pub fn get_nested_vec_element(self: @ContractState, index: u64, nested_index: u64) -> u32 {
self.vecs.vec_of_vecs.at(index).at(nested_index).read()
}
}

#[test]
Expand Down Expand Up @@ -224,3 +274,52 @@ fn test_read_write_non_zero() {
assert!(test_contract::__external::set_non_zeros(serialized(x.clone())).is_empty());
assert_eq!(deserialized(test_contract::__external::get_non_zeros(serialized(()))), x);
}

#[test]
fn test_storage_array() {
assert!(test_contract::__external::append_to_vec(serialized(1_u32)).is_empty());
assert!(test_contract::__external::append_to_vec(serialized(2_u32)).is_empty());
assert!(test_contract::__external::append_to_vec(serialized(3_u32)).is_empty());
assert_eq!(deserialized(test_contract::__external::get_vec_length(serialized(()))), 3);
assert_eq!(deserialized(test_contract::__external::get_vec_element(serialized(0_u64))), 1);
assert_eq!(deserialized(test_contract::__external::get_vec_element(serialized(1_u64))), 2);
assert_eq!(deserialized(test_contract::__external::get_vec_element(serialized(2_u64))), 3);
}

#[test]
fn test_storage_vec_of_vecs() {
assert!(test_contract::__external::append_an_vec(serialized(())).is_empty());
assert!(test_contract::__external::append_to_nested_vec(serialized((0_u64, 1_u32))).is_empty());
assert!(test_contract::__external::append_to_nested_vec(serialized((0_u64, 2_u32))).is_empty());
assert!(test_contract::__external::append_to_nested_vec(serialized((0_u64, 3_u32))).is_empty());
assert!(test_contract::__external::append_an_vec(serialized(())).is_empty());
assert!(test_contract::__external::append_to_nested_vec(serialized((1_u64, 4_u32))).is_empty());
assert!(test_contract::__external::append_to_nested_vec(serialized((1_u64, 5_u32))).is_empty());
assert_eq!(deserialized(test_contract::__external::get_vec_of_vecs_length(serialized(()))), 2);
assert_eq!(
deserialized(test_contract::__external::get_nested_vec_length(serialized(0_u64))), 3
);
assert_eq!(
deserialized(test_contract::__external::get_nested_vec_element(serialized((0_u64, 0_u64)))),
1
);
assert_eq!(
deserialized(test_contract::__external::get_nested_vec_element(serialized((0_u64, 1_u64)))),
2
);
assert_eq!(
deserialized(test_contract::__external::get_nested_vec_element(serialized((0_u64, 2_u64)))),
3
);
assert_eq!(
deserialized(test_contract::__external::get_nested_vec_length(serialized(1_u64))), 2
);
assert_eq!(
deserialized(test_contract::__external::get_nested_vec_element(serialized((1_u64, 0_u64)))),
4
);
assert_eq!(
deserialized(test_contract::__external::get_nested_vec_element(serialized((1_u64, 1_u64)))),
5
);
}
Loading

0 comments on commit f7e928e

Please sign in to comment.