From 3c5d77ee03130a54fb13fff24306ae6bf3b5f568 Mon Sep 17 00:00:00 2001 From: Clement Delafargue Date: Wed, 2 Oct 2024 10:29:40 +0200 Subject: [PATCH 1/8] wip: datalog foreign function interface prototype MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This allows using external functions in datalog. This makes it easy to provide custom logic without extending the spec for every use-case, at the expense of portability: behaviour is no longer guaranteed to be consistent cross languages, and some languages won’t be able to support it at all (for instance JS as of now). Todo: - stricter conversions from datalog - feature-gating if possible Open questions: - enum index for the FFI variants (contiguous or not?) - how to provide functions (right now, function pointers: prevent mutability and closing over arguments) - how to provide arguments (right now, datalog::Term, so symbols have to be resolved, and functions returning strings have to register new symbols) --- biscuit-auth/src/datalog/expression.rs | 248 +++++++++++++----- biscuit-auth/src/datalog/mod.rs | 62 +++-- biscuit-auth/src/error.rs | 6 + biscuit-auth/src/format/convert.rs | 20 ++ biscuit-auth/src/format/schema.proto | 4 + biscuit-auth/src/format/schema.rs | 6 + biscuit-auth/src/parser.rs | 12 +- biscuit-auth/src/token/authorizer.rs | 41 ++- biscuit-auth/src/token/authorizer/snapshot.rs | 1 + biscuit-auth/src/token/builder.rs | 2 + biscuit-auth/tests/macros.rs | 8 + biscuit-capi/src/lib.rs | 2 + biscuit-capi/tests/capi.rs | 2 +- biscuit-parser/src/builder.rs | 4 + biscuit-parser/src/parser.rs | 37 +++ 15 files changed, 366 insertions(+), 89 deletions(-) diff --git a/biscuit-auth/src/datalog/expression.rs b/biscuit-auth/src/datalog/expression.rs index 67d9e6f8..48ba7772 100644 --- a/biscuit-auth/src/datalog/expression.rs +++ b/biscuit-auth/src/datalog/expression.rs @@ -6,6 +6,16 @@ use regex::Regex; use std::collections::{HashMap, HashSet}; use std::convert::TryFrom; +type ExternBinary = fn(&mut TemporarySymbolTable, &Term, &Term) -> Result; + +type ExternUnary = fn(&mut TemporarySymbolTable, &Term) -> Result; + +#[derive(Debug, Clone)] +pub enum ExternFunc { + Unary(ExternUnary), + Binary(ExternBinary), +} + #[derive(Debug, Clone, PartialEq, Hash, Eq)] pub struct Expression { pub ops: Vec, @@ -26,6 +36,7 @@ pub enum Unary { Parens, Length, TypeOf, + Ffi(String), } impl Unary { @@ -33,6 +44,7 @@ impl Unary { &self, value: Term, symbols: &mut TemporarySymbolTable, + extern_funcs: &HashMap, ) -> Result { match (self, value) { (Unary::Negate, Term::Bool(b)) => Ok(Term::Bool(!b)), @@ -61,6 +73,15 @@ impl Unary { let sym = symbols.insert(type_string); Ok(Term::Str(sym)) } + (Unary::Ffi(name), i) => { + let fun = extern_funcs + .get(name) + .ok_or(error::Expression::UndefinedExtern(name.to_owned()))?; + match fun { + ExternFunc::Unary(fun) => fun(symbols, &i), + ExternFunc::Binary(_) => Err(error::Expression::IncorrectArityExtern), + } + } _ => { //println!("unexpected value type on the stack"); Err(error::Expression::InvalidType) @@ -74,6 +95,7 @@ impl Unary { Unary::Parens => format!("({})", value), Unary::Length => format!("{}.length()", value), Unary::TypeOf => format!("{}.type()", value), + Unary::Ffi(name) => format!("{value}.extern::{name}()"), } } } @@ -109,6 +131,7 @@ pub enum Binary { All, Any, Get, + Ffi(String), } impl Binary { @@ -119,18 +142,19 @@ impl Binary { params: &[u32], values: &mut HashMap, symbols: &mut TemporarySymbolTable, + extern_func: &HashMap, ) -> Result { match (self, left, params) { // boolean (Binary::LazyOr, Term::Bool(true), []) => Ok(Term::Bool(true)), (Binary::LazyOr, Term::Bool(false), []) => { let e = Expression { ops: right.clone() }; - e.evaluate(values, symbols) + e.evaluate(values, symbols, extern_func) } (Binary::LazyAnd, Term::Bool(false), []) => Ok(Term::Bool(false)), (Binary::LazyAnd, Term::Bool(true), []) => { let e = Expression { ops: right.clone() }; - e.evaluate(values, symbols) + e.evaluate(values, symbols, extern_func) } // set @@ -138,7 +162,7 @@ impl Binary { for value in set_values.iter() { values.insert(*param, value.clone()); let e = Expression { ops: right.clone() }; - let result = e.evaluate(values, symbols); + let result = e.evaluate(values, symbols, extern_func); values.remove(param); match result? { Term::Bool(true) => {} @@ -152,7 +176,7 @@ impl Binary { for value in set_values.iter() { values.insert(*param, value.clone()); let e = Expression { ops: right.clone() }; - let result = e.evaluate(values, symbols); + let result = e.evaluate(values, symbols, extern_func); values.remove(param); match result? { Term::Bool(false) => {} @@ -168,7 +192,7 @@ impl Binary { for value in array.iter() { values.insert(*param, value.clone()); let e = Expression { ops: right.clone() }; - let result = e.evaluate(values, symbols); + let result = e.evaluate(values, symbols, extern_func); values.remove(param); match result? { Term::Bool(true) => {} @@ -182,7 +206,7 @@ impl Binary { for value in array.iter() { values.insert(*param, value.clone()); let e = Expression { ops: right.clone() }; - let result = e.evaluate(values, symbols); + let result = e.evaluate(values, symbols, extern_func); values.remove(param); match result? { Term::Bool(false) => {} @@ -203,7 +227,7 @@ impl Binary { values.insert(*param, Term::Array(vec![key, value.clone()])); let e = Expression { ops: right.clone() }; - let result = e.evaluate(values, symbols); + let result = e.evaluate(values, symbols, extern_func); values.remove(param); match result? { Term::Bool(true) => {} @@ -222,7 +246,7 @@ impl Binary { values.insert(*param, Term::Array(vec![key, value.clone()])); let e = Expression { ops: right.clone() }; - let result = e.evaluate(values, symbols); + let result = e.evaluate(values, symbols, extern_func); values.remove(param); match result? { Term::Bool(false) => {} @@ -240,6 +264,7 @@ impl Binary { left: Term, right: Term, symbols: &mut TemporarySymbolTable, + extern_funcs: &HashMap, ) -> Result { match (self, left, right) { // integer @@ -438,9 +463,21 @@ impl Binary { None => Ok(Term::Null), }, + // heterogeneous equals catch all (Binary::HeterogeneousEqual, _, _) => Ok(Term::Bool(false)), (Binary::HeterogeneousNotEqual, _, _) => Ok(Term::Bool(true)), + // FFI + (Binary::Ffi(name), left, right) => { + let fun = extern_funcs + .get(name) + .ok_or(error::Expression::UndefinedExtern(name.to_owned()))?; + match fun { + ExternFunc::Binary(fun) => fun(symbols, &left, &right), + ExternFunc::Unary(_) => Err(error::Expression::IncorrectArityExtern), + } + } + _ => { //println!("unexpected value type on the stack"); Err(error::Expression::InvalidType) @@ -478,6 +515,7 @@ impl Binary { Binary::All => format!("{left}.all({right})"), Binary::Any => format!("{left}.any({right})"), Binary::Get => format!("{left}.get({right})"), + Binary::Ffi(name) => format!("{left}.extern::{name}({right})"), } } } @@ -493,6 +531,7 @@ impl Expression { &self, values: &HashMap, symbols: &mut TemporarySymbolTable, + extern_funcs: &HashMap, ) -> Result { let mut stack: Vec = Vec::new(); @@ -508,19 +547,24 @@ impl Expression { } }, Op::Value(term) => stack.push(StackElem::Term(term.clone())), - Op::Unary(unary) => match stack.pop() { - Some(StackElem::Term(term)) => { - stack.push(StackElem::Term(unary.evaluate(term, symbols)?)) - } - _ => { - return Err(error::Expression::InvalidStack); + Op::Unary(unary) => { + match stack.pop() { + Some(StackElem::Term(term)) => stack.push(StackElem::Term( + unary.evaluate(term, symbols, extern_funcs)?, + )), + _ => { + return Err(error::Expression::InvalidStack); + } } - }, + } Op::Binary(binary) => match (stack.pop(), stack.pop()) { (Some(StackElem::Term(right_term)), Some(StackElem::Term(left_term))) => stack - .push(StackElem::Term( - binary.evaluate(left_term, right_term, symbols)?, - )), + .push(StackElem::Term(binary.evaluate( + left_term, + right_term, + symbols, + extern_funcs, + )?)), ( Some(StackElem::Closure(params, right_ops)), Some(StackElem::Term(left_term)), @@ -541,6 +585,7 @@ impl Expression { ¶ms, &mut values, symbols, + extern_funcs, )?)) } @@ -638,7 +683,7 @@ mod tests { let e = Expression { ops }; println!("print: {}", e.print(&symbols).unwrap()); - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(true))); } @@ -668,7 +713,7 @@ mod tests { let e = Expression { ops }; println!("print: {}", e.print(&symbols).unwrap()); - let res = e.evaluate(&HashMap::new(), &mut tmp_symbols); + let res = e.evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Integer(expected))); } } @@ -685,7 +730,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Err(error::Expression::DivideByZero)); let ops = vec![ @@ -696,7 +741,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Err(error::Expression::Overflow)); let ops = vec![ @@ -707,7 +752,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Err(error::Expression::Overflow)); let ops = vec![ @@ -718,7 +763,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Err(error::Expression::Overflow)); } @@ -785,7 +830,7 @@ mod tests { let e = Expression { ops }; println!("print: {}", e.print(&symbols).unwrap()); - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(true))); } } @@ -809,7 +854,7 @@ mod tests { let e = Expression { ops }; println!("print: {}", e.print(&symbols).unwrap()); - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(false))); } } @@ -833,7 +878,7 @@ mod tests { let e = Expression { ops }; println!("print: {}", e.print(&symbols).unwrap()); - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(result))); } } @@ -877,7 +922,7 @@ mod tests { let e = Expression { ops }; println!("print: {}", e.print(&symbols).unwrap()); - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(*result))); } } @@ -916,7 +961,8 @@ mod tests { let e = Expression { ops }; println!("print: {}", e.print(&symbols).unwrap()); - e.evaluate(&values, &mut tmp_symbols).unwrap_err(); + e.evaluate(&values, &mut tmp_symbols, &Default::default()) + .unwrap_err(); } } } @@ -941,7 +987,9 @@ mod tests { ]; let e2 = Expression { ops: ops1 }; - let res2 = e2.evaluate(&HashMap::new(), &mut symbols).unwrap(); + let res2 = e2 + .evaluate(&HashMap::new(), &mut symbols, &Default::default()) + .unwrap(); assert_eq!(res2, Term::Bool(true)); } @@ -959,7 +1007,9 @@ mod tests { let e1 = Expression { ops: ops1 }; println!("{:?}", e1.print(&symbols)); - let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + let res1 = e1 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap(); assert_eq!(res1, Term::Bool(true)); let ops2 = vec![ @@ -977,7 +1027,9 @@ mod tests { let e2 = Expression { ops: ops2 }; println!("{:?}", e2.print(&symbols)); - let res2 = e2.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + let res2 = e2 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap(); assert_eq!(res2, Term::Bool(false)); let ops3 = vec![ @@ -988,7 +1040,9 @@ mod tests { let e3 = Expression { ops: ops3 }; println!("{:?}", e3.print(&symbols)); - let err3 = e3.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap_err(); + let err3 = e3 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap_err(); assert_eq!(err3, error::Expression::InvalidType); } @@ -1013,7 +1067,9 @@ mod tests { let e1 = Expression { ops: ops1 }; println!("{:?}", e1.print(&symbols)); - let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + let res1 = e1 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap(); assert_eq!(res1, Term::Bool(true)); let ops2 = vec![ @@ -1031,7 +1087,9 @@ mod tests { let e2 = Expression { ops: ops2 }; println!("{:?}", e2.print(&symbols)); - let res2 = e2.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + let res2 = e2 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap(); assert_eq!(res2, Term::Bool(false)); let ops3 = vec![ @@ -1042,7 +1100,9 @@ mod tests { let e3 = Expression { ops: ops3 }; println!("{:?}", e3.print(&symbols)); - let err3 = e3.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap_err(); + let err3 = e3 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap_err(); assert_eq!(err3, error::Expression::InvalidType); } @@ -1088,7 +1148,9 @@ mod tests { let e1 = Expression { ops: ops1 }; println!("{}", e1.print(&symbols).unwrap()); - let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + let res1 = e1 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap(); assert_eq!(res1, Term::Bool(true)); } @@ -1115,7 +1177,7 @@ mod tests { let mut values = HashMap::new(); values.insert(p, Term::Null); - let res1 = e1.evaluate(&values, &mut tmp_symbols); + let res1 = e1.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res1, Err(error::Expression::ShadowedVariable)); let mut symbols = SymbolTable::new(); @@ -1157,7 +1219,7 @@ mod tests { let e2 = Expression { ops: ops2 }; println!("{}", e2.print(&symbols).unwrap()); - let res2 = e2.evaluate(&HashMap::new(), &mut tmp_symbols); + let res2 = e2.evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()); assert_eq!(res2, Err(error::Expression::ShadowedVariable)); } @@ -1173,7 +1235,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(true))); let ops = vec![ @@ -1184,7 +1246,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(false))); let ops = vec![ @@ -1195,7 +1257,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(true))); let ops = vec![ @@ -1206,7 +1268,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(false))); let ops = vec![ @@ -1221,7 +1283,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(true))); let ops = vec![ @@ -1236,7 +1298,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(false))); let ops = vec![ @@ -1251,7 +1313,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(true))); let ops = vec![ @@ -1266,7 +1328,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(false))); // get @@ -1282,7 +1344,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Integer(1))); // get out of bounds @@ -1298,7 +1360,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Null)); // all @@ -1318,7 +1380,9 @@ mod tests { let e1 = Expression { ops: ops1 }; println!("{:?}", e1.print(&symbols)); - let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + let res1 = e1 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap(); assert_eq!(res1, Term::Bool(true)); // any @@ -1337,7 +1401,9 @@ mod tests { let e1 = Expression { ops: ops1 }; println!("{:?}", e1.print(&symbols)); - let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + let res1 = e1 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap(); assert_eq!(res1, Term::Bool(false)); } @@ -1371,7 +1437,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(true))); let ops = vec![ @@ -1395,7 +1461,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(false))); let ops = vec![ @@ -1414,7 +1480,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(true))); let ops = vec![ @@ -1433,7 +1499,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(false))); // get @@ -1453,7 +1519,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Integer(0))); let ops = vec![ @@ -1472,7 +1538,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Integer(1))); // get non existing key @@ -1492,7 +1558,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Null)); let ops = vec![ @@ -1511,7 +1577,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Null)); // all @@ -1540,7 +1606,9 @@ mod tests { let e1 = Expression { ops: ops1 }; println!("{:?}", e1.print(&symbols)); - let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + let res1 = e1 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap(); assert_eq!(res1, Term::Bool(true)); // any @@ -1569,7 +1637,65 @@ mod tests { let e1 = Expression { ops: ops1 }; println!("{:?}", e1.print(&symbols)); - let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + let res1 = e1 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap(); assert_eq!(res1, Term::Bool(true)); } + #[test] + fn ffi() { + let mut symbols = SymbolTable::new(); + let i = symbols.insert("test"); + let j = symbols.insert("TeSt"); + let mut tmp_symbols = TemporarySymbolTable::new(&symbols); + let ops = vec![ + Op::Value(Term::Integer(60)), + Op::Value(Term::Integer(0)), + Op::Binary(Binary::Ffi("test_bin".to_owned())), + Op::Value(Term::Str(i)), + Op::Value(Term::Str(j)), + Op::Binary(Binary::Ffi("test_bin".to_owned())), + Op::Binary(Binary::And), + Op::Value(Term::Integer(42)), + Op::Unary(Unary::Ffi("test_un".to_owned())), + Op::Binary(Binary::And), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let mut extern_funcs: HashMap = Default::default(); + extern_funcs.insert( + "test_bin".to_owned(), + ExternFunc::Binary(|sym, left, right| match (left, right) { + (Term::Integer(left), Term::Integer(right)) => { + println!("{left} {right}"); + Ok(Term::Bool((left % 60) == (right % 60))) + } + (Term::Str(left), Term::Str(right)) => { + let left = sym + .get_symbol(*left) + .ok_or(error::Expression::UnknownSymbol(*left))?; + let right = sym + .get_symbol(*right) + .ok_or(error::Expression::UnknownSymbol(*right))?; + + println!("{left} {right}"); + Ok(Term::Bool(left.to_lowercase() == right.to_lowercase())) + } + _ => Err(error::Expression::InvalidType), + }), + ); + extern_funcs.insert( + "test_un".to_owned(), + ExternFunc::Unary(|_, value| match value { + Term::Integer(value) => Ok(Term::Bool(*value == 42)), + _ => { + println!("{value:?}"); + Err(error::Expression::InvalidType) + } + }), + ); + let res = e.evaluate(&values, &mut tmp_symbols, &extern_funcs); + assert_eq!(res, Ok(Term::Bool(true))); + } } diff --git a/biscuit-auth/src/datalog/mod.rs b/biscuit-auth/src/datalog/mod.rs index f722f00c..b91ccdfe 100644 --- a/biscuit-auth/src/datalog/mod.rs +++ b/biscuit-auth/src/datalog/mod.rs @@ -138,6 +138,7 @@ impl Rule { facts: IT, rule_origin: usize, symbols: &'a SymbolTable, + extern_funcs: &'a HashMap, ) -> impl Iterator> + 'a where IT: Iterator + Clone + 'a, @@ -149,7 +150,7 @@ impl Rule { .map(move |(origin, variables)| { let mut temporary_symbols = TemporarySymbolTable::new(symbols); for e in self.expressions.iter() { - match e.evaluate(&variables, &mut temporary_symbols) { + match e.evaluate(&variables, &mut temporary_symbols, extern_funcs) { Ok(Term::Bool(true)) => {} Ok(Term::Bool(false)) => return Ok((origin, variables, false)), Ok(_) => return Err(error::Expression::InvalidType), @@ -194,9 +195,10 @@ impl Rule { origin: usize, scope: &TrustedOrigins, symbols: &SymbolTable, + extern_funcs: &HashMap, ) -> Result { let fact_it = facts.iterator(scope); - let mut it = self.apply(fact_it, origin, symbols); + let mut it = self.apply(fact_it, origin, symbols, extern_funcs); let next = it.next(); match next { @@ -211,6 +213,7 @@ impl Rule { facts: &FactSet, scope: &TrustedOrigins, symbols: &SymbolTable, + extern_funcs: &HashMap, ) -> Result { let fact_it = facts.iterator(scope); let variables = MatchedVariables::new(self.variables_set()); @@ -221,7 +224,7 @@ impl Rule { let mut temporary_symbols = TemporarySymbolTable::new(symbols); for e in self.expressions.iter() { - match e.evaluate(&variables, &mut temporary_symbols) { + match e.evaluate(&variables, &mut temporary_symbols, extern_funcs) { Ok(Term::Bool(true)) => {} Ok(Term::Bool(false)) => { //println!("expr returned {:?}", res); @@ -619,7 +622,7 @@ impl World { for (scope, rules) in self.rules.inner.iter() { let it = self.facts.iterator(scope); for (origin, rule) in rules { - for res in rule.apply(it.clone(), *origin, symbols) { + for res in rule.apply(it.clone(), *origin, symbols, &limits.extern_funcs) { match res { Ok((origin, fact)) => { new_facts.insert(&origin, fact); @@ -690,11 +693,12 @@ impl World { origin: usize, scope: &TrustedOrigins, symbols: &SymbolTable, + extern_funcs: &HashMap, ) -> Result { let mut new_facts = FactSet::default(); let it = self.facts.iterator(scope); //new_facts.extend(rule.apply(it, origin, symbols)); - for res in rule.apply(it.clone(), origin, symbols) { + for res in rule.apply(it.clone(), origin, symbols, extern_funcs) { match res { Ok((origin, fact)) => { new_facts.insert(&origin, fact); @@ -714,8 +718,9 @@ impl World { origin: usize, scope: &TrustedOrigins, symbols: &SymbolTable, + extern_funcs: &HashMap, ) -> Result { - rule.find_match(&self.facts, origin, scope, symbols) + rule.find_match(&self.facts, origin, scope, symbols, extern_funcs) } pub fn query_match_all( @@ -723,8 +728,9 @@ impl World { rule: Rule, scope: &TrustedOrigins, symbols: &SymbolTable, + extern_funcs: &HashMap, ) -> Result { - rule.check_match_all(&self.facts, scope, symbols) + rule.check_match_all(&self.facts, scope, symbols, extern_funcs) } } @@ -737,6 +743,8 @@ pub struct RunLimits { pub max_iterations: u64, /// maximum execution time pub max_time: Duration, + + pub extern_funcs: HashMap, } impl std::default::Default for RunLimits { @@ -745,6 +753,7 @@ impl std::default::Default for RunLimits { max_facts: 1000, max_iterations: 100, max_time: Duration::from_millis(1), + extern_funcs: Default::default(), } } } @@ -1047,7 +1056,8 @@ mod tests { println!("symbols: {:?}", syms); println!("testing r1: {}", syms.print_rule(&r1)); - let query_rule_result = w.query_rule(r1, 0, &[0].iter().collect(), &syms); + let query_rule_result = + w.query_rule(r1, 0, &[0].iter().collect(), &syms, &Default::default()); println!("grandparents query_rules: {:?}", query_rule_result); println!("current facts: {:?}", w.facts); @@ -1092,6 +1102,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); @@ -1109,7 +1120,8 @@ mod tests { ), 0, &[0].iter().collect(), - &syms + &syms, + &Default::default() ) ); println!( @@ -1125,7 +1137,8 @@ mod tests { ), 0, &[0].iter().collect(), - &syms + &syms, + &Default::default() ) ); w.add_fact(&[0].iter().collect(), fact(parent, &[&c, &e])); @@ -1143,6 +1156,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); println!("grandparents after inserting parent(C, E): {:?}", res); @@ -1218,6 +1232,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); @@ -1267,6 +1282,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); @@ -1353,6 +1369,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap() .iter_all() @@ -1433,7 +1450,9 @@ mod tests { ); println!("testing r1: {}", syms.print_rule(&r1)); - let res = w.query_rule(r1, 0, &[0].iter().collect(), &syms).unwrap(); + let res = w + .query_rule(r1, 0, &[0].iter().collect(), &syms, &Default::default()) + .unwrap(); for (_, fact) in res.iter_all() { println!("\t{}", syms.print_fact(fact)); } @@ -1471,7 +1490,9 @@ mod tests { ); println!("testing r2: {}", syms.print_rule(&r2)); - let res = w.query_rule(r2, 0, &[0].iter().collect(), &syms).unwrap(); + let res = w + .query_rule(r2, 0, &[0].iter().collect(), &syms, &Default::default()) + .unwrap(); for (_, fact) in res.iter_all() { println!("\t{}", syms.print_fact(fact)); } @@ -1534,6 +1555,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); @@ -1585,6 +1607,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); @@ -1630,6 +1653,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); @@ -1675,6 +1699,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); @@ -1698,6 +1723,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); @@ -1740,7 +1766,9 @@ mod tests { println!("world:\n{}\n", syms.print_world(&w)); println!("\ntesting r1: {}\n", syms.print_rule(&r1)); - let res = w.query_rule(r1, 0, &[0].iter().collect(), &syms).unwrap(); + let res = w + .query_rule(r1, 0, &[0].iter().collect(), &syms, &Default::default()) + .unwrap(); for (_, fact) in res.iter_all() { println!("\t{}", syms.print_fact(fact)); } @@ -1779,7 +1807,9 @@ mod tests { ); println!("world:\n{}\n", syms.print_world(&w)); println!("\ntesting r1: {}\n", syms.print_rule(&r1)); - let res = w.query_rule(r1, 0, &[0].iter().collect(), &syms).unwrap(); + let res = w + .query_rule(r1, 0, &[0].iter().collect(), &syms, &Default::default()) + .unwrap(); println!("generated facts:"); for (_, fact) in res.iter_all() { @@ -1795,7 +1825,9 @@ mod tests { let r2 = rule(check, &[&read], &[pred(operation, &[&read])]); println!("world:\n{}\n", syms.print_world(&w)); println!("\ntesting r2: {}\n", syms.print_rule(&r2)); - let res = w.query_rule(r2, 0, &[0].iter().collect(), &syms).unwrap(); + let res = w + .query_rule(r2, 0, &[0].iter().collect(), &syms, &Default::default()) + .unwrap(); println!("generated facts:"); for (_, fact) in res.iter_all() { diff --git a/biscuit-auth/src/error.rs b/biscuit-auth/src/error.rs index 984369c5..5209703b 100644 --- a/biscuit-auth/src/error.rs +++ b/biscuit-auth/src/error.rs @@ -150,6 +150,8 @@ pub enum Format { UnknownExternalKey, #[error("the symbol id was not in the table")] UnknownSymbol(u64), + #[error("missing FFI name field")] + MissingFfiName, } /// Signature errors @@ -250,6 +252,10 @@ pub enum Expression { InvalidStack, #[error("Shadowed variable")] ShadowedVariable, + #[error("Incorrect arity for extern func")] + IncorrectArityExtern, + #[error("Undefined extern func: {0}")] + UndefinedExtern(String), } /// runtime limits errors diff --git a/biscuit-auth/src/format/convert.rs b/biscuit-auth/src/format/convert.rs index c829c7b5..ad8a0bab 100644 --- a/biscuit-auth/src/format/convert.rs +++ b/biscuit-auth/src/format/convert.rs @@ -668,7 +668,12 @@ pub mod v2 { Unary::Parens => Kind::Parens, Unary::Length => Kind::Length, Unary::TypeOf => Kind::TypeOf, + Unary::Ffi(_) => Kind::Ffi, } as i32, + ffi_name: match u { + Unary::Ffi(name) => Some(name.to_owned()), + _ => None, + }, }) } Op::Binary(b) => { @@ -704,7 +709,12 @@ pub mod v2 { Binary::All => Kind::All, Binary::Any => Kind::Any, Binary::Get => Kind::Get, + Binary::Ffi(_) => Kind::Ffi, } as i32, + ffi_name: match b { + Binary::Ffi(name) => Some(name.to_owned()), + _ => None, + }, }) } Op::Closure(params, ops) => schema::op::Content::Closure(schema::OpClosure { @@ -733,6 +743,11 @@ pub mod v2 { Some(op_unary::Kind::Parens) => Op::Unary(Unary::Parens), Some(op_unary::Kind::Length) => Op::Unary(Unary::Length), Some(op_unary::Kind::TypeOf) => Op::Unary(Unary::TypeOf), + Some(op_unary::Kind::Ffi) => match u.ffi_name.as_ref() { + // todo clementd error if ffi name is defined with another kind + Some(n) => Op::Unary(Unary::Ffi(n.to_owned())), + None => return Err(error::Format::MissingFfiName), + }, None => { return Err(error::Format::DeserializationError( "deserialization error: unary operation is empty".to_string(), @@ -770,6 +785,11 @@ pub mod v2 { Some(op_binary::Kind::All) => Op::Binary(Binary::All), Some(op_binary::Kind::Any) => Op::Binary(Binary::Any), Some(op_binary::Kind::Get) => Op::Binary(Binary::Get), + Some(op_binary::Kind::Ffi) => match b.ffi_name.as_ref() { + // todo clementd error if ffi name is defined with another kind + Some(n) => Op::Binary(Binary::Ffi(n.to_owned())), + None => return Err(error::Format::MissingFfiName), + }, None => { return Err(error::Format::DeserializationError( "deserialization error: binary operation is empty".to_string(), diff --git a/biscuit-auth/src/format/schema.proto b/biscuit-auth/src/format/schema.proto index 57ab8b7d..ad802596 100644 --- a/biscuit-auth/src/format/schema.proto +++ b/biscuit-auth/src/format/schema.proto @@ -148,9 +148,11 @@ message OpUnary { Parens = 1; Length = 2; TypeOf = 3; + Ffi = 4; } required Kind kind = 1; + optional string ffiName = 2; } message OpBinary { @@ -183,9 +185,11 @@ message OpBinary { All = 25; Any = 26; Get = 27; + Ffi = 28; } required Kind kind = 1; + optional string ffiName = 2; } message OpClosure { diff --git a/biscuit-auth/src/format/schema.rs b/biscuit-auth/src/format/schema.rs index c90b5e61..c8189ae0 100644 --- a/biscuit-auth/src/format/schema.rs +++ b/biscuit-auth/src/format/schema.rs @@ -235,6 +235,8 @@ pub mod op { pub struct OpUnary { #[prost(enumeration="op_unary::Kind", required, tag="1")] pub kind: i32, + #[prost(string, optional, tag="2")] + pub ffi_name: ::core::option::Option<::prost::alloc::string::String>, } /// Nested message and enum types in `OpUnary`. pub mod op_unary { @@ -245,12 +247,15 @@ pub mod op_unary { Parens = 1, Length = 2, TypeOf = 3, + Ffi = 4, } } #[derive(Clone, PartialEq, ::prost::Message)] pub struct OpBinary { #[prost(enumeration="op_binary::Kind", required, tag="1")] pub kind: i32, + #[prost(string, optional, tag="2")] + pub ffi_name: ::core::option::Option<::prost::alloc::string::String>, } /// Nested message and enum types in `OpBinary`. pub mod op_binary { @@ -285,6 +290,7 @@ pub mod op_binary { All = 25, Any = 26, Get = 27, + Ffi = 28, } } #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/biscuit-auth/src/parser.rs b/biscuit-auth/src/parser.rs index 59510afb..77f0c3ef 100644 --- a/biscuit-auth/src/parser.rs +++ b/biscuit-auth/src/parser.rs @@ -383,7 +383,11 @@ mod tests { println!("print: {}", e.print(&syms).unwrap()); let h = HashMap::new(); let result = e - .evaluate(&h, &mut TemporarySymbolTable::new(&syms)) + .evaluate( + &h, + &mut TemporarySymbolTable::new(&syms), + &Default::default(), + ) .unwrap(); println!("evaluates to: {:?}", result); @@ -414,7 +418,11 @@ mod tests { println!("print: {}", e.print(&syms).unwrap()); let h = HashMap::new(); let result = e - .evaluate(&h, &mut TemporarySymbolTable::new(&syms)) + .evaluate( + &h, + &mut TemporarySymbolTable::new(&syms), + &Default::default(), + ) .unwrap(); println!("evaluates to: {:?}", result); diff --git a/biscuit-auth/src/token/authorizer.rs b/biscuit-auth/src/token/authorizer.rs index 1648601d..75777834 100644 --- a/biscuit-auth/src/token/authorizer.rs +++ b/biscuit-auth/src/token/authorizer.rs @@ -469,10 +469,15 @@ impl Authorizer { &self.public_key_to_block_id, ); + let extern_binary = limits.extern_funcs.clone(); self.world.run_with_limits(&self.symbols, limits)?; - let res = self - .world - .query_rule(rule, usize::MAX, &rule_trusted_origins, &self.symbols)?; + let res = self.world.query_rule( + rule, + usize::MAX, + &rule_trusted_origins, + &self.symbols, + &extern_binary, + )?; res.inner .into_iter() @@ -552,6 +557,7 @@ impl Authorizer { rule: datalog::Rule, limits: AuthorizerLimits, ) -> Result, error::Token> { + let extern_binary = limits.extern_funcs.clone(); self.world.run_with_limits(&self.symbols, limits)?; let rule_trusted_origins = if rule.scopes.is_empty() { @@ -568,9 +574,13 @@ impl Authorizer { ) }; - let res = self - .world - .query_rule(rule, 0, &rule_trusted_origins, &self.symbols)?; + let res = self.world.query_rule( + rule, + 0, + &rule_trusted_origins, + &self.symbols, + &extern_binary, + )?; let r: HashSet<_> = res.into_iter().map(|(_, fact)| fact).collect(); @@ -741,16 +751,20 @@ impl Authorizer { usize::MAX, &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, + )?, + CheckKind::All => self.world.query_match_all( + query, + &rule_trusted_origins, + &self.symbols, + &limits.extern_funcs, )?, - CheckKind::All => { - self.world - .query_match_all(query, &rule_trusted_origins, &self.symbols)? - } CheckKind::Reject => !self.world.query_match( query, usize::MAX, &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, )?, }; @@ -799,17 +813,20 @@ impl Authorizer { 0, &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, )?, CheckKind::All => self.world.query_match_all( query.clone(), &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, )?, CheckKind::Reject => !self.world.query_match( query.clone(), 0, &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, )?, }; @@ -849,6 +866,7 @@ impl Authorizer { usize::MAX, &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, )?; let now = Instant::now(); @@ -898,17 +916,20 @@ impl Authorizer { i + 1, &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, )?, CheckKind::All => self.world.query_match_all( query.clone(), &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, )?, CheckKind::Reject => !self.world.query_match( query.clone(), i + 1, &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, )?, }; diff --git a/biscuit-auth/src/token/authorizer/snapshot.rs b/biscuit-auth/src/token/authorizer/snapshot.rs index 373aff9f..247dd890 100644 --- a/biscuit-auth/src/token/authorizer/snapshot.rs +++ b/biscuit-auth/src/token/authorizer/snapshot.rs @@ -31,6 +31,7 @@ impl super::Authorizer { max_facts: limits.max_facts, max_iterations: limits.max_iterations, max_time: Duration::from_nanos(limits.max_time), + extern_funcs: Default::default(), }; let execution_time = Duration::from_nanos(execution_time); diff --git a/biscuit-auth/src/token/builder.rs b/biscuit-auth/src/token/builder.rs index 97922451..f0220f00 100644 --- a/biscuit-auth/src/token/builder.rs +++ b/biscuit-auth/src/token/builder.rs @@ -1145,6 +1145,7 @@ impl From for Unary { biscuit_parser::builder::Unary::Parens => Unary::Parens, biscuit_parser::builder::Unary::Length => Unary::Length, biscuit_parser::builder::Unary::TypeOf => Unary::TypeOf, + biscuit_parser::builder::Unary::Ffi(name) => Unary::Ffi(name), } } } @@ -1180,6 +1181,7 @@ impl From for Binary { biscuit_parser::builder::Binary::All => Binary::All, biscuit_parser::builder::Binary::Any => Binary::Any, biscuit_parser::builder::Binary::Get => Binary::Get, + biscuit_parser::builder::Binary::Ffi(name) => Binary::Ffi(name), } } } diff --git a/biscuit-auth/tests/macros.rs b/biscuit-auth/tests/macros.rs index 86f0016a..96d6b5b0 100644 --- a/biscuit-auth/tests/macros.rs +++ b/biscuit-auth/tests/macros.rs @@ -34,6 +34,14 @@ check if "my_value".starts_with("my"); check if {false, true}.any($p -> true); "#, ); + + let b = block!(r#"check if "test".extern::toto() && "test".extern::test("test");"#); + + assert_eq!( + b.to_string(), + r#"check if "test".extern::toto() && "test".extern::test("test"); +"# + ); } #[test] diff --git a/biscuit-capi/src/lib.rs b/biscuit-capi/src/lib.rs index 23a0b290..694791e4 100644 --- a/biscuit-capi/src/lib.rs +++ b/biscuit-capi/src/lib.rs @@ -80,6 +80,7 @@ pub enum ErrorKind { FormatPublicKeyTableOverlap, FormatUnknownExternalKey, FormatUnknownSymbol, + FormatMissingFfiName, AppendOnSealed, LogicInvalidBlockRule, LogicUnauthorized, @@ -159,6 +160,7 @@ pub extern "C" fn error_kind() -> ErrorKind { ErrorKind::FormatUnknownExternalKey } Token::Format(Format::UnknownSymbol(_)) => ErrorKind::FormatUnknownSymbol, + Token::Format(Format::MissingFfiName) => ErrorKind::FormatMissingFfiName, Token::AppendOnSealed => ErrorKind::AppendOnSealed, Token::AlreadySealed => ErrorKind::AlreadySealed, Token::Language(_) => ErrorKind::LanguageError, diff --git a/biscuit-capi/tests/capi.rs b/biscuit-capi/tests/capi.rs index af7f615a..d617ce96 100644 --- a/biscuit-capi/tests/capi.rs +++ b/biscuit-capi/tests/capi.rs @@ -114,7 +114,7 @@ biscuit append error? (null) authorizer creation error? (null) authorizer add check error? (null) authorizer add policy error? (null) -authorizer error(code = 21): authorization failed +authorizer error(code = 22): authorization failed failed checks (2): Authorizer check 0: check if right("efgh") Block 1, check 0: check if operation("read") diff --git a/biscuit-parser/src/builder.rs b/biscuit-parser/src/builder.rs index 5bf0634e..486fd85b 100644 --- a/biscuit-parser/src/builder.rs +++ b/biscuit-parser/src/builder.rs @@ -284,6 +284,7 @@ pub enum Unary { Parens, Length, TypeOf, + Ffi(String), } #[derive(Debug, Clone, PartialEq, Eq)] @@ -316,6 +317,7 @@ pub enum Binary { All, Any, Get, + Ffi(String), } #[cfg(feature = "datalog-macro")] @@ -343,6 +345,7 @@ impl ToTokens for Unary { Unary::Parens => quote! {::biscuit_auth::datalog::Unary::Parens }, Unary::Length => quote! {::biscuit_auth::datalog::Unary::Length }, Unary::TypeOf => quote! {::biscuit_auth::datalog::Unary::TypeOf }, + Unary::Ffi(name) => quote! {::biscuit_auth::datalog::Unary::Ffi(#name.to_string()) }, }); } } @@ -383,6 +386,7 @@ impl ToTokens for Binary { Binary::All => quote! { ::biscuit_auth::datalog::Binary::All }, Binary::Any => quote! { ::biscuit_auth::datalog::Binary::Any }, Binary::Get => quote! { ::biscuit_auth::datalog::Binary::Get }, + Binary::Ffi(name) => quote! {::biscuit_auth::datalog::Binary::Ffi(#name.to_string()) }, }); } } diff --git a/biscuit-parser/src/parser.rs b/biscuit-parser/src/parser.rs index 0b44f108..a810b085 100644 --- a/biscuit-parser/src/parser.rs +++ b/biscuit-parser/src/parser.rs @@ -497,6 +497,16 @@ fn binary_op_7(i: &str) -> IResult<&str, builder::Binary, Error> { alt((value(Binary::Mul, tag("*")), value(Binary::Div, tag("/"))))(i) } +fn extern_un(i: &str) -> IResult<&str, builder::Unary, Error> { + let (i, func) = preceded(tag("extern::"), name)(i)?; + Ok((i, builder::Unary::Ffi(func.to_string()))) +} + +fn extern_bin(i: &str) -> IResult<&str, builder::Binary, Error> { + let (i, func) = preceded(tag("extern::"), name)(i)?; + Ok((i, builder::Binary::Ffi(func.to_string()))) +} + fn binary_op_8(i: &str) -> IResult<&str, builder::Binary, Error> { use builder::Binary; @@ -510,6 +520,7 @@ fn binary_op_8(i: &str) -> IResult<&str, builder::Binary, Error> { value(Binary::All, tag("all")), value(Binary::Any, tag("any")), value(Binary::Get, tag("get")), + extern_bin, ))(i) } @@ -720,6 +731,7 @@ fn unary_method(i: &str) -> IResult<&str, builder::Unary, Error> { let (i, op) = alt(( value(Unary::Length, tag("length")), value(Unary::TypeOf, tag("type")), + extern_un, ))(i)?; let (i, _) = char('(')(i)?; @@ -2609,6 +2621,31 @@ mod tests { Op::Value(array(h.clone())), Op::Value(var("0")), Op::Binary(Binary::Contains), + ] + )) + ) + } + + #[test] + fn extern_funcs() { + use builder::{int, Binary, Op}; + + assert_eq!( + super::expr("2.extern::toto()").map(|(i, o)| (i, o.opcodes())), + Ok(( + "", + vec![Op::Value(int(2)), Op::Unary(Unary::Ffi("toto".to_string()))], + )) + ); + + assert_eq!( + super::expr("2.extern::toto(3)").map(|(i, o)| (i, o.opcodes())), + Ok(( + "", + vec![ + Op::Value(int(2)), + Op::Value(int(3)), + Op::Binary(Binary::Ffi("toto".to_string())), ], )) ); From 95726afc80dbd0690d3ad5810443f7757d4deae6 Mon Sep 17 00:00:00 2001 From: Clement Delafargue Date: Tue, 22 Oct 2024 11:55:20 +0200 Subject: [PATCH 2/8] bump datalog version when external calls are used --- biscuit-auth/src/datalog/mod.rs | 4 ++-- biscuit-auth/src/token/mod.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/biscuit-auth/src/datalog/mod.rs b/biscuit-auth/src/datalog/mod.rs index b91ccdfe..4ddf5bcb 100644 --- a/biscuit-auth/src/datalog/mod.rs +++ b/biscuit-auth/src/datalog/mod.rs @@ -990,7 +990,7 @@ fn contains_v3_3_op(expressions: &[Expression]) -> bool { expression.ops.iter().any(|op| match op { Op::Value(term) => contains_v3_3_term(term), Op::Closure(_, _) => true, - Op::Unary(Unary::TypeOf) => true, + Op::Unary(unary) => matches!(unary, Unary::TypeOf | Unary::Ffi(_)), Op::Binary(binary) => matches!( binary, Binary::HeterogeneousEqual @@ -999,8 +999,8 @@ fn contains_v3_3_op(expressions: &[Expression]) -> bool { | Binary::LazyOr | Binary::All | Binary::Any + | Binary::Ffi(_) ), - _ => false, }) }) } diff --git a/biscuit-auth/src/token/mod.rs b/biscuit-auth/src/token/mod.rs index 6da50b29..5af64247 100644 --- a/biscuit-auth/src/token/mod.rs +++ b/biscuit-auth/src/token/mod.rs @@ -37,7 +37,7 @@ pub const MAX_SCHEMA_VERSION: u32 = 6; pub const DATALOG_3_1: u32 = 4; /// starting version for 3rd party blocks (datalog 3.2) pub const DATALOG_3_2: u32 = 5; -/// starting version for datalog 3.3 features (reject if, closures, array/map, null, …) +/// starting version for datalog 3.3 features (reject if, closures, array/map, null, external functions, …) pub const DATALOG_3_3: u32 = 6; /// some symbols are predefined and available in every implementation, to avoid From a79175d10d134123f7d65108876178586c0a0693 Mon Sep 17 00:00:00 2001 From: Clement Delafargue Date: Tue, 22 Oct 2024 11:36:52 +0200 Subject: [PATCH 3/8] wip: allow closures in external funcs, use builder terms - using boxed functions instead of function pointers allow capturing the environment - using builder terms instead of datalog terms remove the need for manual symbol management --- biscuit-auth/src/datalog/expression.rs | 98 ++++++++++++++++---------- biscuit-auth/src/error.rs | 4 +- biscuit-auth/src/token/builder.rs | 52 +++++++++++++- 3 files changed, 114 insertions(+), 40 deletions(-) diff --git a/biscuit-auth/src/datalog/expression.rs b/biscuit-auth/src/datalog/expression.rs index 48ba7772..56d8f2a1 100644 --- a/biscuit-auth/src/datalog/expression.rs +++ b/biscuit-auth/src/datalog/expression.rs @@ -1,19 +1,54 @@ -use crate::error; +use crate::{builder, error}; use super::{MapKey, Term}; use super::{SymbolTable, TemporarySymbolTable}; use regex::Regex; -use std::collections::{HashMap, HashSet}; -use std::convert::TryFrom; - -type ExternBinary = fn(&mut TemporarySymbolTable, &Term, &Term) -> Result; +use std::sync::Arc; +use std::{ + collections::{HashMap, HashSet}, + convert::TryFrom, +}; + +#[derive(Clone)] +pub struct ExternFunc( + pub Arc< + dyn Fn(builder::Term, Option) -> Result + Send + Sync, + >, +); + +impl std::fmt::Debug for ExternFunc { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "") + } +} -type ExternUnary = fn(&mut TemporarySymbolTable, &Term) -> Result; +impl ExternFunc { + pub fn new( + f: Arc< + dyn Fn(builder::Term, Option) -> Result + + Send + + Sync, + >, + ) -> Self { + Self(f) + } -#[derive(Debug, Clone)] -pub enum ExternFunc { - Unary(ExternUnary), - Binary(ExternBinary), + pub fn call( + &self, + symbols: &mut TemporarySymbolTable, + name: &str, + left: Term, + right: Option, + ) -> Result { + let left = builder::Term::from_datalog(left, symbols)?; + let right = right + .map(|right| builder::Term::from_datalog(right, symbols)) + .transpose()?; + match self.0(left, right) { + Ok(t) => Ok(t.to_datalog(symbols)), + Err(e) => Err(error::Expression::ExternEvalError(name.to_string(), e)), + } + } } #[derive(Debug, Clone, PartialEq, Hash, Eq)] @@ -77,10 +112,7 @@ impl Unary { let fun = extern_funcs .get(name) .ok_or(error::Expression::UndefinedExtern(name.to_owned()))?; - match fun { - ExternFunc::Unary(fun) => fun(symbols, &i), - ExternFunc::Binary(_) => Err(error::Expression::IncorrectArityExtern), - } + fun.call(symbols, name, i, None) } _ => { //println!("unexpected value type on the stack"); @@ -472,10 +504,7 @@ impl Binary { let fun = extern_funcs .get(name) .ok_or(error::Expression::UndefinedExtern(name.to_owned()))?; - match fun { - ExternFunc::Binary(fun) => fun(symbols, &left, &right), - ExternFunc::Unary(_) => Err(error::Expression::IncorrectArityExtern), - } + fun.call(symbols, name, left, Some(right)) } _ => { @@ -1666,34 +1695,29 @@ mod tests { let mut extern_funcs: HashMap = Default::default(); extern_funcs.insert( "test_bin".to_owned(), - ExternFunc::Binary(|sym, left, right| match (left, right) { - (Term::Integer(left), Term::Integer(right)) => { + ExternFunc::new(Arc::new(|left, right| match (left, right) { + (builder::Term::Integer(left), Some(builder::Term::Integer(right))) => { println!("{left} {right}"); - Ok(Term::Bool((left % 60) == (right % 60))) + Ok(builder::Term::Bool((left % 60) == (right % 60))) } - (Term::Str(left), Term::Str(right)) => { - let left = sym - .get_symbol(*left) - .ok_or(error::Expression::UnknownSymbol(*left))?; - let right = sym - .get_symbol(*right) - .ok_or(error::Expression::UnknownSymbol(*right))?; - + (builder::Term::Str(left), Some(builder::Term::Str(right))) => { println!("{left} {right}"); - Ok(Term::Bool(left.to_lowercase() == right.to_lowercase())) + Ok(builder::Term::Bool( + left.to_lowercase() == right.to_lowercase(), + )) } - _ => Err(error::Expression::InvalidType), - }), + _ => Err("Expected two strings or two integers".to_string()), + })), ); extern_funcs.insert( "test_un".to_owned(), - ExternFunc::Unary(|_, value| match value { - Term::Integer(value) => Ok(Term::Bool(*value == 42)), + ExternFunc::new(Arc::new(|left, right| match (&left, &right) { + (builder::Term::Integer(left), None) => Ok(builder::boolean(*left == 42)), _ => { - println!("{value:?}"); - Err(error::Expression::InvalidType) + println!("{left:?}, {right:?}"); + Err("expecting a single integer".to_string()) } - }), + })), ); let res = e.evaluate(&values, &mut tmp_symbols, &extern_funcs); assert_eq!(res, Ok(Term::Bool(true))); diff --git a/biscuit-auth/src/error.rs b/biscuit-auth/src/error.rs index 5209703b..d5570ca9 100644 --- a/biscuit-auth/src/error.rs +++ b/biscuit-auth/src/error.rs @@ -252,10 +252,10 @@ pub enum Expression { InvalidStack, #[error("Shadowed variable")] ShadowedVariable, - #[error("Incorrect arity for extern func")] - IncorrectArityExtern, #[error("Undefined extern func: {0}")] UndefinedExtern(String), + #[error("Error while evaluating extern func {0}: {1}")] + ExternEvalError(String, String), } /// runtime limits errors diff --git a/biscuit-auth/src/token/builder.rs b/biscuit-auth/src/token/builder.rs index f0220f00..9982fbcd 100644 --- a/biscuit-auth/src/token/builder.rs +++ b/biscuit-auth/src/token/builder.rs @@ -1,7 +1,7 @@ //! helper functions and structure to create tokens and blocks use super::{default_symbol_table, Biscuit, Block}; use crate::crypto::{KeyPair, PublicKey}; -use crate::datalog::{self, get_schema_version, SymbolTable}; +use crate::datalog::{self, get_schema_version, SymbolTable, TemporarySymbolTable}; use crate::error; use crate::token::builder_ext::BuilderExt; use biscuit_parser::parser::parse_block_source; @@ -519,6 +519,56 @@ pub enum MapKey { Parameter(String), } +impl Term { + pub fn to_datalog(self, symbols: &mut TemporarySymbolTable) -> datalog::Term { + match self { + Term::Variable(s) => datalog::Term::Variable(symbols.insert(&s) as u32), + Term::Integer(i) => datalog::Term::Integer(i), + Term::Str(s) => datalog::Term::Str(symbols.insert(&s)), + Term::Date(d) => datalog::Term::Date(d), + Term::Bytes(s) => datalog::Term::Bytes(s), + Term::Bool(b) => datalog::Term::Bool(b), + Term::Set(s) => { + datalog::Term::Set(s.into_iter().map(|i| i.to_datalog(symbols)).collect()) + } + Term::Null => datalog::Term::Null, + // The error is caught in the `add_xxx` functions, so this should + // not happen™ + Term::Parameter(s) => panic!("Remaining parameter {}", &s), + } + } + + pub fn from_datalog( + term: datalog::Term, + symbols: &TemporarySymbolTable, + ) -> Result { + Ok(match term { + datalog::Term::Variable(s) => Term::Variable( + symbols + .get_symbol(s as u64) + .ok_or(error::Expression::UnknownVariable(s))? + .to_string(), + ), + datalog::Term::Integer(i) => Term::Integer(i), + datalog::Term::Str(s) => Term::Str( + symbols + .get_symbol(s) + .ok_or(error::Expression::UnknownSymbol(s))? + .to_string(), + ), + datalog::Term::Date(d) => Term::Date(d), + datalog::Term::Bytes(s) => Term::Bytes(s), + datalog::Term::Bool(b) => Term::Bool(b), + datalog::Term::Set(s) => Term::Set( + s.into_iter() + .map(|i| Self::from_datalog(i, symbols)) + .collect::>()?, + ), + datalog::Term::Null => Term::Null, + }) + } +} + impl Convert for Term { fn convert(&self, symbols: &mut SymbolTable) -> datalog::Term { match self { From 2ef9900ee9d1219f11d00cab5f0c84d6ee022d87 Mon Sep 17 00:00:00 2001 From: Clement Delafargue Date: Thu, 24 Oct 2024 15:40:12 +0200 Subject: [PATCH 4/8] add tests for various cases of function definitions --- biscuit-auth/src/datalog/expression.rs | 91 +++++++++++++++++++++++++- 1 file changed, 90 insertions(+), 1 deletion(-) diff --git a/biscuit-auth/src/datalog/expression.rs b/biscuit-auth/src/datalog/expression.rs index 56d8f2a1..e26f3836 100644 --- a/biscuit-auth/src/datalog/expression.rs +++ b/biscuit-auth/src/datalog/expression.rs @@ -684,7 +684,7 @@ impl Expression { #[cfg(test)] mod tests { - use std::collections::BTreeSet; + use std::collections::{BTreeMap, BTreeSet}; use super::*; use crate::datalog::{MapKey, SymbolTable, TemporarySymbolTable}; @@ -1688,6 +1688,66 @@ mod tests { Op::Value(Term::Integer(42)), Op::Unary(Unary::Ffi("test_un".to_owned())), Op::Binary(Binary::And), + Op::Value(Term::Integer(42)), + Op::Unary(Unary::Ffi("test_closure".to_owned())), + Op::Binary(Binary::And), + Op::Value(Term::Str(i)), + Op::Unary(Unary::Ffi("test_closure".to_owned())), + Op::Binary(Binary::And), + Op::Value(Term::Integer(42)), + Op::Unary(Unary::Ffi("test_fn".to_owned())), + Op::Binary(Binary::And), + Op::Value(Term::Integer(42)), + Op::Unary(Unary::Ffi("id".to_owned())), + Op::Value(Term::Integer(42)), + Op::Binary(Binary::HeterogeneousEqual), + Op::Binary(Binary::And), + Op::Value(Term::Str(i)), + Op::Unary(Unary::Ffi("id".to_owned())), + Op::Value(Term::Str(i)), + Op::Binary(Binary::HeterogeneousEqual), + Op::Binary(Binary::And), + Op::Value(Term::Bool(true)), + Op::Unary(Unary::Ffi("id".to_owned())), + Op::Value(Term::Bool(true)), + Op::Binary(Binary::HeterogeneousEqual), + Op::Binary(Binary::And), + Op::Value(Term::Date(0)), + Op::Unary(Unary::Ffi("id".to_owned())), + Op::Value(Term::Date(0)), + Op::Binary(Binary::HeterogeneousEqual), + Op::Binary(Binary::And), + Op::Value(Term::Bytes(vec![42])), + Op::Unary(Unary::Ffi("id".to_owned())), + Op::Value(Term::Bytes(vec![42])), + Op::Binary(Binary::HeterogeneousEqual), + Op::Binary(Binary::And), + Op::Value(Term::Null), + Op::Unary(Unary::Ffi("id".to_owned())), + Op::Value(Term::Null), + Op::Binary(Binary::HeterogeneousEqual), + Op::Binary(Binary::And), + Op::Value(Term::Array(vec![Term::Null])), + Op::Unary(Unary::Ffi("id".to_owned())), + Op::Value(Term::Array(vec![Term::Null])), + Op::Binary(Binary::HeterogeneousEqual), + Op::Binary(Binary::And), + Op::Value(Term::Set(BTreeSet::from([Term::Null]))), + Op::Unary(Unary::Ffi("id".to_owned())), + Op::Value(Term::Set(BTreeSet::from([Term::Null]))), + Op::Binary(Binary::HeterogeneousEqual), + Op::Binary(Binary::And), + Op::Value(Term::Map(BTreeMap::from([ + (MapKey::Integer(42), Term::Null), + (MapKey::Str(i), Term::Null), + ]))), + Op::Unary(Unary::Ffi("id".to_owned())), + Op::Value(Term::Map(BTreeMap::from([ + (MapKey::Integer(42), Term::Null), + (MapKey::Str(i), Term::Null), + ]))), + Op::Binary(Binary::HeterogeneousEqual), + Op::Binary(Binary::And), ]; let values = HashMap::new(); @@ -1719,7 +1779,36 @@ mod tests { } })), ); + extern_funcs.insert( + "id".to_string(), + ExternFunc::new(Arc::new(|left, right| match (left, right) { + (a, None) => Ok(a), + _ => Err("expecting a single value".to_string()), + })), + ); + let closed_over_int = 42; + let closed_over_string = "test".to_string(); + extern_funcs.insert( + "test_closure".to_owned(), + ExternFunc::new(Arc::new(move |left, right| match (&left, &right) { + (builder::Term::Integer(left), None) => { + Ok(builder::boolean(*left == closed_over_int)) + } + (builder::Term::Str(left), None) => { + Ok(builder::boolean(left == &closed_over_string)) + } + _ => { + println!("{left:?}, {right:?}"); + Err("expecting a single integer".to_string()) + } + })), + ); + extern_funcs.insert("test_fn".to_owned(), ExternFunc::new(Arc::new(toto))); let res = e.evaluate(&values, &mut tmp_symbols, &extern_funcs); assert_eq!(res, Ok(Term::Bool(true))); } + + fn toto(_left: builder::Term, _right: Option) -> Result { + Ok(builder::Term::Bool(true)) + } } From 6c7cc801867bafabc9f2c091ab73e9479f15b86f Mon Sep 17 00:00:00 2001 From: Clement Delafargue Date: Fri, 25 Oct 2024 15:18:50 +0200 Subject: [PATCH 5/8] ensure that FFI function names are only defined on FFI operations --- biscuit-auth/src/error.rs | 2 - biscuit-auth/src/format/convert.rs | 134 +++++++++++++++++------------ biscuit-auth/src/token/builder.rs | 42 +++++++++ biscuit-capi/src/lib.rs | 2 - biscuit-capi/tests/capi.rs | 2 +- 5 files changed, 121 insertions(+), 61 deletions(-) diff --git a/biscuit-auth/src/error.rs b/biscuit-auth/src/error.rs index d5570ca9..5690e3eb 100644 --- a/biscuit-auth/src/error.rs +++ b/biscuit-auth/src/error.rs @@ -150,8 +150,6 @@ pub enum Format { UnknownExternalKey, #[error("the symbol id was not in the table")] UnknownSymbol(u64), - #[error("missing FFI name field")] - MissingFfiName, } /// Signature errors diff --git a/biscuit-auth/src/format/convert.rs b/biscuit-auth/src/format/convert.rs index ad8a0bab..53225d59 100644 --- a/biscuit-auth/src/format/convert.rs +++ b/biscuit-auth/src/format/convert.rs @@ -738,64 +738,86 @@ pub mod v2 { use schema::{op, op_binary, op_unary}; Ok(match op.content.as_ref() { Some(op::Content::Value(id)) => Op::Value(proto_id_to_token_term(id)?), - Some(op::Content::Unary(u)) => match op_unary::Kind::from_i32(u.kind) { - Some(op_unary::Kind::Negate) => Op::Unary(Unary::Negate), - Some(op_unary::Kind::Parens) => Op::Unary(Unary::Parens), - Some(op_unary::Kind::Length) => Op::Unary(Unary::Length), - Some(op_unary::Kind::TypeOf) => Op::Unary(Unary::TypeOf), - Some(op_unary::Kind::Ffi) => match u.ffi_name.as_ref() { - // todo clementd error if ffi name is defined with another kind - Some(n) => Op::Unary(Unary::Ffi(n.to_owned())), - None => return Err(error::Format::MissingFfiName), - }, - None => { - return Err(error::Format::DeserializationError( - "deserialization error: unary operation is empty".to_string(), - )) - } - }, - Some(op::Content::Binary(b)) => match op_binary::Kind::from_i32(b.kind) { - Some(op_binary::Kind::LessThan) => Op::Binary(Binary::LessThan), - Some(op_binary::Kind::GreaterThan) => Op::Binary(Binary::GreaterThan), - Some(op_binary::Kind::LessOrEqual) => Op::Binary(Binary::LessOrEqual), - Some(op_binary::Kind::GreaterOrEqual) => Op::Binary(Binary::GreaterOrEqual), - Some(op_binary::Kind::Equal) => Op::Binary(Binary::Equal), - Some(op_binary::Kind::Contains) => Op::Binary(Binary::Contains), - Some(op_binary::Kind::Prefix) => Op::Binary(Binary::Prefix), - Some(op_binary::Kind::Suffix) => Op::Binary(Binary::Suffix), - Some(op_binary::Kind::Regex) => Op::Binary(Binary::Regex), - Some(op_binary::Kind::Add) => Op::Binary(Binary::Add), - Some(op_binary::Kind::Sub) => Op::Binary(Binary::Sub), - Some(op_binary::Kind::Mul) => Op::Binary(Binary::Mul), - Some(op_binary::Kind::Div) => Op::Binary(Binary::Div), - Some(op_binary::Kind::And) => Op::Binary(Binary::And), - Some(op_binary::Kind::Or) => Op::Binary(Binary::Or), - Some(op_binary::Kind::Intersection) => Op::Binary(Binary::Intersection), - Some(op_binary::Kind::Union) => Op::Binary(Binary::Union), - Some(op_binary::Kind::BitwiseAnd) => Op::Binary(Binary::BitwiseAnd), - Some(op_binary::Kind::BitwiseOr) => Op::Binary(Binary::BitwiseOr), - Some(op_binary::Kind::BitwiseXor) => Op::Binary(Binary::BitwiseXor), - Some(op_binary::Kind::NotEqual) => Op::Binary(Binary::NotEqual), - Some(op_binary::Kind::HeterogeneousEqual) => Op::Binary(Binary::HeterogeneousEqual), - Some(op_binary::Kind::HeterogeneousNotEqual) => { - Op::Binary(Binary::HeterogeneousNotEqual) + Some(op::Content::Unary(u)) => { + match (op_unary::Kind::from_i32(u.kind), u.ffi_name.as_ref()) { + (Some(op_unary::Kind::Negate), None) => Op::Unary(Unary::Negate), + (Some(op_unary::Kind::Parens), None) => Op::Unary(Unary::Parens), + (Some(op_unary::Kind::Length), None) => Op::Unary(Unary::Length), + (Some(op_unary::Kind::TypeOf), None) => Op::Unary(Unary::TypeOf), + (Some(op_unary::Kind::Ffi), Some(n)) => Op::Unary(Unary::Ffi(n.to_owned())), + (Some(op_unary::Kind::Ffi), None) => { + return Err(error::Format::DeserializationError( + "deserialization error: missing ffi name".to_string(), + )) + } + (Some(_), Some(_)) => { + return Err(error::Format::DeserializationError( + "deserialization error: ffi name set on a regular unary operation" + .to_string(), + )) + } + (None, _) => { + return Err(error::Format::DeserializationError( + "deserialization error: unary operation is empty".to_string(), + )) + } } - Some(op_binary::Kind::LazyAnd) => Op::Binary(Binary::LazyAnd), - Some(op_binary::Kind::LazyOr) => Op::Binary(Binary::LazyOr), - Some(op_binary::Kind::All) => Op::Binary(Binary::All), - Some(op_binary::Kind::Any) => Op::Binary(Binary::Any), - Some(op_binary::Kind::Get) => Op::Binary(Binary::Get), - Some(op_binary::Kind::Ffi) => match b.ffi_name.as_ref() { - // todo clementd error if ffi name is defined with another kind - Some(n) => Op::Binary(Binary::Ffi(n.to_owned())), - None => return Err(error::Format::MissingFfiName), - }, - None => { - return Err(error::Format::DeserializationError( - "deserialization error: binary operation is empty".to_string(), - )) + } + Some(op::Content::Binary(b)) => { + match (op_binary::Kind::from_i32(b.kind), b.ffi_name.as_ref()) { + (Some(op_binary::Kind::LessThan), None) => Op::Binary(Binary::LessThan), + (Some(op_binary::Kind::GreaterThan), None) => Op::Binary(Binary::GreaterThan), + (Some(op_binary::Kind::LessOrEqual), None) => Op::Binary(Binary::LessOrEqual), + (Some(op_binary::Kind::GreaterOrEqual), None) => { + Op::Binary(Binary::GreaterOrEqual) + } + (Some(op_binary::Kind::Equal), None) => Op::Binary(Binary::Equal), + (Some(op_binary::Kind::Contains), None) => Op::Binary(Binary::Contains), + (Some(op_binary::Kind::Prefix), None) => Op::Binary(Binary::Prefix), + (Some(op_binary::Kind::Suffix), None) => Op::Binary(Binary::Suffix), + (Some(op_binary::Kind::Regex), None) => Op::Binary(Binary::Regex), + (Some(op_binary::Kind::Add), None) => Op::Binary(Binary::Add), + (Some(op_binary::Kind::Sub), None) => Op::Binary(Binary::Sub), + (Some(op_binary::Kind::Mul), None) => Op::Binary(Binary::Mul), + (Some(op_binary::Kind::Div), None) => Op::Binary(Binary::Div), + (Some(op_binary::Kind::And), None) => Op::Binary(Binary::And), + (Some(op_binary::Kind::Or), None) => Op::Binary(Binary::Or), + (Some(op_binary::Kind::Intersection), None) => Op::Binary(Binary::Intersection), + (Some(op_binary::Kind::Union), None) => Op::Binary(Binary::Union), + (Some(op_binary::Kind::BitwiseAnd), None) => Op::Binary(Binary::BitwiseAnd), + (Some(op_binary::Kind::BitwiseOr), None) => Op::Binary(Binary::BitwiseOr), + (Some(op_binary::Kind::BitwiseXor), None) => Op::Binary(Binary::BitwiseXor), + (Some(op_binary::Kind::NotEqual), None) => Op::Binary(Binary::NotEqual), + (Some(op_binary::Kind::HeterogeneousEqual), None) => { + Op::Binary(Binary::HeterogeneousEqual) + } + (Some(op_binary::Kind::HeterogeneousNotEqual), None) => { + Op::Binary(Binary::HeterogeneousNotEqual) + } + (Some(op_binary::Kind::LazyAnd), None) => Op::Binary(Binary::LazyAnd), + (Some(op_binary::Kind::LazyOr), None) => Op::Binary(Binary::LazyOr), + (Some(op_binary::Kind::All), None) => Op::Binary(Binary::All), + (Some(op_binary::Kind::Any), None) => Op::Binary(Binary::Any), + (Some(op_binary::Kind::Get), None) => Op::Binary(Binary::Get), + (Some(op_binary::Kind::Ffi), Some(n)) => Op::Binary(Binary::Ffi(n.to_owned())), + (Some(op_binary::Kind::Ffi), None) => { + return Err(error::Format::DeserializationError( + "deserialization error: missing ffi name".to_string(), + )) + } + (Some(_), Some(_)) => { + return Err(error::Format::DeserializationError( + "deserialization error: ffi name set on a regular binary operation" + .to_string(), + )) + } + (None, _) => { + return Err(error::Format::DeserializationError( + "deserialization error: binary operation is empty".to_string(), + )) + } } - }, + } Some(op::Content::Closure(op_closure)) => Op::Closure( op_closure.params.clone(), op_closure diff --git a/biscuit-auth/src/token/builder.rs b/biscuit-auth/src/token/builder.rs index 9982fbcd..abd3d867 100644 --- a/biscuit-auth/src/token/builder.rs +++ b/biscuit-auth/src/token/builder.rs @@ -532,6 +532,25 @@ impl Term { datalog::Term::Set(s.into_iter().map(|i| i.to_datalog(symbols)).collect()) } Term::Null => datalog::Term::Null, + Term::Array(a) => { + datalog::Term::Array(a.into_iter().map(|i| i.to_datalog(symbols)).collect()) + } + Term::Map(m) => datalog::Term::Map( + m.into_iter() + .map(|(k, i)| { + ( + match k { + MapKey::Integer(i) => datalog::MapKey::Integer(i), + MapKey::Str(s) => datalog::MapKey::Str(symbols.insert(&s)), + // The error is caught in the `add_xxx` functions, so this should + // not happen™ + MapKey::Parameter(s) => panic!("Remaining parameter {}", &s), + }, + i.to_datalog(symbols), + ) + }) + .collect(), + ), // The error is caught in the `add_xxx` functions, so this should // not happen™ Term::Parameter(s) => panic!("Remaining parameter {}", &s), @@ -565,6 +584,29 @@ impl Term { .collect::>()?, ), datalog::Term::Null => Term::Null, + datalog::Term::Array(a) => Term::Array( + a.into_iter() + .map(|i| Self::from_datalog(i, symbols)) + .collect::>()?, + ), + datalog::Term::Map(m) => Term::Map( + m.into_iter() + .map(|(k, i)| { + Ok(( + match k { + datalog::MapKey::Integer(i) => MapKey::Integer(i), + datalog::MapKey::Str(s) => MapKey::Str( + symbols + .get_symbol(s) + .ok_or(error::Expression::UnknownSymbol(s))? + .to_string(), + ), + }, + Self::from_datalog(i, symbols)?, + )) + }) + .collect::>()?, + ), }) } } diff --git a/biscuit-capi/src/lib.rs b/biscuit-capi/src/lib.rs index 694791e4..23a0b290 100644 --- a/biscuit-capi/src/lib.rs +++ b/biscuit-capi/src/lib.rs @@ -80,7 +80,6 @@ pub enum ErrorKind { FormatPublicKeyTableOverlap, FormatUnknownExternalKey, FormatUnknownSymbol, - FormatMissingFfiName, AppendOnSealed, LogicInvalidBlockRule, LogicUnauthorized, @@ -160,7 +159,6 @@ pub extern "C" fn error_kind() -> ErrorKind { ErrorKind::FormatUnknownExternalKey } Token::Format(Format::UnknownSymbol(_)) => ErrorKind::FormatUnknownSymbol, - Token::Format(Format::MissingFfiName) => ErrorKind::FormatMissingFfiName, Token::AppendOnSealed => ErrorKind::AppendOnSealed, Token::AlreadySealed => ErrorKind::AlreadySealed, Token::Language(_) => ErrorKind::LanguageError, diff --git a/biscuit-capi/tests/capi.rs b/biscuit-capi/tests/capi.rs index d617ce96..af7f615a 100644 --- a/biscuit-capi/tests/capi.rs +++ b/biscuit-capi/tests/capi.rs @@ -114,7 +114,7 @@ biscuit append error? (null) authorizer creation error? (null) authorizer add check error? (null) authorizer add policy error? (null) -authorizer error(code = 22): authorization failed +authorizer error(code = 21): authorization failed failed checks (2): Authorizer check 0: check if right("efgh") Block 1, check 0: check if operation("read") From 0ed24827b514e7ad93041dcf4054a696ef56371d Mon Sep 17 00:00:00 2001 From: Clement Delafargue Date: Tue, 12 Nov 2024 10:08:47 +0100 Subject: [PATCH 6/8] Add sample for FFI calls --- biscuit-auth/examples/testcases.rs | 69 +++++++++++++++++++++++++++- biscuit-auth/samples/README.md | 48 +++++++++++++++++++ biscuit-auth/samples/samples.json | 41 +++++++++++++++++ biscuit-auth/samples/test035_ffi.bc | Bin 0 -> 234 bytes 4 files changed, 157 insertions(+), 1 deletion(-) create mode 100644 biscuit-auth/samples/test035_ffi.bc diff --git a/biscuit-auth/examples/testcases.rs b/biscuit-auth/examples/testcases.rs index 05f44e44..648e2656 100644 --- a/biscuit-auth/examples/testcases.rs +++ b/biscuit-auth/examples/testcases.rs @@ -10,9 +10,14 @@ use biscuit::macros::*; use biscuit::Authorizer; use biscuit::{builder::*, builder_ext::*, Biscuit}; use biscuit::{KeyPair, PrivateKey, PublicKey}; +use biscuit_auth::builder; +use biscuit_auth::datalog::ExternFunc; +use biscuit_auth::datalog::RunLimits; use prost::Message; use rand::prelude::*; use serde::Serialize; +use std::collections::HashMap; +use std::sync::Arc; use std::{ collections::{BTreeMap, BTreeSet}, fs::File, @@ -157,6 +162,9 @@ fn run(target: String, root_key: Option, test: bool, json: bool) { add_test_result(&mut results, type_of(&target, &root, test)); add_test_result(&mut results, array_map(&target, &root, test)); + + add_test_result(&mut results, ffi(&target, &root, test)); + if json { let s = serde_json::to_string_pretty(&TestCases { root_private_key: hex::encode(root.private().to_bytes()), @@ -297,6 +305,15 @@ enum AuthorizerResult { } fn validate_token(root: &KeyPair, data: &[u8], authorizer_code: &str) -> Validation { + validate_token_with_limits(root, data, authorizer_code, RunLimits::default()) +} + +fn validate_token_with_limits( + root: &KeyPair, + data: &[u8], + authorizer_code: &str, + run_limits: RunLimits, +) -> Validation { let token = match Biscuit::from(&data[..], &root.public()) { Ok(t) => t, Err(e) => { @@ -331,7 +348,7 @@ fn validate_token(root: &KeyPair, data: &[u8], authorizer_code: &str) -> Validat } }; - let res = authorizer.authorize(); + let res = authorizer.authorize_with_limits(run_limits); //println!("authorizer world:\n{}", authorizer.print_world()); let (_, _, _, policies) = authorizer.dump(); let snapshot = authorizer.snapshot().unwrap(); @@ -2269,6 +2286,56 @@ fn array_map(target: &str, root: &KeyPair, test: bool) -> TestResult { } } +fn ffi(target: &str, root: &KeyPair, test: bool) -> TestResult { + let mut rng: StdRng = SeedableRng::seed_from_u64(1234); + let title = "test ffi calls (v6 blocks)".to_string(); + let filename = "test035_ffi".to_string(); + let token; + + let biscuit = + biscuit!(r#"check if true.extern::test(), "a".extern::test("a") == "equal strings""#) + .build_with_rng(&root, SymbolTable::default(), &mut rng) + .unwrap(); + token = print_blocks(&biscuit); + + let data = write_or_load_testcase(target, &filename, root, &biscuit, test); + + let mut validations = BTreeMap::new(); + validations.insert( + "".to_string(), + validate_token_with_limits( + root, + &data[..], + "allow if true", + RunLimits { + extern_funcs: HashMap::from([( + "test".to_string(), + ExternFunc::new(Arc::new(|left, right| match (left, right) { + (t, None) => Ok(t), + (builder::Term::Str(left), Some(builder::Term::Str(right))) + if left == right => + { + Ok(builder::Term::Str("equal strings".to_string())) + } + (builder::Term::Str(_), Some(builder::Term::Str(_))) => { + Ok(builder::Term::Str("different strings".to_string())) + } + _ => Err("unsupported operands".to_string()), + })), + )]), + ..Default::default() + }, + ), + ); + + TestResult { + title, + filename, + token, + validations, + } +} + fn print_blocks(token: &Biscuit) -> Vec { let mut v = Vec::new(); diff --git a/biscuit-auth/samples/README.md b/biscuit-auth/samples/README.md index 11681f0c..b86cf835 100644 --- a/biscuit-auth/samples/README.md +++ b/biscuit-auth/samples/README.md @@ -3139,3 +3139,51 @@ World { result: `Ok(0)` + +------------------------------ + +## test ffi calls (v6 blocks): test035_ffi.bc +### token + +authority: +symbols: ["a", "equal strings"] + +public keys: [] + +``` +check if true.extern::test(), "a".extern::test("a") == "equal strings"; +``` + +### validation + +authorizer code: +``` +allow if true; +``` + +revocation ids: +- `b1696fd9f9ec456d65a863df034cb132dc7dca076d16f5bc3e73986a4cc88cc4e7902dc8519cb60961e3f33799c147f874c7e0d7e12ef1b461e361e0c0aa580b` + +authorizer world: +``` +World { + facts: [] + rules: [] + checks: [ + Checks { + origin: Some( + 0, + ), + checks: [ + "check if true.extern::test(), \"a\".extern::test(\"a\") == \"equal strings\"", + ], + }, +] + policies: [ + "allow if true", +] +} +``` + +result: `Ok(0)` + diff --git a/biscuit-auth/samples/samples.json b/biscuit-auth/samples/samples.json index 80519a1e..d71b90f2 100644 --- a/biscuit-auth/samples/samples.json +++ b/biscuit-auth/samples/samples.json @@ -2913,6 +2913,47 @@ ] } } + }, + { + "title": "test ffi calls (v6 blocks)", + "filename": "test035_ffi.bc", + "token": [ + { + "symbols": [ + "a", + "equal strings" + ], + "public_keys": [], + "external_key": null, + "code": "check if true.extern::test(), \"a\".extern::test(\"a\") == \"equal strings\";\n" + } + ], + "validations": { + "": { + "world": { + "facts": [], + "rules": [], + "checks": [ + { + "origin": 0, + "checks": [ + "check if true.extern::test(), \"a\".extern::test(\"a\") == \"equal strings\"" + ] + } + ], + "policies": [ + "allow if true" + ] + }, + "result": { + "Ok": 0 + }, + "authorizer_code": "allow if true;\n", + "revocation_ids": [ + "b1696fd9f9ec456d65a863df034cb132dc7dca076d16f5bc3e73986a4cc88cc4e7902dc8519cb60961e3f33799c147f874c7e0d7e12ef1b461e361e0c0aa580b" + ] + } + } } ] } diff --git a/biscuit-auth/samples/test035_ffi.bc b/biscuit-auth/samples/test035_ffi.bc new file mode 100644 index 0000000000000000000000000000000000000000..6387dfaa062b682388a5dd88f01f72e555cbf303 GIT binary patch literal 234 zcmWeS%*YkV#hA#&n_5_!n4?f!Qk0pOUM#_8{6%Ga=1%c4x0UH+vM4GP`H9w}d z%P4tfjPWw}h*j%1sz);mJ!+A1*qE7r^XD7a+|(7x_nCb*8r`Wq#hxqnb&p-~j4YoM zJx88T&^-}2XB%hY$TS a{})xu(^fvTD4<%nXRfro)|QC-x`_bUCr@4g literal 0 HcmV?d00001 From 070bda09fc28367b2141927b891101938841ca02 Mon Sep 17 00:00:00 2001 From: Clement Delafargue Date: Wed, 13 Nov 2024 10:19:15 +0100 Subject: [PATCH 7/8] intern FFI call names Instead of storing strings directly in the ops, do as we do for everything else and use the symbol table. This required duplicating `biscuit_parser::builder::Binary` and `Unary` in the `biscuit_auth::builder` module (which previously used the definitions from the `datalog` module directly). There is a lot of duplication between `biscuit_parser::builder` and `biscuit_auth::builder`, with a circular-ish dependency (biscuit_auth depends on biscuit parser, but code generated by the `ToTokens` impl blocks in biscuit parser depend on `biscuit_auth::builder`). --- biscuit-auth/samples/README.md | 4 +- biscuit-auth/samples/samples.json | 3 +- biscuit-auth/samples/test035_ffi.bc | Bin 234 -> 234 bytes biscuit-auth/src/datalog/expression.rs | 70 +++++++----- biscuit-auth/src/format/convert.rs | 4 +- biscuit-auth/src/format/schema.proto | 4 +- biscuit-auth/src/format/schema.rs | 8 +- biscuit-auth/src/token/builder.rs | 149 ++++++++++++++++++++++++- biscuit-parser/src/builder.rs | 68 +++++------ 9 files changed, 234 insertions(+), 76 deletions(-) diff --git a/biscuit-auth/samples/README.md b/biscuit-auth/samples/README.md index b86cf835..ec6e18fb 100644 --- a/biscuit-auth/samples/README.md +++ b/biscuit-auth/samples/README.md @@ -3146,7 +3146,7 @@ result: `Ok(0)` ### token authority: -symbols: ["a", "equal strings"] +symbols: ["test", "a", "equal strings"] public keys: [] @@ -3162,7 +3162,7 @@ allow if true; ``` revocation ids: -- `b1696fd9f9ec456d65a863df034cb132dc7dca076d16f5bc3e73986a4cc88cc4e7902dc8519cb60961e3f33799c147f874c7e0d7e12ef1b461e361e0c0aa580b` +- `faf26fe6f5dfa08c114a0a29321405b6fb7be79b0d80694d27925f7deb01effe5707600e42fd74f9a1d2920466446d51949155f4548f0fd68f3e9326c7e12404` authorizer world: ``` diff --git a/biscuit-auth/samples/samples.json b/biscuit-auth/samples/samples.json index d71b90f2..1401846a 100644 --- a/biscuit-auth/samples/samples.json +++ b/biscuit-auth/samples/samples.json @@ -2920,6 +2920,7 @@ "token": [ { "symbols": [ + "test", "a", "equal strings" ], @@ -2950,7 +2951,7 @@ }, "authorizer_code": "allow if true;\n", "revocation_ids": [ - "b1696fd9f9ec456d65a863df034cb132dc7dca076d16f5bc3e73986a4cc88cc4e7902dc8519cb60961e3f33799c147f874c7e0d7e12ef1b461e361e0c0aa580b" + "faf26fe6f5dfa08c114a0a29321405b6fb7be79b0d80694d27925f7deb01effe5707600e42fd74f9a1d2920466446d51949155f4548f0fd68f3e9326c7e12404" ] } } diff --git a/biscuit-auth/samples/test035_ffi.bc b/biscuit-auth/samples/test035_ffi.bc index 6387dfaa062b682388a5dd88f01f72e555cbf303..d5bb3a8dce7887643122c4fe7c01dc441a4de614 100644 GIT binary patch delta 163 zcmaFG_=+)H=rAK!Bo|9bYH_luEtfSH6Nj`EKNkxZ zlK~?ayAUe}i$DX1lnNIs7qdho2aIBuV&#wlN`U!I6T|f#etpV+_VxaP9zicIO(PN3 zZNICZ&*p8&^i`h}U;CQz{l9Sb1U{#~B|jHln#7Xkk{dW>V(6EUe*SCyc9YePKU86v HcuWNVJ9#oy delta 163 zcmaFG_=+)H=rAK!Bo|{M7jJ4|X=08-aY<2TUV5RiIh4QD;Kjw0|$)alH%ZyfyjY{8YhP9J8aC%zxnfxYi{a_{h{8+Es2j49~@W}!9DSq F3IN5rI5q$P diff --git a/biscuit-auth/src/datalog/expression.rs b/biscuit-auth/src/datalog/expression.rs index e26f3836..ecb47a28 100644 --- a/biscuit-auth/src/datalog/expression.rs +++ b/biscuit-auth/src/datalog/expression.rs @@ -1,6 +1,6 @@ use crate::{builder, error}; -use super::{MapKey, Term}; +use super::{MapKey, SymbolIndex, Term}; use super::{SymbolTable, TemporarySymbolTable}; use regex::Regex; use std::sync::Arc; @@ -71,7 +71,7 @@ pub enum Unary { Parens, Length, TypeOf, - Ffi(String), + Ffi(SymbolIndex), } impl Unary { @@ -109,10 +109,14 @@ impl Unary { Ok(Term::Str(sym)) } (Unary::Ffi(name), i) => { + let name = symbols + .get_symbol(*name) + .ok_or(error::Expression::UnknownSymbol(*name))? + .to_owned(); let fun = extern_funcs - .get(name) + .get(&name) .ok_or(error::Expression::UndefinedExtern(name.to_owned()))?; - fun.call(symbols, name, i, None) + fun.call(symbols, &name, i, None) } _ => { //println!("unexpected value type on the stack"); @@ -121,13 +125,15 @@ impl Unary { } } - pub fn print(&self, value: String, _symbols: &SymbolTable) -> String { + pub fn print(&self, value: String, symbols: &SymbolTable) -> String { match self { Unary::Negate => format!("!{}", value), Unary::Parens => format!("({})", value), Unary::Length => format!("{}.length()", value), Unary::TypeOf => format!("{}.type()", value), - Unary::Ffi(name) => format!("{value}.extern::{name}()"), + Unary::Ffi(name) => { + format!("{value}.extern::{}()", symbols.print_symbol_default(*name)) + } } } } @@ -163,7 +169,7 @@ pub enum Binary { All, Any, Get, - Ffi(String), + Ffi(SymbolIndex), } impl Binary { @@ -501,10 +507,14 @@ impl Binary { // FFI (Binary::Ffi(name), left, right) => { + let name = symbols + .get_symbol(*name) + .ok_or(error::Expression::UnknownSymbol(*name))? + .to_owned(); let fun = extern_funcs - .get(name) + .get(&name) .ok_or(error::Expression::UndefinedExtern(name.to_owned()))?; - fun.call(symbols, name, left, Some(right)) + fun.call(symbols, &name, left, Some(right)) } _ => { @@ -514,7 +524,7 @@ impl Binary { } } - pub fn print(&self, left: String, right: String, _symbols: &SymbolTable) -> String { + pub fn print(&self, left: String, right: String, symbols: &SymbolTable) -> String { match self { Binary::LessThan => format!("{} < {}", left, right), Binary::GreaterThan => format!("{} > {}", left, right), @@ -544,7 +554,10 @@ impl Binary { Binary::All => format!("{left}.all({right})"), Binary::Any => format!("{left}.any({right})"), Binary::Get => format!("{left}.get({right})"), - Binary::Ffi(name) => format!("{left}.extern::{name}({right})"), + Binary::Ffi(name) => format!( + "{left}.extern::{}({right})", + symbols.print_symbol_default(*name) + ), } } } @@ -1676,64 +1689,69 @@ mod tests { let mut symbols = SymbolTable::new(); let i = symbols.insert("test"); let j = symbols.insert("TeSt"); + let test_bin = symbols.insert("test_bin"); + let test_un = symbols.insert("test_un"); + let test_closure = symbols.insert("test_closure"); + let test_fn = symbols.insert("test_fn"); + let id_fn = symbols.insert("id"); let mut tmp_symbols = TemporarySymbolTable::new(&symbols); let ops = vec![ Op::Value(Term::Integer(60)), Op::Value(Term::Integer(0)), - Op::Binary(Binary::Ffi("test_bin".to_owned())), + Op::Binary(Binary::Ffi(test_bin)), Op::Value(Term::Str(i)), Op::Value(Term::Str(j)), - Op::Binary(Binary::Ffi("test_bin".to_owned())), + Op::Binary(Binary::Ffi(test_bin)), Op::Binary(Binary::And), Op::Value(Term::Integer(42)), - Op::Unary(Unary::Ffi("test_un".to_owned())), + Op::Unary(Unary::Ffi(test_un)), Op::Binary(Binary::And), Op::Value(Term::Integer(42)), - Op::Unary(Unary::Ffi("test_closure".to_owned())), + Op::Unary(Unary::Ffi(test_closure)), Op::Binary(Binary::And), Op::Value(Term::Str(i)), - Op::Unary(Unary::Ffi("test_closure".to_owned())), + Op::Unary(Unary::Ffi(test_closure)), Op::Binary(Binary::And), Op::Value(Term::Integer(42)), - Op::Unary(Unary::Ffi("test_fn".to_owned())), + Op::Unary(Unary::Ffi(test_fn)), Op::Binary(Binary::And), Op::Value(Term::Integer(42)), - Op::Unary(Unary::Ffi("id".to_owned())), + Op::Unary(Unary::Ffi(id_fn)), Op::Value(Term::Integer(42)), Op::Binary(Binary::HeterogeneousEqual), Op::Binary(Binary::And), Op::Value(Term::Str(i)), - Op::Unary(Unary::Ffi("id".to_owned())), + Op::Unary(Unary::Ffi(id_fn)), Op::Value(Term::Str(i)), Op::Binary(Binary::HeterogeneousEqual), Op::Binary(Binary::And), Op::Value(Term::Bool(true)), - Op::Unary(Unary::Ffi("id".to_owned())), + Op::Unary(Unary::Ffi(id_fn)), Op::Value(Term::Bool(true)), Op::Binary(Binary::HeterogeneousEqual), Op::Binary(Binary::And), Op::Value(Term::Date(0)), - Op::Unary(Unary::Ffi("id".to_owned())), + Op::Unary(Unary::Ffi(id_fn)), Op::Value(Term::Date(0)), Op::Binary(Binary::HeterogeneousEqual), Op::Binary(Binary::And), Op::Value(Term::Bytes(vec![42])), - Op::Unary(Unary::Ffi("id".to_owned())), + Op::Unary(Unary::Ffi(id_fn)), Op::Value(Term::Bytes(vec![42])), Op::Binary(Binary::HeterogeneousEqual), Op::Binary(Binary::And), Op::Value(Term::Null), - Op::Unary(Unary::Ffi("id".to_owned())), + Op::Unary(Unary::Ffi(id_fn)), Op::Value(Term::Null), Op::Binary(Binary::HeterogeneousEqual), Op::Binary(Binary::And), Op::Value(Term::Array(vec![Term::Null])), - Op::Unary(Unary::Ffi("id".to_owned())), + Op::Unary(Unary::Ffi(id_fn)), Op::Value(Term::Array(vec![Term::Null])), Op::Binary(Binary::HeterogeneousEqual), Op::Binary(Binary::And), Op::Value(Term::Set(BTreeSet::from([Term::Null]))), - Op::Unary(Unary::Ffi("id".to_owned())), + Op::Unary(Unary::Ffi(id_fn)), Op::Value(Term::Set(BTreeSet::from([Term::Null]))), Op::Binary(Binary::HeterogeneousEqual), Op::Binary(Binary::And), @@ -1741,7 +1759,7 @@ mod tests { (MapKey::Integer(42), Term::Null), (MapKey::Str(i), Term::Null), ]))), - Op::Unary(Unary::Ffi("id".to_owned())), + Op::Unary(Unary::Ffi(id_fn)), Op::Value(Term::Map(BTreeMap::from([ (MapKey::Integer(42), Term::Null), (MapKey::Str(i), Term::Null), diff --git a/biscuit-auth/src/format/convert.rs b/biscuit-auth/src/format/convert.rs index 53225d59..df8fa290 100644 --- a/biscuit-auth/src/format/convert.rs +++ b/biscuit-auth/src/format/convert.rs @@ -744,7 +744,7 @@ pub mod v2 { (Some(op_unary::Kind::Parens), None) => Op::Unary(Unary::Parens), (Some(op_unary::Kind::Length), None) => Op::Unary(Unary::Length), (Some(op_unary::Kind::TypeOf), None) => Op::Unary(Unary::TypeOf), - (Some(op_unary::Kind::Ffi), Some(n)) => Op::Unary(Unary::Ffi(n.to_owned())), + (Some(op_unary::Kind::Ffi), Some(n)) => Op::Unary(Unary::Ffi(*n)), (Some(op_unary::Kind::Ffi), None) => { return Err(error::Format::DeserializationError( "deserialization error: missing ffi name".to_string(), @@ -799,7 +799,7 @@ pub mod v2 { (Some(op_binary::Kind::All), None) => Op::Binary(Binary::All), (Some(op_binary::Kind::Any), None) => Op::Binary(Binary::Any), (Some(op_binary::Kind::Get), None) => Op::Binary(Binary::Get), - (Some(op_binary::Kind::Ffi), Some(n)) => Op::Binary(Binary::Ffi(n.to_owned())), + (Some(op_binary::Kind::Ffi), Some(n)) => Op::Binary(Binary::Ffi(*n)), (Some(op_binary::Kind::Ffi), None) => { return Err(error::Format::DeserializationError( "deserialization error: missing ffi name".to_string(), diff --git a/biscuit-auth/src/format/schema.proto b/biscuit-auth/src/format/schema.proto index ad802596..d8c91458 100644 --- a/biscuit-auth/src/format/schema.proto +++ b/biscuit-auth/src/format/schema.proto @@ -152,7 +152,7 @@ message OpUnary { } required Kind kind = 1; - optional string ffiName = 2; + optional uint64 ffiName = 2; } message OpBinary { @@ -189,7 +189,7 @@ message OpBinary { } required Kind kind = 1; - optional string ffiName = 2; + optional uint64 ffiName = 2; } message OpClosure { diff --git a/biscuit-auth/src/format/schema.rs b/biscuit-auth/src/format/schema.rs index c8189ae0..80cb1aa2 100644 --- a/biscuit-auth/src/format/schema.rs +++ b/biscuit-auth/src/format/schema.rs @@ -235,8 +235,8 @@ pub mod op { pub struct OpUnary { #[prost(enumeration="op_unary::Kind", required, tag="1")] pub kind: i32, - #[prost(string, optional, tag="2")] - pub ffi_name: ::core::option::Option<::prost::alloc::string::String>, + #[prost(uint64, optional, tag="2")] + pub ffi_name: ::core::option::Option, } /// Nested message and enum types in `OpUnary`. pub mod op_unary { @@ -254,8 +254,8 @@ pub mod op_unary { pub struct OpBinary { #[prost(enumeration="op_binary::Kind", required, tag="1")] pub kind: i32, - #[prost(string, optional, tag="2")] - pub ffi_name: ::core::option::Option<::prost::alloc::string::String>, + #[prost(uint64, optional, tag="2")] + pub ffi_name: ::core::option::Option, } /// Nested message and enum types in `OpBinary`. pub mod op_binary { diff --git a/biscuit-auth/src/token/builder.rs b/biscuit-auth/src/token/builder.rs index abd3d867..248afe6e 100644 --- a/biscuit-auth/src/token/builder.rs +++ b/biscuit-auth/src/token/builder.rs @@ -17,7 +17,10 @@ use std::{ }; // reexport those because the builder uses the same definitions -pub use crate::datalog::{Binary, Expression as DatalogExpression, Op as DatalogOp, Unary}; +pub use crate::datalog::{ + Binary as DatalogBinary, Expression as DatalogExpression, Op as DatalogOp, + Unary as DatalogUnary, +}; /// creates a Block content to append to an existing token #[derive(Clone, Debug, Default)] @@ -419,6 +422,50 @@ pub trait Convert: Sized { } } +/// Builder for a unary operation +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Unary { + Negate, + Parens, + Length, + TypeOf, + Ffi(String), +} + +/// Builder for a binary operation +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Binary { + LessThan, + GreaterThan, + LessOrEqual, + GreaterOrEqual, + Equal, + Contains, + Prefix, + Suffix, + Regex, + Add, + Sub, + Mul, + Div, + And, + Or, + Intersection, + Union, + BitwiseAnd, + BitwiseOr, + BitwiseXor, + NotEqual, + HeterogeneousEqual, + HeterogeneousNotEqual, + LazyAnd, + LazyOr, + All, + Any, + Get, + Ffi(String), +} + /// Builder for a Datalog value #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Term { @@ -1191,8 +1238,8 @@ impl Convert for Op { fn convert(&self, symbols: &mut SymbolTable) -> datalog::Op { match self { Op::Value(t) => datalog::Op::Value(t.convert(symbols)), - Op::Unary(u) => datalog::Op::Unary(u.clone()), - Op::Binary(b) => datalog::Op::Binary(b.clone()), + Op::Unary(u) => datalog::Op::Unary(u.convert(symbols)), + Op::Binary(b) => datalog::Op::Binary(b.convert(symbols)), Op::Closure(ps, os) => datalog::Op::Closure( ps.iter().map(|p| symbols.insert(p) as u32).collect(), os.iter().map(|o| o.convert(symbols)).collect(), @@ -1203,8 +1250,8 @@ impl Convert for Op { fn convert_from(op: &datalog::Op, symbols: &SymbolTable) -> Result { Ok(match op { datalog::Op::Value(t) => Op::Value(Term::convert_from(t, symbols)?), - datalog::Op::Unary(u) => Op::Unary(u.clone()), - datalog::Op::Binary(b) => Op::Binary(b.clone()), + datalog::Op::Unary(u) => Op::Unary(Unary::convert_from(u, symbols)?), + datalog::Op::Binary(b) => Op::Binary(Binary::convert_from(b, symbols)?), datalog::Op::Closure(ps, os) => Op::Closure( ps.iter() .map(|p| symbols.print_symbol(*p as u64)) @@ -1230,6 +1277,28 @@ impl From for Op { } } +impl Convert for Unary { + fn convert(&self, symbols: &mut SymbolTable) -> datalog::Unary { + match self { + Unary::Negate => datalog::Unary::Negate, + Unary::Parens => datalog::Unary::Parens, + Unary::Length => datalog::Unary::Length, + Unary::TypeOf => datalog::Unary::TypeOf, + Unary::Ffi(n) => datalog::Unary::Ffi(symbols.insert(n)), + } + } + + fn convert_from(f: &datalog::Unary, symbols: &SymbolTable) -> Result { + match f { + datalog::Unary::Negate => Ok(Unary::Negate), + datalog::Unary::Parens => Ok(Unary::Parens), + datalog::Unary::Length => Ok(Unary::Length), + datalog::Unary::TypeOf => Ok(Unary::TypeOf), + datalog::Unary::Ffi(i) => Ok(Unary::Ffi(symbols.print_symbol(*i)?)), + } + } +} + impl From for Unary { fn from(unary: biscuit_parser::builder::Unary) -> Self { match unary { @@ -1242,6 +1311,76 @@ impl From for Unary { } } +impl Convert for Binary { + fn convert(&self, symbols: &mut SymbolTable) -> datalog::Binary { + match self { + Binary::LessThan => datalog::Binary::LessThan, + Binary::GreaterThan => datalog::Binary::GreaterThan, + Binary::LessOrEqual => datalog::Binary::LessOrEqual, + Binary::GreaterOrEqual => datalog::Binary::GreaterOrEqual, + Binary::Equal => datalog::Binary::Equal, + Binary::Contains => datalog::Binary::Contains, + Binary::Prefix => datalog::Binary::Prefix, + Binary::Suffix => datalog::Binary::Suffix, + Binary::Regex => datalog::Binary::Regex, + Binary::Add => datalog::Binary::Add, + Binary::Sub => datalog::Binary::Sub, + Binary::Mul => datalog::Binary::Mul, + Binary::Div => datalog::Binary::Div, + Binary::And => datalog::Binary::And, + Binary::Or => datalog::Binary::Or, + Binary::Intersection => datalog::Binary::Intersection, + Binary::Union => datalog::Binary::Union, + Binary::BitwiseAnd => datalog::Binary::BitwiseAnd, + Binary::BitwiseOr => datalog::Binary::BitwiseOr, + Binary::BitwiseXor => datalog::Binary::BitwiseXor, + Binary::NotEqual => datalog::Binary::NotEqual, + Binary::HeterogeneousEqual => datalog::Binary::HeterogeneousEqual, + Binary::HeterogeneousNotEqual => datalog::Binary::HeterogeneousNotEqual, + Binary::LazyAnd => datalog::Binary::LazyAnd, + Binary::LazyOr => datalog::Binary::LazyOr, + Binary::All => datalog::Binary::All, + Binary::Any => datalog::Binary::Any, + Binary::Get => datalog::Binary::Get, + Binary::Ffi(n) => datalog::Binary::Ffi(symbols.insert(n)), + } + } + + fn convert_from(f: &datalog::Binary, symbols: &SymbolTable) -> Result { + match f { + datalog::Binary::LessThan => Ok(Binary::LessThan), + datalog::Binary::GreaterThan => Ok(Binary::GreaterThan), + datalog::Binary::LessOrEqual => Ok(Binary::LessOrEqual), + datalog::Binary::GreaterOrEqual => Ok(Binary::GreaterOrEqual), + datalog::Binary::Equal => Ok(Binary::Equal), + datalog::Binary::Contains => Ok(Binary::Contains), + datalog::Binary::Prefix => Ok(Binary::Prefix), + datalog::Binary::Suffix => Ok(Binary::Suffix), + datalog::Binary::Regex => Ok(Binary::Regex), + datalog::Binary::Add => Ok(Binary::Add), + datalog::Binary::Sub => Ok(Binary::Sub), + datalog::Binary::Mul => Ok(Binary::Mul), + datalog::Binary::Div => Ok(Binary::Div), + datalog::Binary::And => Ok(Binary::And), + datalog::Binary::Or => Ok(Binary::Or), + datalog::Binary::Intersection => Ok(Binary::Intersection), + datalog::Binary::Union => Ok(Binary::Union), + datalog::Binary::BitwiseAnd => Ok(Binary::BitwiseAnd), + datalog::Binary::BitwiseOr => Ok(Binary::BitwiseOr), + datalog::Binary::BitwiseXor => Ok(Binary::BitwiseXor), + datalog::Binary::NotEqual => Ok(Binary::NotEqual), + datalog::Binary::HeterogeneousEqual => Ok(Binary::HeterogeneousEqual), + datalog::Binary::HeterogeneousNotEqual => Ok(Binary::HeterogeneousNotEqual), + datalog::Binary::LazyAnd => Ok(Binary::LazyAnd), + datalog::Binary::LazyOr => Ok(Binary::LazyOr), + datalog::Binary::All => Ok(Binary::All), + datalog::Binary::Any => Ok(Binary::Any), + datalog::Binary::Get => Ok(Binary::Get), + datalog::Binary::Ffi(i) => Ok(Binary::Ffi(symbols.print_symbol(*i)?)), + } + } +} + impl From for Binary { fn from(binary: biscuit_parser::builder::Binary) -> Self { match binary { diff --git a/biscuit-parser/src/builder.rs b/biscuit-parser/src/builder.rs index 486fd85b..4a993157 100644 --- a/biscuit-parser/src/builder.rs +++ b/biscuit-parser/src/builder.rs @@ -341,11 +341,11 @@ impl ToTokens for Op { impl ToTokens for Unary { fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { tokens.extend(match self { - Unary::Negate => quote! {::biscuit_auth::datalog::Unary::Negate }, - Unary::Parens => quote! {::biscuit_auth::datalog::Unary::Parens }, - Unary::Length => quote! {::biscuit_auth::datalog::Unary::Length }, - Unary::TypeOf => quote! {::biscuit_auth::datalog::Unary::TypeOf }, - Unary::Ffi(name) => quote! {::biscuit_auth::datalog::Unary::Ffi(#name.to_string()) }, + Unary::Negate => quote! {::biscuit_auth::builder::Unary::Negate }, + Unary::Parens => quote! {::biscuit_auth::builder::Unary::Parens }, + Unary::Length => quote! {::biscuit_auth::builder::Unary::Length }, + Unary::TypeOf => quote! {::biscuit_auth::builder::Unary::TypeOf }, + Unary::Ffi(name) => quote! {::biscuit_auth::builder::Unary::Ffi(#name.to_string()) }, }); } } @@ -354,39 +354,39 @@ impl ToTokens for Unary { impl ToTokens for Binary { fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { tokens.extend(match self { - Binary::LessThan => quote! { ::biscuit_auth::datalog::Binary::LessThan }, - Binary::GreaterThan => quote! { ::biscuit_auth::datalog::Binary::GreaterThan }, - Binary::LessOrEqual => quote! { ::biscuit_auth::datalog::Binary::LessOrEqual }, - Binary::GreaterOrEqual => quote! { ::biscuit_auth::datalog::Binary::GreaterOrEqual }, - Binary::Equal => quote! { ::biscuit_auth::datalog::Binary::Equal }, - Binary::Contains => quote! { ::biscuit_auth::datalog::Binary::Contains }, - Binary::Prefix => quote! { ::biscuit_auth::datalog::Binary::Prefix }, - Binary::Suffix => quote! { ::biscuit_auth::datalog::Binary::Suffix }, - Binary::Regex => quote! { ::biscuit_auth::datalog::Binary::Regex }, - Binary::Add => quote! { ::biscuit_auth::datalog::Binary::Add }, - Binary::Sub => quote! { ::biscuit_auth::datalog::Binary::Sub }, - Binary::Mul => quote! { ::biscuit_auth::datalog::Binary::Mul }, - Binary::Div => quote! { ::biscuit_auth::datalog::Binary::Div }, - Binary::And => quote! { ::biscuit_auth::datalog::Binary::And }, - Binary::Or => quote! { ::biscuit_auth::datalog::Binary::Or }, - Binary::Intersection => quote! { ::biscuit_auth::datalog::Binary::Intersection }, - Binary::Union => quote! { ::biscuit_auth::datalog::Binary::Union }, - Binary::BitwiseAnd => quote! { ::biscuit_auth::datalog::Binary::BitwiseAnd }, - Binary::BitwiseOr => quote! { ::biscuit_auth::datalog::Binary::BitwiseOr }, - Binary::BitwiseXor => quote! { ::biscuit_auth::datalog::Binary::BitwiseXor }, - Binary::NotEqual => quote! { ::biscuit_auth::datalog::Binary::NotEqual }, + Binary::LessThan => quote! { ::biscuit_auth::builder::Binary::LessThan }, + Binary::GreaterThan => quote! { ::biscuit_auth::builder::Binary::GreaterThan }, + Binary::LessOrEqual => quote! { ::biscuit_auth::builder::Binary::LessOrEqual }, + Binary::GreaterOrEqual => quote! { ::biscuit_auth::builder::Binary::GreaterOrEqual }, + Binary::Equal => quote! { ::biscuit_auth::builder::Binary::Equal }, + Binary::Contains => quote! { ::biscuit_auth::builder::Binary::Contains }, + Binary::Prefix => quote! { ::biscuit_auth::builder::Binary::Prefix }, + Binary::Suffix => quote! { ::biscuit_auth::builder::Binary::Suffix }, + Binary::Regex => quote! { ::biscuit_auth::builder::Binary::Regex }, + Binary::Add => quote! { ::biscuit_auth::builder::Binary::Add }, + Binary::Sub => quote! { ::biscuit_auth::builder::Binary::Sub }, + Binary::Mul => quote! { ::biscuit_auth::builder::Binary::Mul }, + Binary::Div => quote! { ::biscuit_auth::builder::Binary::Div }, + Binary::And => quote! { ::biscuit_auth::builder::Binary::And }, + Binary::Or => quote! { ::biscuit_auth::builder::Binary::Or }, + Binary::Intersection => quote! { ::biscuit_auth::builder::Binary::Intersection }, + Binary::Union => quote! { ::biscuit_auth::builder::Binary::Union }, + Binary::BitwiseAnd => quote! { ::biscuit_auth::builder::Binary::BitwiseAnd }, + Binary::BitwiseOr => quote! { ::biscuit_auth::builder::Binary::BitwiseOr }, + Binary::BitwiseXor => quote! { ::biscuit_auth::builder::Binary::BitwiseXor }, + Binary::NotEqual => quote! { ::biscuit_auth::builder::Binary::NotEqual }, Binary::HeterogeneousEqual => { - quote! { ::biscuit_auth::datalog::Binary::HeterogeneousEqual} + quote! { ::biscuit_auth::builder::Binary::HeterogeneousEqual} } Binary::HeterogeneousNotEqual => { - quote! { ::biscuit_auth::datalog::Binary::HeterogeneousNotEqual} + quote! { ::biscuit_auth::builder::Binary::HeterogeneousNotEqual} } - Binary::LazyAnd => quote! { ::biscuit_auth::datalog::Binary::LazyAnd }, - Binary::LazyOr => quote! { ::biscuit_auth::datalog::Binary::LazyOr }, - Binary::All => quote! { ::biscuit_auth::datalog::Binary::All }, - Binary::Any => quote! { ::biscuit_auth::datalog::Binary::Any }, - Binary::Get => quote! { ::biscuit_auth::datalog::Binary::Get }, - Binary::Ffi(name) => quote! {::biscuit_auth::datalog::Binary::Ffi(#name.to_string()) }, + Binary::LazyAnd => quote! { ::biscuit_auth::builder::Binary::LazyAnd }, + Binary::LazyOr => quote! { ::biscuit_auth::builder::Binary::LazyOr }, + Binary::All => quote! { ::biscuit_auth::builder::Binary::All }, + Binary::Any => quote! { ::biscuit_auth::builder::Binary::Any }, + Binary::Get => quote! { ::biscuit_auth::builder::Binary::Get }, + Binary::Ffi(name) => quote! {::biscuit_auth::builder::Binary::Ffi(#name.to_string()) }, }); } } From a179587d497f3fa38da9e55d0629e9353ff90740 Mon Sep 17 00:00:00 2001 From: Clement Delafargue Date: Tue, 19 Nov 2024 16:20:28 +0100 Subject: [PATCH 8/8] ffi: store external functions in `World` This ensures consistent evaluation of functions and removes the need of storing them in `RunLimits` --- biscuit-auth/examples/testcases.rs | 48 +++++++------- biscuit-auth/src/datalog/mod.rs | 50 ++++----------- biscuit-auth/src/token/authorizer.rs | 63 +++++++++---------- biscuit-auth/src/token/authorizer/snapshot.rs | 1 - 4 files changed, 69 insertions(+), 93 deletions(-) diff --git a/biscuit-auth/examples/testcases.rs b/biscuit-auth/examples/testcases.rs index 648e2656..f69809d8 100644 --- a/biscuit-auth/examples/testcases.rs +++ b/biscuit-auth/examples/testcases.rs @@ -305,14 +305,21 @@ enum AuthorizerResult { } fn validate_token(root: &KeyPair, data: &[u8], authorizer_code: &str) -> Validation { - validate_token_with_limits(root, data, authorizer_code, RunLimits::default()) + validate_token_with_limits_and_external_functions( + root, + data, + authorizer_code, + RunLimits::default(), + Default::default(), + ) } -fn validate_token_with_limits( +fn validate_token_with_limits_and_external_functions( root: &KeyPair, data: &[u8], authorizer_code: &str, run_limits: RunLimits, + extern_funcs: HashMap, ) -> Validation { let token = match Biscuit::from(&data[..], &root.public()) { Ok(t) => t, @@ -333,6 +340,7 @@ fn validate_token_with_limits( } let mut authorizer = Authorizer::new(); + authorizer.set_extern_funcs(extern_funcs); authorizer.add_code(authorizer_code).unwrap(); let authorizer_code = authorizer.dump_code(); @@ -2303,28 +2311,26 @@ fn ffi(target: &str, root: &KeyPair, test: bool) -> TestResult { let mut validations = BTreeMap::new(); validations.insert( "".to_string(), - validate_token_with_limits( + validate_token_with_limits_and_external_functions( root, &data[..], "allow if true", - RunLimits { - extern_funcs: HashMap::from([( - "test".to_string(), - ExternFunc::new(Arc::new(|left, right| match (left, right) { - (t, None) => Ok(t), - (builder::Term::Str(left), Some(builder::Term::Str(right))) - if left == right => - { - Ok(builder::Term::Str("equal strings".to_string())) - } - (builder::Term::Str(_), Some(builder::Term::Str(_))) => { - Ok(builder::Term::Str("different strings".to_string())) - } - _ => Err("unsupported operands".to_string()), - })), - )]), - ..Default::default() - }, + RunLimits::default(), + HashMap::from([( + "test".to_string(), + ExternFunc::new(Arc::new(|left, right| match (left, right) { + (t, None) => Ok(t), + (builder::Term::Str(left), Some(builder::Term::Str(right))) + if left == right => + { + Ok(builder::Term::Str("equal strings".to_string())) + } + (builder::Term::Str(_), Some(builder::Term::Str(_))) => { + Ok(builder::Term::Str("different strings".to_string())) + } + _ => Err("unsupported operands".to_string()), + })), + )]), ), ); diff --git a/biscuit-auth/src/datalog/mod.rs b/biscuit-auth/src/datalog/mod.rs index 4ddf5bcb..d1fba3d0 100644 --- a/biscuit-auth/src/datalog/mod.rs +++ b/biscuit-auth/src/datalog/mod.rs @@ -588,6 +588,7 @@ pub struct World { pub facts: FactSet, pub rules: RuleSet, pub iterations: u64, + pub extern_funcs: HashMap, } impl World { @@ -622,7 +623,7 @@ impl World { for (scope, rules) in self.rules.inner.iter() { let it = self.facts.iterator(scope); for (origin, rule) in rules { - for res in rule.apply(it.clone(), *origin, symbols, &limits.extern_funcs) { + for res in rule.apply(it.clone(), *origin, symbols, &self.extern_funcs) { match res { Ok((origin, fact)) => { new_facts.insert(&origin, fact); @@ -693,12 +694,11 @@ impl World { origin: usize, scope: &TrustedOrigins, symbols: &SymbolTable, - extern_funcs: &HashMap, ) -> Result { let mut new_facts = FactSet::default(); let it = self.facts.iterator(scope); //new_facts.extend(rule.apply(it, origin, symbols)); - for res in rule.apply(it.clone(), origin, symbols, extern_funcs) { + for res in rule.apply(it.clone(), origin, symbols, &self.extern_funcs) { match res { Ok((origin, fact)) => { new_facts.insert(&origin, fact); @@ -718,9 +718,8 @@ impl World { origin: usize, scope: &TrustedOrigins, symbols: &SymbolTable, - extern_funcs: &HashMap, ) -> Result { - rule.find_match(&self.facts, origin, scope, symbols, extern_funcs) + rule.find_match(&self.facts, origin, scope, symbols, &self.extern_funcs) } pub fn query_match_all( @@ -728,9 +727,8 @@ impl World { rule: Rule, scope: &TrustedOrigins, symbols: &SymbolTable, - extern_funcs: &HashMap, ) -> Result { - rule.check_match_all(&self.facts, scope, symbols, extern_funcs) + rule.check_match_all(&self.facts, scope, symbols, &self.extern_funcs) } } @@ -743,8 +741,6 @@ pub struct RunLimits { pub max_iterations: u64, /// maximum execution time pub max_time: Duration, - - pub extern_funcs: HashMap, } impl std::default::Default for RunLimits { @@ -753,7 +749,6 @@ impl std::default::Default for RunLimits { max_facts: 1000, max_iterations: 100, max_time: Duration::from_millis(1), - extern_funcs: Default::default(), } } } @@ -1056,8 +1051,7 @@ mod tests { println!("symbols: {:?}", syms); println!("testing r1: {}", syms.print_rule(&r1)); - let query_rule_result = - w.query_rule(r1, 0, &[0].iter().collect(), &syms, &Default::default()); + let query_rule_result = w.query_rule(r1, 0, &[0].iter().collect(), &syms); println!("grandparents query_rules: {:?}", query_rule_result); println!("current facts: {:?}", w.facts); @@ -1102,7 +1096,6 @@ mod tests { 0, &[0].iter().collect(), &syms, - &Default::default(), ) .unwrap(); @@ -1121,7 +1114,6 @@ mod tests { 0, &[0].iter().collect(), &syms, - &Default::default() ) ); println!( @@ -1138,7 +1130,6 @@ mod tests { 0, &[0].iter().collect(), &syms, - &Default::default() ) ); w.add_fact(&[0].iter().collect(), fact(parent, &[&c, &e])); @@ -1156,7 +1147,6 @@ mod tests { 0, &[0].iter().collect(), &syms, - &Default::default(), ) .unwrap(); println!("grandparents after inserting parent(C, E): {:?}", res); @@ -1232,7 +1222,6 @@ mod tests { 0, &[0].iter().collect(), &syms, - &Default::default(), ) .unwrap(); @@ -1282,7 +1271,6 @@ mod tests { 0, &[0].iter().collect(), &syms, - &Default::default(), ) .unwrap(); @@ -1369,7 +1357,6 @@ mod tests { 0, &[0].iter().collect(), &syms, - &Default::default(), ) .unwrap() .iter_all() @@ -1450,9 +1437,7 @@ mod tests { ); println!("testing r1: {}", syms.print_rule(&r1)); - let res = w - .query_rule(r1, 0, &[0].iter().collect(), &syms, &Default::default()) - .unwrap(); + let res = w.query_rule(r1, 0, &[0].iter().collect(), &syms).unwrap(); for (_, fact) in res.iter_all() { println!("\t{}", syms.print_fact(fact)); } @@ -1490,9 +1475,7 @@ mod tests { ); println!("testing r2: {}", syms.print_rule(&r2)); - let res = w - .query_rule(r2, 0, &[0].iter().collect(), &syms, &Default::default()) - .unwrap(); + let res = w.query_rule(r2, 0, &[0].iter().collect(), &syms).unwrap(); for (_, fact) in res.iter_all() { println!("\t{}", syms.print_fact(fact)); } @@ -1555,7 +1538,6 @@ mod tests { 0, &[0].iter().collect(), &syms, - &Default::default(), ) .unwrap(); @@ -1607,7 +1589,6 @@ mod tests { 0, &[0].iter().collect(), &syms, - &Default::default(), ) .unwrap(); @@ -1653,7 +1634,6 @@ mod tests { 0, &[0].iter().collect(), &syms, - &Default::default(), ) .unwrap(); @@ -1699,7 +1679,6 @@ mod tests { 0, &[0].iter().collect(), &syms, - &Default::default(), ) .unwrap(); @@ -1723,7 +1702,6 @@ mod tests { 0, &[0].iter().collect(), &syms, - &Default::default(), ) .unwrap(); @@ -1766,9 +1744,7 @@ mod tests { println!("world:\n{}\n", syms.print_world(&w)); println!("\ntesting r1: {}\n", syms.print_rule(&r1)); - let res = w - .query_rule(r1, 0, &[0].iter().collect(), &syms, &Default::default()) - .unwrap(); + let res = w.query_rule(r1, 0, &[0].iter().collect(), &syms).unwrap(); for (_, fact) in res.iter_all() { println!("\t{}", syms.print_fact(fact)); } @@ -1807,9 +1783,7 @@ mod tests { ); println!("world:\n{}\n", syms.print_world(&w)); println!("\ntesting r1: {}\n", syms.print_rule(&r1)); - let res = w - .query_rule(r1, 0, &[0].iter().collect(), &syms, &Default::default()) - .unwrap(); + let res = w.query_rule(r1, 0, &[0].iter().collect(), &syms).unwrap(); println!("generated facts:"); for (_, fact) in res.iter_all() { @@ -1825,9 +1799,7 @@ mod tests { let r2 = rule(check, &[&read], &[pred(operation, &[&read])]); println!("world:\n{}\n", syms.print_world(&w)); println!("\ntesting r2: {}\n", syms.print_rule(&r2)); - let res = w - .query_rule(r2, 0, &[0].iter().collect(), &syms, &Default::default()) - .unwrap(); + let res = w.query_rule(r2, 0, &[0].iter().collect(), &syms).unwrap(); println!("generated facts:"); for (_, fact) in res.iter_all() { diff --git a/biscuit-auth/src/token/authorizer.rs b/biscuit-auth/src/token/authorizer.rs index 75777834..060aa2f7 100644 --- a/biscuit-auth/src/token/authorizer.rs +++ b/biscuit-auth/src/token/authorizer.rs @@ -7,7 +7,7 @@ use super::builder_ext::{AuthorizerExt, BuilderExt}; use super::{Biscuit, Block}; use crate::builder::{self, CheckKind, Convert}; use crate::crypto::PublicKey; -use crate::datalog::{self, Origin, RunLimits, SymbolTable, TrustedOrigins}; +use crate::datalog::{self, ExternFunc, Origin, RunLimits, SymbolTable, TrustedOrigins}; use crate::error; use crate::time::Instant; use crate::token; @@ -397,6 +397,26 @@ impl Authorizer { self.limits = limits; } + /// Returns the currently registered external functions + pub fn external_funcs(&self) -> &HashMap { + &self.world.extern_funcs + } + + /// Replaces the registered external functions + pub fn set_extern_funcs(&mut self, extern_funcs: HashMap) { + self.world.extern_funcs = extern_funcs; + } + + /// Registers the provided external functions (possibly replacing already registered functions) + pub fn register_extern_funcs(&mut self, extern_funcs: HashMap) { + self.world.extern_funcs.extend(extern_funcs); + } + + /// Registers the provided external function (possibly replacing an already registered function) + pub fn register_extern_func(&mut self, name: String, func: ExternFunc) { + self.world.extern_funcs.insert(name, func); + } + /// run a query over the authorizer's Datalog engine to gather data /// /// ```rust @@ -469,15 +489,10 @@ impl Authorizer { &self.public_key_to_block_id, ); - let extern_binary = limits.extern_funcs.clone(); self.world.run_with_limits(&self.symbols, limits)?; - let res = self.world.query_rule( - rule, - usize::MAX, - &rule_trusted_origins, - &self.symbols, - &extern_binary, - )?; + let res = self + .world + .query_rule(rule, usize::MAX, &rule_trusted_origins, &self.symbols)?; res.inner .into_iter() @@ -557,7 +572,6 @@ impl Authorizer { rule: datalog::Rule, limits: AuthorizerLimits, ) -> Result, error::Token> { - let extern_binary = limits.extern_funcs.clone(); self.world.run_with_limits(&self.symbols, limits)?; let rule_trusted_origins = if rule.scopes.is_empty() { @@ -574,13 +588,9 @@ impl Authorizer { ) }; - let res = self.world.query_rule( - rule, - 0, - &rule_trusted_origins, - &self.symbols, - &extern_binary, - )?; + let res = self + .world + .query_rule(rule, 0, &rule_trusted_origins, &self.symbols)?; let r: HashSet<_> = res.into_iter().map(|(_, fact)| fact).collect(); @@ -751,20 +761,16 @@ impl Authorizer { usize::MAX, &rule_trusted_origins, &self.symbols, - &limits.extern_funcs, - )?, - CheckKind::All => self.world.query_match_all( - query, - &rule_trusted_origins, - &self.symbols, - &limits.extern_funcs, )?, + CheckKind::All => { + self.world + .query_match_all(query, &rule_trusted_origins, &self.symbols)? + } CheckKind::Reject => !self.world.query_match( query, usize::MAX, &rule_trusted_origins, &self.symbols, - &limits.extern_funcs, )?, }; @@ -813,20 +819,17 @@ impl Authorizer { 0, &rule_trusted_origins, &self.symbols, - &limits.extern_funcs, )?, CheckKind::All => self.world.query_match_all( query.clone(), &rule_trusted_origins, &self.symbols, - &limits.extern_funcs, )?, CheckKind::Reject => !self.world.query_match( query.clone(), 0, &rule_trusted_origins, &self.symbols, - &limits.extern_funcs, )?, }; @@ -866,7 +869,6 @@ impl Authorizer { usize::MAX, &rule_trusted_origins, &self.symbols, - &limits.extern_funcs, )?; let now = Instant::now(); @@ -916,20 +918,17 @@ impl Authorizer { i + 1, &rule_trusted_origins, &self.symbols, - &limits.extern_funcs, )?, CheckKind::All => self.world.query_match_all( query.clone(), &rule_trusted_origins, &self.symbols, - &limits.extern_funcs, )?, CheckKind::Reject => !self.world.query_match( query.clone(), i + 1, &rule_trusted_origins, &self.symbols, - &limits.extern_funcs, )?, }; diff --git a/biscuit-auth/src/token/authorizer/snapshot.rs b/biscuit-auth/src/token/authorizer/snapshot.rs index 247dd890..373aff9f 100644 --- a/biscuit-auth/src/token/authorizer/snapshot.rs +++ b/biscuit-auth/src/token/authorizer/snapshot.rs @@ -31,7 +31,6 @@ impl super::Authorizer { max_facts: limits.max_facts, max_iterations: limits.max_iterations, max_time: Duration::from_nanos(limits.max_time), - extern_funcs: Default::default(), }; let execution_time = Duration::from_nanos(execution_time);