Skip to content

Commit

Permalink
Refactor attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
Rigidity committed Jul 23, 2024
1 parent d189bf1 commit 606426a
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 113 deletions.
2 changes: 2 additions & 0 deletions crates/rue-typing/src/check.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::fmt;

mod attributes;
mod check_error;
mod check_type;
mod simplify_and;
Expand All @@ -9,6 +10,7 @@ mod stringify_check;

pub use check_error::*;

pub(crate) use attributes::*;
pub(crate) use check_type::*;
pub(crate) use simplify_and::*;
pub(crate) use simplify_check::*;
Expand Down
140 changes: 140 additions & 0 deletions crates/rue-typing/src/check/attributes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
use std::collections::{HashMap, HashSet, VecDeque};

use num_bigint::BigInt;
use num_traits::One;

use crate::{Type, TypeId, TypeSystem};

use super::CheckError;

pub(crate) struct Attributes {
pub atom_count: usize,
pub bytes32_count: usize,
pub public_key_count: usize,
pub pairs: Vec<(TypeId, TypeId)>,
pub values: HashMap<BigInt, usize>,
pub length: usize,
}

impl Attributes {
pub fn all_atoms(&self) -> bool {
self.atom_count == self.length
}

pub fn all_bytes32(&self) -> bool {
self.bytes32_count == self.length
}

pub fn all_public_key(&self) -> bool {
self.public_key_count == self.length
}

pub fn all_pairs(&self) -> bool {
self.pairs.len() == self.length
}

pub fn all_value(&self, value: &BigInt) -> bool {
self.values.get(value).copied().unwrap_or(0) == self.length
}

pub fn atoms_are_bytes32(&self) -> bool {
self.bytes32_count == self.atom_count
}

pub fn atoms_are_public_key(&self) -> bool {
self.public_key_count == self.atom_count
}

pub fn atoms_are_value(&self, value: &BigInt) -> bool {
self.values.get(value).copied().unwrap_or(0) == self.atom_count
}
}

pub(crate) fn union_attributes(
db: &TypeSystem,
items: &[TypeId],
is_lhs: bool,
other_type_id: TypeId,
visited: &mut HashSet<(TypeId, TypeId)>,
) -> Result<Attributes, CheckError> {
let mut atom_count = 0;
let mut bytes32_count = 0;
let mut public_key_count = 0;
let mut pairs = Vec::new();
let mut values = HashMap::new();

let mut items: VecDeque<_> = items.iter().copied().collect();
let mut length = 0;

while !items.is_empty() {
let item = items.remove(0).unwrap();
length += 1;

let key = if is_lhs {
(item, other_type_id)
} else {
(other_type_id, item)
};

if !visited.insert(key) {
return Err(CheckError::Recursive(key.0, key.1));
}

match db.get(item) {
Type::Ref(..) => unreachable!(),
Type::Lazy(..) => unreachable!(),
Type::Alias(..) => unreachable!(),
Type::Struct(..) => unreachable!(),
Type::Callable(..) => unreachable!(),
Type::Generic => return Err(CheckError::Impossible(key.0, key.1)),
Type::Unknown => {}
Type::Never => {
length -= 1;
}
Type::Atom | Type::Bytes | Type::Int => {
atom_count += 1;
}
Type::Bytes32 => {
atom_count += 1;
bytes32_count += 1;
}
Type::PublicKey => {
atom_count += 1;
public_key_count += 1;
}
Type::Nil => {
atom_count += 1;
*values.entry(BigInt::ZERO).or_insert(0) += 1;
}
Type::True => {
atom_count += 1;
*values.entry(BigInt::one()).or_insert(0) += 1;
}
Type::False => {
atom_count += 1;
*values.entry(BigInt::ZERO).or_insert(0) += 1;
}
Type::Value(value) => {
atom_count += 1;
*values.entry(value.clone()).or_insert(0) += 1;
}
Type::Pair(first, rest) => {
pairs.push((*first, *rest));
}
Type::Union(child_items) => {
items.extend(child_items);
}
}

visited.remove(&key);
}

Ok(Attributes {
atom_count,
bytes32_count,
public_key_count,
pairs,
values,
length,
})
}
147 changes: 34 additions & 113 deletions crates/rue-typing/src/check/check_type.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,19 @@
use std::{
collections::{HashMap, HashSet, VecDeque},
hash::BuildHasher,
};
use std::collections::HashSet;

use num_bigint::BigInt;
use num_traits::One;

use crate::{bigint_to_bytes, Comparison, Type, TypeId, TypeSystem};

use super::{Check, CheckError};
use super::{union_attributes, Check, CheckError};

/// Returns [`None`] for recursive checks.
pub(crate) fn check_type<S>(
pub(crate) fn check_type(
types: &mut TypeSystem,
lhs: TypeId,
rhs: TypeId,
visited: &mut HashSet<(TypeId, TypeId), S>,
) -> Result<Check, CheckError>
where
S: BuildHasher,
{
visited: &mut HashSet<(TypeId, TypeId)>,
) -> Result<Check, CheckError> {
if !visited.insert((lhs, rhs)) {
if types.compare(lhs, rhs) <= Comparison::Castable {
return Ok(Check::None);
Expand Down Expand Up @@ -254,16 +248,13 @@ where
Ok(check)
}

fn check_union_against_rhs<S>(
fn check_union_against_rhs(
types: &mut TypeSystem,
original_type_id: TypeId,
items: &[TypeId],
rhs: TypeId,
visited: &mut HashSet<(TypeId, TypeId), S>,
) -> Result<Check, CheckError>
where
S: BuildHasher,
{
visited: &mut HashSet<(TypeId, TypeId)>,
) -> Result<Check, CheckError> {
let union = types.alloc(Type::Union(items.to_vec()));

if types.compare(union, rhs) <= Comparison::Castable {
Expand All @@ -288,71 +279,7 @@ where
return Ok(Check::Or(result));
}

let mut atom_count = 0;
let mut bytes32_count = 0;
let mut public_key_count = 0;
let mut pairs = Vec::new();
let mut values = HashMap::new();

let mut items: VecDeque<_> = items.iter().copied().collect::<VecDeque<_>>();
let mut length = 0;

while !items.is_empty() {
let item = items.remove(0).unwrap();
length += 1;

if !visited.insert((item, rhs)) {
return Err(CheckError::Recursive(item, rhs));
}

match types.get(item) {
Type::Ref(..) => unreachable!(),
Type::Lazy(..) => unreachable!(),
Type::Alias(..) => unreachable!(),
Type::Struct(..) => unreachable!(),
Type::Callable(..) => unreachable!(),
Type::Generic => return Err(CheckError::Impossible(item, rhs)),
Type::Unknown => {}
Type::Never => {
length -= 1;
}
Type::Atom | Type::Bytes | Type::Int => {
atom_count += 1;
}
Type::Bytes32 => {
atom_count += 1;
bytes32_count += 1;
}
Type::PublicKey => {
atom_count += 1;
public_key_count += 1;
}
Type::Nil => {
atom_count += 1;
*values.entry(BigInt::ZERO).or_insert(0) += 1;
}
Type::True => {
atom_count += 1;
*values.entry(BigInt::one()).or_insert(0) += 1;
}
Type::False => {
atom_count += 1;
*values.entry(BigInt::ZERO).or_insert(0) += 1;
}
Type::Value(value) => {
atom_count += 1;
*values.entry(value.clone()).or_insert(0) += 1;
}
Type::Pair(first, rest) => {
pairs.push((*first, *rest));
}
Type::Union(child_items) => {
items.extend(child_items);
}
}

visited.remove(&(item, rhs));
}
let attrs = union_attributes(types, items, true, rhs, visited)?;

Ok(match types.get(rhs) {
Type::Ref(..) => unreachable!(),
Expand All @@ -364,57 +291,51 @@ where
Type::Unknown => Check::None,
Type::Generic => return Err(CheckError::Impossible(original_type_id, rhs)),
Type::Never => return Err(CheckError::Impossible(original_type_id, rhs)),
Type::Atom if atom_count == length => Check::None,
Type::Atom if attrs.all_atoms() => Check::None,
Type::Atom => Check::IsAtom,
Type::Bytes if atom_count == length => Check::None,
Type::Bytes if attrs.all_atoms() => Check::None,
Type::Bytes => Check::IsAtom,
Type::Int if atom_count == length => Check::None,
Type::Int if attrs.all_atoms() => Check::None,
Type::Int => Check::IsAtom,
Type::Nil if values.get(&BigInt::ZERO).copied().unwrap_or(0) == length => Check::None,
Type::Nil if values.get(&BigInt::ZERO).copied().unwrap_or(0) == atom_count => Check::IsAtom,
Type::Nil if atom_count == length => Check::Value(BigInt::ZERO),
Type::Nil if attrs.all_value(&BigInt::ZERO) => Check::None,
Type::Nil if attrs.atoms_are_value(&BigInt::ZERO) => Check::IsAtom,
Type::Nil if attrs.all_atoms() => Check::Value(BigInt::ZERO),
Type::Nil => Check::And(vec![Check::IsAtom, Check::Value(BigInt::ZERO)]),
Type::False if values.get(&BigInt::ZERO).copied().unwrap_or(0) == length => Check::None,
Type::False if values.get(&BigInt::ZERO).copied().unwrap_or(0) == atom_count => {
Check::IsAtom
}
Type::False if atom_count == length => Check::Value(BigInt::ZERO),
Type::False if attrs.all_value(&BigInt::ZERO) => Check::None,
Type::False if attrs.atoms_are_value(&BigInt::ZERO) => Check::IsAtom,
Type::False if attrs.all_atoms() => Check::Value(BigInt::ZERO),
Type::False => Check::And(vec![Check::IsAtom, Check::Value(BigInt::ZERO)]),
Type::True if values.get(&BigInt::one()).copied().unwrap_or(0) == length => Check::None,
Type::True if values.get(&BigInt::one()).copied().unwrap_or(0) == atom_count => {
Check::IsAtom
}
Type::True if atom_count == length => Check::Value(BigInt::one()),
Type::True if attrs.all_value(&BigInt::one()) => Check::None,
Type::True if attrs.atoms_are_value(&BigInt::ZERO) => Check::IsAtom,
Type::True if attrs.all_atoms() => Check::Value(BigInt::one()),
Type::True => Check::And(vec![Check::IsAtom, Check::Value(BigInt::one())]),
Type::Value(value) if values.get(value).copied().unwrap_or(0) == length => Check::None,
Type::Value(value) if values.get(value).copied().unwrap_or(0) == atom_count => {
Check::IsAtom
}
Type::Value(value) if atom_count == length => Check::Value(value.clone()),
Type::Value(value) if attrs.all_value(value) => Check::None,
Type::Value(value) if attrs.atoms_are_value(value) => Check::IsAtom,
Type::Value(value) if attrs.all_atoms() => Check::Value(value.clone()),
Type::Value(value) => Check::And(vec![Check::IsAtom, Check::Value(value.clone())]),
Type::Bytes32 if bytes32_count == length => Check::None,
Type::Bytes32 if atom_count == length => Check::Length(32),
Type::Bytes32 if bytes32_count == atom_count => Check::IsAtom,
Type::Bytes32 if attrs.all_bytes32() => Check::None,
Type::Bytes32 if attrs.all_atoms() => Check::Length(32),
Type::Bytes32 if attrs.atoms_are_bytes32() => Check::IsAtom,
Type::Bytes32 => Check::And(vec![Check::IsAtom, Check::Length(32)]),
Type::PublicKey if public_key_count == length => Check::None,
Type::PublicKey if atom_count == length => Check::Length(48),
Type::PublicKey if public_key_count == atom_count => Check::IsAtom,
Type::PublicKey if attrs.all_public_key() => Check::None,
Type::PublicKey if attrs.all_atoms() => Check::Length(48),
Type::PublicKey if attrs.atoms_are_public_key() => Check::IsAtom,
Type::PublicKey => Check::And(vec![Check::IsAtom, Check::Length(48)]),
Type::Pair(..) if atom_count == length => {
Type::Pair(..) if attrs.all_atoms() => {
return Err(CheckError::Impossible(original_type_id, rhs))
}
Type::Pair(first, rest) => {
let (first, rest) = (*first, *rest);

let first_items: Vec<_> = pairs.iter().map(|(first, _)| *first).collect();
let rest_items: Vec<_> = pairs.iter().map(|(_, rest)| *rest).collect();
let first_items: Vec<_> = attrs.pairs.iter().map(|(first, _)| *first).collect();
let rest_items: Vec<_> = attrs.pairs.iter().map(|(_, rest)| *rest).collect();

let first =
check_union_against_rhs(types, original_type_id, &first_items, first, visited)?;
let rest =
check_union_against_rhs(types, original_type_id, &rest_items, rest, visited)?;

if pairs.len() == length {
if attrs.all_pairs() {
Check::And(vec![
Check::First(Box::new(first)),
Check::Rest(Box::new(rest)),
Expand Down

0 comments on commit 606426a

Please sign in to comment.