From 606426a83e9d5d02b87a4adc859d2b1cec7974d2 Mon Sep 17 00:00:00 2001 From: Rigidity Date: Tue, 23 Jul 2024 02:04:22 -0400 Subject: [PATCH] Refactor attributes --- crates/rue-typing/src/check.rs | 2 + crates/rue-typing/src/check/attributes.rs | 140 +++++++++++++++++++++ crates/rue-typing/src/check/check_type.rs | 147 +++++----------------- 3 files changed, 176 insertions(+), 113 deletions(-) create mode 100644 crates/rue-typing/src/check/attributes.rs diff --git a/crates/rue-typing/src/check.rs b/crates/rue-typing/src/check.rs index 0073913..61f05ea 100644 --- a/crates/rue-typing/src/check.rs +++ b/crates/rue-typing/src/check.rs @@ -1,5 +1,6 @@ use std::fmt; +mod attributes; mod check_error; mod check_type; mod simplify_and; @@ -9,6 +10,7 @@ mod stringify_check; pub use check_error::*; +pub(crate) use attributes::*; pub(crate) use check_type::*; pub(crate) use simplify_and::*; pub(crate) use simplify_check::*; diff --git a/crates/rue-typing/src/check/attributes.rs b/crates/rue-typing/src/check/attributes.rs new file mode 100644 index 0000000..407ad3f --- /dev/null +++ b/crates/rue-typing/src/check/attributes.rs @@ -0,0 +1,140 @@ +use std::collections::{HashMap, HashSet, VecDeque}; + +use num_bigint::BigInt; +use num_traits::One; + +use crate::{Type, TypeId, TypeSystem}; + +use super::CheckError; + +pub(crate) struct Attributes { + pub atom_count: usize, + pub bytes32_count: usize, + pub public_key_count: usize, + pub pairs: Vec<(TypeId, TypeId)>, + pub values: HashMap, + pub length: usize, +} + +impl Attributes { + pub fn all_atoms(&self) -> bool { + self.atom_count == self.length + } + + pub fn all_bytes32(&self) -> bool { + self.bytes32_count == self.length + } + + pub fn all_public_key(&self) -> bool { + self.public_key_count == self.length + } + + pub fn all_pairs(&self) -> bool { + self.pairs.len() == self.length + } + + pub fn all_value(&self, value: &BigInt) -> bool { + self.values.get(value).copied().unwrap_or(0) == self.length + } + + pub fn atoms_are_bytes32(&self) -> bool { + self.bytes32_count == self.atom_count + } + + pub fn atoms_are_public_key(&self) -> bool { + self.public_key_count == self.atom_count + } + + pub fn atoms_are_value(&self, value: &BigInt) -> bool { + self.values.get(value).copied().unwrap_or(0) == self.atom_count + } +} + +pub(crate) fn union_attributes( + db: &TypeSystem, + items: &[TypeId], + is_lhs: bool, + other_type_id: TypeId, + visited: &mut HashSet<(TypeId, TypeId)>, +) -> Result { + let mut atom_count = 0; + let mut bytes32_count = 0; + let mut public_key_count = 0; + let mut pairs = Vec::new(); + let mut values = HashMap::new(); + + let mut items: VecDeque<_> = items.iter().copied().collect(); + let mut length = 0; + + while !items.is_empty() { + let item = items.remove(0).unwrap(); + length += 1; + + let key = if is_lhs { + (item, other_type_id) + } else { + (other_type_id, item) + }; + + if !visited.insert(key) { + return Err(CheckError::Recursive(key.0, key.1)); + } + + match db.get(item) { + Type::Ref(..) => unreachable!(), + Type::Lazy(..) => unreachable!(), + Type::Alias(..) => unreachable!(), + Type::Struct(..) => unreachable!(), + Type::Callable(..) => unreachable!(), + Type::Generic => return Err(CheckError::Impossible(key.0, key.1)), + Type::Unknown => {} + Type::Never => { + length -= 1; + } + Type::Atom | Type::Bytes | Type::Int => { + atom_count += 1; + } + Type::Bytes32 => { + atom_count += 1; + bytes32_count += 1; + } + Type::PublicKey => { + atom_count += 1; + public_key_count += 1; + } + Type::Nil => { + atom_count += 1; + *values.entry(BigInt::ZERO).or_insert(0) += 1; + } + Type::True => { + atom_count += 1; + *values.entry(BigInt::one()).or_insert(0) += 1; + } + Type::False => { + atom_count += 1; + *values.entry(BigInt::ZERO).or_insert(0) += 1; + } + Type::Value(value) => { + atom_count += 1; + *values.entry(value.clone()).or_insert(0) += 1; + } + Type::Pair(first, rest) => { + pairs.push((*first, *rest)); + } + Type::Union(child_items) => { + items.extend(child_items); + } + } + + visited.remove(&key); + } + + Ok(Attributes { + atom_count, + bytes32_count, + public_key_count, + pairs, + values, + length, + }) +} diff --git a/crates/rue-typing/src/check/check_type.rs b/crates/rue-typing/src/check/check_type.rs index 8fb5662..fa7b0d4 100644 --- a/crates/rue-typing/src/check/check_type.rs +++ b/crates/rue-typing/src/check/check_type.rs @@ -1,25 +1,19 @@ -use std::{ - collections::{HashMap, HashSet, VecDeque}, - hash::BuildHasher, -}; +use std::collections::HashSet; use num_bigint::BigInt; use num_traits::One; use crate::{bigint_to_bytes, Comparison, Type, TypeId, TypeSystem}; -use super::{Check, CheckError}; +use super::{union_attributes, Check, CheckError}; /// Returns [`None`] for recursive checks. -pub(crate) fn check_type( +pub(crate) fn check_type( types: &mut TypeSystem, lhs: TypeId, rhs: TypeId, - visited: &mut HashSet<(TypeId, TypeId), S>, -) -> Result -where - S: BuildHasher, -{ + visited: &mut HashSet<(TypeId, TypeId)>, +) -> Result { if !visited.insert((lhs, rhs)) { if types.compare(lhs, rhs) <= Comparison::Castable { return Ok(Check::None); @@ -254,16 +248,13 @@ where Ok(check) } -fn check_union_against_rhs( +fn check_union_against_rhs( types: &mut TypeSystem, original_type_id: TypeId, items: &[TypeId], rhs: TypeId, - visited: &mut HashSet<(TypeId, TypeId), S>, -) -> Result -where - S: BuildHasher, -{ + visited: &mut HashSet<(TypeId, TypeId)>, +) -> Result { let union = types.alloc(Type::Union(items.to_vec())); if types.compare(union, rhs) <= Comparison::Castable { @@ -288,71 +279,7 @@ where return Ok(Check::Or(result)); } - let mut atom_count = 0; - let mut bytes32_count = 0; - let mut public_key_count = 0; - let mut pairs = Vec::new(); - let mut values = HashMap::new(); - - let mut items: VecDeque<_> = items.iter().copied().collect::>(); - let mut length = 0; - - while !items.is_empty() { - let item = items.remove(0).unwrap(); - length += 1; - - if !visited.insert((item, rhs)) { - return Err(CheckError::Recursive(item, rhs)); - } - - match types.get(item) { - Type::Ref(..) => unreachable!(), - Type::Lazy(..) => unreachable!(), - Type::Alias(..) => unreachable!(), - Type::Struct(..) => unreachable!(), - Type::Callable(..) => unreachable!(), - Type::Generic => return Err(CheckError::Impossible(item, rhs)), - Type::Unknown => {} - Type::Never => { - length -= 1; - } - Type::Atom | Type::Bytes | Type::Int => { - atom_count += 1; - } - Type::Bytes32 => { - atom_count += 1; - bytes32_count += 1; - } - Type::PublicKey => { - atom_count += 1; - public_key_count += 1; - } - Type::Nil => { - atom_count += 1; - *values.entry(BigInt::ZERO).or_insert(0) += 1; - } - Type::True => { - atom_count += 1; - *values.entry(BigInt::one()).or_insert(0) += 1; - } - Type::False => { - atom_count += 1; - *values.entry(BigInt::ZERO).or_insert(0) += 1; - } - Type::Value(value) => { - atom_count += 1; - *values.entry(value.clone()).or_insert(0) += 1; - } - Type::Pair(first, rest) => { - pairs.push((*first, *rest)); - } - Type::Union(child_items) => { - items.extend(child_items); - } - } - - visited.remove(&(item, rhs)); - } + let attrs = union_attributes(types, items, true, rhs, visited)?; Ok(match types.get(rhs) { Type::Ref(..) => unreachable!(), @@ -364,57 +291,51 @@ where Type::Unknown => Check::None, Type::Generic => return Err(CheckError::Impossible(original_type_id, rhs)), Type::Never => return Err(CheckError::Impossible(original_type_id, rhs)), - Type::Atom if atom_count == length => Check::None, + Type::Atom if attrs.all_atoms() => Check::None, Type::Atom => Check::IsAtom, - Type::Bytes if atom_count == length => Check::None, + Type::Bytes if attrs.all_atoms() => Check::None, Type::Bytes => Check::IsAtom, - Type::Int if atom_count == length => Check::None, + Type::Int if attrs.all_atoms() => Check::None, Type::Int => Check::IsAtom, - Type::Nil if values.get(&BigInt::ZERO).copied().unwrap_or(0) == length => Check::None, - Type::Nil if values.get(&BigInt::ZERO).copied().unwrap_or(0) == atom_count => Check::IsAtom, - Type::Nil if atom_count == length => Check::Value(BigInt::ZERO), + Type::Nil if attrs.all_value(&BigInt::ZERO) => Check::None, + Type::Nil if attrs.atoms_are_value(&BigInt::ZERO) => Check::IsAtom, + Type::Nil if attrs.all_atoms() => Check::Value(BigInt::ZERO), Type::Nil => Check::And(vec![Check::IsAtom, Check::Value(BigInt::ZERO)]), - Type::False if values.get(&BigInt::ZERO).copied().unwrap_or(0) == length => Check::None, - Type::False if values.get(&BigInt::ZERO).copied().unwrap_or(0) == atom_count => { - Check::IsAtom - } - Type::False if atom_count == length => Check::Value(BigInt::ZERO), + Type::False if attrs.all_value(&BigInt::ZERO) => Check::None, + Type::False if attrs.atoms_are_value(&BigInt::ZERO) => Check::IsAtom, + Type::False if attrs.all_atoms() => Check::Value(BigInt::ZERO), Type::False => Check::And(vec![Check::IsAtom, Check::Value(BigInt::ZERO)]), - Type::True if values.get(&BigInt::one()).copied().unwrap_or(0) == length => Check::None, - Type::True if values.get(&BigInt::one()).copied().unwrap_or(0) == atom_count => { - Check::IsAtom - } - Type::True if atom_count == length => Check::Value(BigInt::one()), + Type::True if attrs.all_value(&BigInt::one()) => Check::None, + Type::True if attrs.atoms_are_value(&BigInt::ZERO) => Check::IsAtom, + Type::True if attrs.all_atoms() => Check::Value(BigInt::one()), Type::True => Check::And(vec![Check::IsAtom, Check::Value(BigInt::one())]), - Type::Value(value) if values.get(value).copied().unwrap_or(0) == length => Check::None, - Type::Value(value) if values.get(value).copied().unwrap_or(0) == atom_count => { - Check::IsAtom - } - Type::Value(value) if atom_count == length => Check::Value(value.clone()), + Type::Value(value) if attrs.all_value(value) => Check::None, + Type::Value(value) if attrs.atoms_are_value(value) => Check::IsAtom, + Type::Value(value) if attrs.all_atoms() => Check::Value(value.clone()), Type::Value(value) => Check::And(vec![Check::IsAtom, Check::Value(value.clone())]), - Type::Bytes32 if bytes32_count == length => Check::None, - Type::Bytes32 if atom_count == length => Check::Length(32), - Type::Bytes32 if bytes32_count == atom_count => Check::IsAtom, + Type::Bytes32 if attrs.all_bytes32() => Check::None, + Type::Bytes32 if attrs.all_atoms() => Check::Length(32), + Type::Bytes32 if attrs.atoms_are_bytes32() => Check::IsAtom, Type::Bytes32 => Check::And(vec![Check::IsAtom, Check::Length(32)]), - Type::PublicKey if public_key_count == length => Check::None, - Type::PublicKey if atom_count == length => Check::Length(48), - Type::PublicKey if public_key_count == atom_count => Check::IsAtom, + Type::PublicKey if attrs.all_public_key() => Check::None, + Type::PublicKey if attrs.all_atoms() => Check::Length(48), + Type::PublicKey if attrs.atoms_are_public_key() => Check::IsAtom, Type::PublicKey => Check::And(vec![Check::IsAtom, Check::Length(48)]), - Type::Pair(..) if atom_count == length => { + Type::Pair(..) if attrs.all_atoms() => { return Err(CheckError::Impossible(original_type_id, rhs)) } Type::Pair(first, rest) => { let (first, rest) = (*first, *rest); - let first_items: Vec<_> = pairs.iter().map(|(first, _)| *first).collect(); - let rest_items: Vec<_> = pairs.iter().map(|(_, rest)| *rest).collect(); + let first_items: Vec<_> = attrs.pairs.iter().map(|(first, _)| *first).collect(); + let rest_items: Vec<_> = attrs.pairs.iter().map(|(_, rest)| *rest).collect(); let first = check_union_against_rhs(types, original_type_id, &first_items, first, visited)?; let rest = check_union_against_rhs(types, original_type_id, &rest_items, rest, visited)?; - if pairs.len() == length { + if attrs.all_pairs() { Check::And(vec![ Check::First(Box::new(first)), Check::Rest(Box::new(rest)),