Skip to content

Commit

Permalink
ffi: store external functions in World
Browse files Browse the repository at this point in the history
This ensures consistent evaluation of functions and removes the need of storing them in `RunLimits`
  • Loading branch information
divarvel committed Nov 19, 2024
1 parent 070bda0 commit a179587
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 93 deletions.
48 changes: 27 additions & 21 deletions biscuit-auth/examples/testcases.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,14 +305,21 @@ enum AuthorizerResult {
}

fn validate_token(root: &KeyPair, data: &[u8], authorizer_code: &str) -> Validation {
validate_token_with_limits(root, data, authorizer_code, RunLimits::default())
validate_token_with_limits_and_external_functions(
root,
data,
authorizer_code,
RunLimits::default(),
Default::default(),
)
}

fn validate_token_with_limits(
fn validate_token_with_limits_and_external_functions(
root: &KeyPair,
data: &[u8],
authorizer_code: &str,
run_limits: RunLimits,
extern_funcs: HashMap<String, ExternFunc>,
) -> Validation {
let token = match Biscuit::from(&data[..], &root.public()) {
Ok(t) => t,
Expand All @@ -333,6 +340,7 @@ fn validate_token_with_limits(
}

let mut authorizer = Authorizer::new();
authorizer.set_extern_funcs(extern_funcs);
authorizer.add_code(authorizer_code).unwrap();
let authorizer_code = authorizer.dump_code();

Expand Down Expand Up @@ -2303,28 +2311,26 @@ fn ffi(target: &str, root: &KeyPair, test: bool) -> TestResult {
let mut validations = BTreeMap::new();
validations.insert(
"".to_string(),
validate_token_with_limits(
validate_token_with_limits_and_external_functions(
root,
&data[..],
"allow if true",
RunLimits {
extern_funcs: HashMap::from([(
"test".to_string(),
ExternFunc::new(Arc::new(|left, right| match (left, right) {
(t, None) => Ok(t),
(builder::Term::Str(left), Some(builder::Term::Str(right)))
if left == right =>
{
Ok(builder::Term::Str("equal strings".to_string()))
}
(builder::Term::Str(_), Some(builder::Term::Str(_))) => {
Ok(builder::Term::Str("different strings".to_string()))
}
_ => Err("unsupported operands".to_string()),
})),
)]),
..Default::default()
},
RunLimits::default(),
HashMap::from([(
"test".to_string(),
ExternFunc::new(Arc::new(|left, right| match (left, right) {
(t, None) => Ok(t),
(builder::Term::Str(left), Some(builder::Term::Str(right)))
if left == right =>
{
Ok(builder::Term::Str("equal strings".to_string()))
}
(builder::Term::Str(_), Some(builder::Term::Str(_))) => {
Ok(builder::Term::Str("different strings".to_string()))

Check warning on line 2329 in biscuit-auth/examples/testcases.rs

View check run for this annotation

Codecov / codecov/patch

biscuit-auth/examples/testcases.rs#L2329

Added line #L2329 was not covered by tests
}
_ => Err("unsupported operands".to_string()),

Check warning on line 2331 in biscuit-auth/examples/testcases.rs

View check run for this annotation

Codecov / codecov/patch

biscuit-auth/examples/testcases.rs#L2331

Added line #L2331 was not covered by tests
})),
)]),
),
);

Expand Down
50 changes: 11 additions & 39 deletions biscuit-auth/src/datalog/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,7 @@ pub struct World {
pub facts: FactSet,
pub rules: RuleSet,
pub iterations: u64,
pub extern_funcs: HashMap<String, ExternFunc>,
}

impl World {
Expand Down Expand Up @@ -622,7 +623,7 @@ impl World {
for (scope, rules) in self.rules.inner.iter() {
let it = self.facts.iterator(scope);
for (origin, rule) in rules {
for res in rule.apply(it.clone(), *origin, symbols, &limits.extern_funcs) {
for res in rule.apply(it.clone(), *origin, symbols, &self.extern_funcs) {
match res {
Ok((origin, fact)) => {
new_facts.insert(&origin, fact);
Expand Down Expand Up @@ -693,12 +694,11 @@ impl World {
origin: usize,
scope: &TrustedOrigins,
symbols: &SymbolTable,
extern_funcs: &HashMap<String, ExternFunc>,
) -> Result<FactSet, Execution> {
let mut new_facts = FactSet::default();
let it = self.facts.iterator(scope);
//new_facts.extend(rule.apply(it, origin, symbols));
for res in rule.apply(it.clone(), origin, symbols, extern_funcs) {
for res in rule.apply(it.clone(), origin, symbols, &self.extern_funcs) {
match res {
Ok((origin, fact)) => {
new_facts.insert(&origin, fact);
Expand All @@ -718,19 +718,17 @@ impl World {
origin: usize,
scope: &TrustedOrigins,
symbols: &SymbolTable,
extern_funcs: &HashMap<String, ExternFunc>,
) -> Result<bool, Execution> {
rule.find_match(&self.facts, origin, scope, symbols, extern_funcs)
rule.find_match(&self.facts, origin, scope, symbols, &self.extern_funcs)
}

pub fn query_match_all(
&self,
rule: Rule,
scope: &TrustedOrigins,
symbols: &SymbolTable,
extern_funcs: &HashMap<String, ExternFunc>,
) -> Result<bool, Execution> {
rule.check_match_all(&self.facts, scope, symbols, extern_funcs)
rule.check_match_all(&self.facts, scope, symbols, &self.extern_funcs)
}
}

Expand All @@ -743,8 +741,6 @@ pub struct RunLimits {
pub max_iterations: u64,
/// maximum execution time
pub max_time: Duration,

pub extern_funcs: HashMap<String, ExternFunc>,
}

impl std::default::Default for RunLimits {
Expand All @@ -753,7 +749,6 @@ impl std::default::Default for RunLimits {
max_facts: 1000,
max_iterations: 100,
max_time: Duration::from_millis(1),
extern_funcs: Default::default(),
}
}
}
Expand Down Expand Up @@ -1056,8 +1051,7 @@ mod tests {

println!("symbols: {:?}", syms);
println!("testing r1: {}", syms.print_rule(&r1));
let query_rule_result =
w.query_rule(r1, 0, &[0].iter().collect(), &syms, &Default::default());
let query_rule_result = w.query_rule(r1, 0, &[0].iter().collect(), &syms);
println!("grandparents query_rules: {:?}", query_rule_result);
println!("current facts: {:?}", w.facts);

Expand Down Expand Up @@ -1102,7 +1096,6 @@ mod tests {
0,
&[0].iter().collect(),
&syms,
&Default::default(),
)
.unwrap();

Expand All @@ -1121,7 +1114,6 @@ mod tests {
0,
&[0].iter().collect(),
&syms,
&Default::default()
)
);
println!(
Expand All @@ -1138,7 +1130,6 @@ mod tests {
0,
&[0].iter().collect(),
&syms,
&Default::default()
)
);
w.add_fact(&[0].iter().collect(), fact(parent, &[&c, &e]));
Expand All @@ -1156,7 +1147,6 @@ mod tests {
0,
&[0].iter().collect(),
&syms,
&Default::default(),
)
.unwrap();
println!("grandparents after inserting parent(C, E): {:?}", res);
Expand Down Expand Up @@ -1232,7 +1222,6 @@ mod tests {
0,
&[0].iter().collect(),
&syms,
&Default::default(),
)
.unwrap();

Expand Down Expand Up @@ -1282,7 +1271,6 @@ mod tests {
0,
&[0].iter().collect(),
&syms,
&Default::default(),
)
.unwrap();

Expand Down Expand Up @@ -1369,7 +1357,6 @@ mod tests {
0,
&[0].iter().collect(),
&syms,
&Default::default(),
)
.unwrap()
.iter_all()
Expand Down Expand Up @@ -1450,9 +1437,7 @@ mod tests {
);

println!("testing r1: {}", syms.print_rule(&r1));
let res = w
.query_rule(r1, 0, &[0].iter().collect(), &syms, &Default::default())
.unwrap();
let res = w.query_rule(r1, 0, &[0].iter().collect(), &syms).unwrap();
for (_, fact) in res.iter_all() {
println!("\t{}", syms.print_fact(fact));
}
Expand Down Expand Up @@ -1490,9 +1475,7 @@ mod tests {
);

println!("testing r2: {}", syms.print_rule(&r2));
let res = w
.query_rule(r2, 0, &[0].iter().collect(), &syms, &Default::default())
.unwrap();
let res = w.query_rule(r2, 0, &[0].iter().collect(), &syms).unwrap();
for (_, fact) in res.iter_all() {
println!("\t{}", syms.print_fact(fact));
}
Expand Down Expand Up @@ -1555,7 +1538,6 @@ mod tests {
0,
&[0].iter().collect(),
&syms,
&Default::default(),
)
.unwrap();

Expand Down Expand Up @@ -1607,7 +1589,6 @@ mod tests {
0,
&[0].iter().collect(),
&syms,
&Default::default(),
)
.unwrap();

Expand Down Expand Up @@ -1653,7 +1634,6 @@ mod tests {
0,
&[0].iter().collect(),
&syms,
&Default::default(),
)
.unwrap();

Expand Down Expand Up @@ -1699,7 +1679,6 @@ mod tests {
0,
&[0].iter().collect(),
&syms,
&Default::default(),
)
.unwrap();

Expand All @@ -1723,7 +1702,6 @@ mod tests {
0,
&[0].iter().collect(),
&syms,
&Default::default(),
)
.unwrap();

Expand Down Expand Up @@ -1766,9 +1744,7 @@ mod tests {

println!("world:\n{}\n", syms.print_world(&w));
println!("\ntesting r1: {}\n", syms.print_rule(&r1));
let res = w
.query_rule(r1, 0, &[0].iter().collect(), &syms, &Default::default())
.unwrap();
let res = w.query_rule(r1, 0, &[0].iter().collect(), &syms).unwrap();
for (_, fact) in res.iter_all() {
println!("\t{}", syms.print_fact(fact));
}
Expand Down Expand Up @@ -1807,9 +1783,7 @@ mod tests {
);
println!("world:\n{}\n", syms.print_world(&w));
println!("\ntesting r1: {}\n", syms.print_rule(&r1));
let res = w
.query_rule(r1, 0, &[0].iter().collect(), &syms, &Default::default())
.unwrap();
let res = w.query_rule(r1, 0, &[0].iter().collect(), &syms).unwrap();

println!("generated facts:");
for (_, fact) in res.iter_all() {
Expand All @@ -1825,9 +1799,7 @@ mod tests {
let r2 = rule(check, &[&read], &[pred(operation, &[&read])]);
println!("world:\n{}\n", syms.print_world(&w));
println!("\ntesting r2: {}\n", syms.print_rule(&r2));
let res = w
.query_rule(r2, 0, &[0].iter().collect(), &syms, &Default::default())
.unwrap();
let res = w.query_rule(r2, 0, &[0].iter().collect(), &syms).unwrap();

println!("generated facts:");
for (_, fact) in res.iter_all() {
Expand Down
Loading

0 comments on commit a179587

Please sign in to comment.