Skip to content

Commit

Permalink
fix: Mutability in the comptime interpreter (#5517)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves #5504

## Summary\*

Mutable variables weren't wrapped in a mutable reference before like
they are in SSA - they are now.

## Additional Context

This is an alternate version from the last PR - I've cherry-picked the
last PR so this can be merged into master instead of another branch.

## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: Tom French <[email protected]>
  • Loading branch information
jfecher and TomAFrench authored Jul 15, 2024
1 parent 029584b commit 8cab4ac
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 12 deletions.
41 changes: 34 additions & 7 deletions compiler/noirc_frontend/src/hir/comptime/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ impl<'a> Interpreter<'a> {
Ok(())
}
HirPattern::Mutable(pattern, _) => {
// Create a mutable reference to store to
let argument = Value::Pointer(Shared::new(argument), true);
self.define_pattern(pattern, typ, argument, location)
}
HirPattern::Tuple(pattern_fields, _) => match (argument, typ) {
Expand Down Expand Up @@ -334,8 +336,19 @@ impl<'a> Interpreter<'a> {
}
}

/// Evaluate an expression and return the result
/// Evaluate an expression and return the result.
/// This will automatically dereference a mutable variable if used.
pub fn evaluate(&mut self, id: ExprId) -> IResult<Value> {
match self.evaluate_no_dereference(id)? {
Value::Pointer(elem, true) => Ok(elem.borrow().clone()),
other => Ok(other),
}
}

/// Evaluating a mutable variable will dereference it automatically.
/// This function should be used when that is not desired - e.g. when
/// compiling a `&mut var` expression to grab the original reference.
fn evaluate_no_dereference(&mut self, id: ExprId) -> IResult<Value> {
match self.interner.expression(&id) {
HirExpression::Ident(ident, _) => self.evaluate_ident(ident, id),
HirExpression::Literal(literal) => self.evaluate_literal(literal, id),
Expand Down Expand Up @@ -592,7 +605,10 @@ impl<'a> Interpreter<'a> {
}

fn evaluate_prefix(&mut self, prefix: HirPrefixExpression, id: ExprId) -> IResult<Value> {
let rhs = self.evaluate(prefix.rhs)?;
let rhs = match prefix.operator {
UnaryOp::MutableReference => self.evaluate_no_dereference(prefix.rhs)?,
_ => self.evaluate(prefix.rhs)?,
};
self.evaluate_prefix_with_value(rhs, prefix.operator, id)
}

Expand Down Expand Up @@ -634,9 +650,17 @@ impl<'a> Interpreter<'a> {
Err(InterpreterError::InvalidValueForUnary { value, location, operator: "not" })
}
},
UnaryOp::MutableReference => Ok(Value::Pointer(Shared::new(rhs))),
UnaryOp::MutableReference => {
// If this is a mutable variable (auto_deref = true), turn this into an explicit
// mutable reference just by switching the value of `auto_deref`. Otherwise, wrap
// the value in a fresh reference.
match rhs {
Value::Pointer(elem, true) => Ok(Value::Pointer(elem, false)),
other => Ok(Value::Pointer(Shared::new(other), false)),
}
}
UnaryOp::Dereference { implicitly_added: _ } => match rhs {
Value::Pointer(element) => Ok(element.borrow().clone()),
Value::Pointer(element, _) => Ok(element.borrow().clone()),
value => {
let location = self.interner.expr_location(&id);
Err(InterpreterError::NonPointerDereferenced { value, location })
Expand Down Expand Up @@ -1303,7 +1327,7 @@ impl<'a> Interpreter<'a> {
HirLValue::Ident(ident, typ) => self.mutate(ident.id, rhs, ident.location),
HirLValue::Dereference { lvalue, element_type: _, location } => {
match self.evaluate_lvalue(&lvalue)? {
Value::Pointer(value) => {
Value::Pointer(value, _) => {
*value.borrow_mut() = rhs;
Ok(())
}
Expand Down Expand Up @@ -1353,10 +1377,13 @@ impl<'a> Interpreter<'a> {

fn evaluate_lvalue(&mut self, lvalue: &HirLValue) -> IResult<Value> {
match lvalue {
HirLValue::Ident(ident, _) => self.lookup(ident),
HirLValue::Ident(ident, _) => match self.lookup(ident)? {
Value::Pointer(elem, true) => Ok(elem.borrow().clone()),
other => Ok(other),
},
HirLValue::Dereference { lvalue, element_type: _, location } => {
match self.evaluate_lvalue(lvalue)? {
Value::Pointer(value) => Ok(value.borrow().clone()),
Value::Pointer(value, _) => Ok(value.borrow().clone()),
value => {
Err(InterpreterError::NonPointerDereferenced { value, location: *location })
}
Expand Down
12 changes: 12 additions & 0 deletions compiler/noirc_frontend/src/hir/comptime/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ fn mutating_mutable_references() {
assert_eq!(result, Value::I64(4));
}

#[test]
fn mutation_leaks() {
let program = "comptime fn main() -> pub i8 {
let mut x = 3;
let y = &mut x;
*y = 5;
x
}";
let result = interpret(program, vec!["main".into()]);
assert_eq!(result, Value::I8(5));
}

#[test]
fn mutating_arrays() {
let program = "comptime fn main() -> pub u8 {
Expand Down
10 changes: 5 additions & 5 deletions compiler/noirc_frontend/src/hir/comptime/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub enum Value {
Closure(HirLambda, Vec<Value>, Type),
Tuple(Vec<Value>),
Struct(HashMap<Rc<String>, Value>, Type),
Pointer(Shared<Value>),
Pointer(Shared<Value>, /* auto_deref */ bool),
Array(Vector<Value>, Type),
Slice(Vector<Value>, Type),
Code(Rc<Tokens>),
Expand Down Expand Up @@ -79,7 +79,7 @@ impl Value {
Value::Slice(_, typ) => return Cow::Borrowed(typ),
Value::Code(_) => Type::Quoted(QuotedType::Quoted),
Value::StructDefinition(_) => Type::Quoted(QuotedType::StructDefinition),
Value::Pointer(element) => {
Value::Pointer(element, _) => {
let element = element.borrow().get_type().into_owned();
Type::MutableReference(Box::new(element))
}
Expand Down Expand Up @@ -199,7 +199,7 @@ impl Value {
}
};
}
Value::Pointer(_)
Value::Pointer(..)
| Value::StructDefinition(_)
| Value::TraitDefinition(_)
| Value::FunctionDefinition(_)
Expand Down Expand Up @@ -309,7 +309,7 @@ impl Value {
HirExpression::Literal(HirLiteral::Slice(HirArrayLiteral::Standard(elements)))
}
Value::Code(block) => HirExpression::Unquote(unwrap_rc(block)),
Value::Pointer(_)
Value::Pointer(..)
| Value::StructDefinition(_)
| Value::TraitDefinition(_)
| Value::FunctionDefinition(_)
Expand Down Expand Up @@ -400,7 +400,7 @@ impl Display for Value {
let fields = vecmap(fields, |(name, value)| format!("{}: {}", name, value));
write!(f, "{typename} {{ {} }}", fields.join(", "))
}
Value::Pointer(value) => write!(f, "&mut {}", value.borrow()),
Value::Pointer(value, _) => write!(f, "&mut {}", value.borrow()),
Value::Array(values, _) => {
let values = vecmap(values, ToString::to_string);
write!(f, "[{}]", values.join(", "))
Expand Down

0 comments on commit 8cab4ac

Please sign in to comment.