Skip to content

Commit

Permalink
Updated bounded_int_constrain to allow constraining NonZero of ranges. (
Browse files Browse the repository at this point in the history
  • Loading branch information
orizi authored Jun 5, 2024
1 parent caa28cb commit c81bec3
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 3 deletions.
14 changes: 11 additions & 3 deletions crates/cairo-lang-sierra/src/extensions/modules/bounded_int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use num_bigint::{BigInt, ToBigInt};
use num_traits::{One, Signed};
use starknet_types_core::felt::Felt as Felt252;

use super::non_zero::nonzero_ty;
use super::non_zero::{nonzero_ty, NonZeroType};
use super::range_check::RangeCheckType;
use super::utils::{reinterpret_cast_signature, Range};
use crate::define_libfunc_hierarchy;
Expand Down Expand Up @@ -291,7 +291,14 @@ impl NamedLibfunc for BoundedIntConstrainLibfunc {
[_, _] => Err(SpecializationError::UnsupportedGenericArg),
_ => Err(SpecializationError::WrongNumberOfGenericArgs),
}?;
let range = Range::from_type(context, ty.clone())?;
let ty_info = context.get_type_info(ty.clone())?;
let is_nz = ty_info.long_id.generic_id == NonZeroType::ID;
let range = if is_nz {
let inner_ty = args_as_single_type(&ty_info.long_id.generic_args)?;
Range::from_type(context, inner_ty)?
} else {
Range::from_type_info(&ty_info)?
};
require(&range.lower < boundary && boundary < &range.upper)
.ok_or(SpecializationError::UnsupportedGenericArg)?;
let low_range = Range::half_open(range.lower, boundary.clone());
Expand All @@ -300,11 +307,12 @@ impl NamedLibfunc for BoundedIntConstrainLibfunc {
.ok_or(SpecializationError::UnsupportedGenericArg)?;
let range_check_type = context.get_concrete_type(RangeCheckType::id(), &[])?;
let branch_signature = |rng: Range| {
let ty = bounded_int_ty(context, rng.lower, rng.upper - 1)?;
Ok(BranchSignature {
vars: vec![
OutputVarInfo::new_builtin(range_check_type.clone(), 0),
OutputVarInfo {
ty: bounded_int_ty(context, rng.lower, rng.upper - 1)?,
ty: if is_nz { nonzero_ty(context, &ty)? } else { ty },
ref_info: OutputVarReferenceInfo::SameAsParam { param_idx: 1 },
},
],
Expand Down
68 changes: 68 additions & 0 deletions tests/e2e_test_data/libfuncs/bounded_int
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,74 @@ test::foo@0([0]: RangeCheck, [1]: BoundedInt<0, 68056473384187692692674921486353

//! > ==========================================================================

//! > bounded_int_constrain libfunc non-zero.

//! > test_runner_name
SmallE2ETestRunner

//! > cairo
extern type BoundedInt<const MIN: felt252, const MAX: felt252>;
type Res = Result<NonZero<BoundedInt<0, 0x7f>>, NonZero<BoundedInt<0x80, 0xff>>>;

extern fn bounded_int_constrain<T, const BOUNDARY: felt252>(value: T) -> Res implicits(RangeCheck) nopanic;

fn foo(value: NonZero<u8>) -> Res {
bounded_int_constrain::<_, 0x80>(value)
}

//! > casm
%{ memory[ap + 0] = 340282366920938463463374607431768211456 <= (memory[fp + -3] + -128) % PRIME %}
jmp rel 7 if [ap + 0] != 0, ap++;
[ap + 0] = [fp + -3] + -128, ap++;
[ap + -1] = [[fp + -4] + 0];
jmp rel 11;
[ap + 0] = [fp + -3] + 340282366920938463463374607431768211328, ap++;
[ap + -1] = [[fp + -4] + 0];
[ap + 0] = [fp + -4] + 1, ap++;
[ap + 0] = 0, ap++;
[ap + 0] = [fp + -3], ap++;
ret;
[ap + 0] = [fp + -4] + 1, ap++;
[ap + 0] = 1, ap++;
[ap + 0] = [fp + -3], ap++;
ret;

//! > function_costs
test::foo: OrderedHashMap({Const: 770})

//! > sierra_code
type RangeCheck = RangeCheck [storable: true, drop: false, dup: false, zero_sized: false];
type BoundedInt<0, 127> = BoundedInt<0, 127> [storable: true, drop: true, dup: true, zero_sized: false];
type NonZero<BoundedInt<0, 127>> = NonZero<BoundedInt<0, 127>> [storable: true, drop: true, dup: true, zero_sized: false];
type BoundedInt<128, 255> = BoundedInt<128, 255> [storable: true, drop: true, dup: true, zero_sized: false];
type NonZero<BoundedInt<128, 255>> = NonZero<BoundedInt<128, 255>> [storable: true, drop: true, dup: true, zero_sized: false];
type core::result::Result::<core::zeroable::NonZero::<test::BoundedInt::<0, 127>>, core::zeroable::NonZero::<test::BoundedInt::<128, 255>>> = Enum<ut@core::result::Result::<core::zeroable::NonZero::<test::BoundedInt::<0, 127>>, core::zeroable::NonZero::<test::BoundedInt::<128, 255>>>, NonZero<BoundedInt<0, 127>>, NonZero<BoundedInt<128, 255>>> [storable: true, drop: true, dup: true, zero_sized: false];
type u8 = u8 [storable: true, drop: true, dup: true, zero_sized: false];
type NonZero<u8> = NonZero<u8> [storable: true, drop: true, dup: true, zero_sized: false];

libfunc bounded_int_constrain<NonZero<u8>, 128> = bounded_int_constrain<NonZero<u8>, 128>;
libfunc branch_align = branch_align;
libfunc enum_init<core::result::Result::<core::zeroable::NonZero::<test::BoundedInt::<0, 127>>, core::zeroable::NonZero::<test::BoundedInt::<128, 255>>>, 0> = enum_init<core::result::Result::<core::zeroable::NonZero::<test::BoundedInt::<0, 127>>, core::zeroable::NonZero::<test::BoundedInt::<128, 255>>>, 0>;
libfunc store_temp<RangeCheck> = store_temp<RangeCheck>;
libfunc store_temp<core::result::Result::<core::zeroable::NonZero::<test::BoundedInt::<0, 127>>, core::zeroable::NonZero::<test::BoundedInt::<128, 255>>>> = store_temp<core::result::Result::<core::zeroable::NonZero::<test::BoundedInt::<0, 127>>, core::zeroable::NonZero::<test::BoundedInt::<128, 255>>>>;
libfunc enum_init<core::result::Result::<core::zeroable::NonZero::<test::BoundedInt::<0, 127>>, core::zeroable::NonZero::<test::BoundedInt::<128, 255>>>, 1> = enum_init<core::result::Result::<core::zeroable::NonZero::<test::BoundedInt::<0, 127>>, core::zeroable::NonZero::<test::BoundedInt::<128, 255>>>, 1>;

bounded_int_constrain<NonZero<u8>, 128>([0], [1]) { fallthrough([2], [3]) 6([4], [5]) }; // 0
branch_align() -> (); // 1
enum_init<core::result::Result::<core::zeroable::NonZero::<test::BoundedInt::<0, 127>>, core::zeroable::NonZero::<test::BoundedInt::<128, 255>>>, 0>([3]) -> ([6]); // 2
store_temp<RangeCheck>([2]) -> ([2]); // 3
store_temp<core::result::Result::<core::zeroable::NonZero::<test::BoundedInt::<0, 127>>, core::zeroable::NonZero::<test::BoundedInt::<128, 255>>>>([6]) -> ([6]); // 4
return([2], [6]); // 5
branch_align() -> (); // 6
enum_init<core::result::Result::<core::zeroable::NonZero::<test::BoundedInt::<0, 127>>, core::zeroable::NonZero::<test::BoundedInt::<128, 255>>>, 1>([5]) -> ([7]); // 7
store_temp<RangeCheck>([4]) -> ([4]); // 8
store_temp<core::result::Result::<core::zeroable::NonZero::<test::BoundedInt::<0, 127>>, core::zeroable::NonZero::<test::BoundedInt::<128, 255>>>>([7]) -> ([7]); // 9
return([4], [7]); // 10

test::foo@0([0]: RangeCheck, [1]: NonZero<u8>) -> (RangeCheck, core::result::Result::<core::zeroable::NonZero::<test::BoundedInt::<0, 127>>, core::zeroable::NonZero::<test::BoundedInt::<128, 255>>>);

//! > ==========================================================================

//! > bounded_int_is_zero libfunc for i8.

//! > test_runner_name
Expand Down

0 comments on commit c81bec3

Please sign in to comment.