Skip to content

Commit

Permalink
fix(cognitarium): ensure equality of same var values in a triple pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
amimart committed Feb 27, 2024
1 parent 870d0f3 commit 3c05fe8
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 37 deletions.
77 changes: 51 additions & 26 deletions contracts/okp4-cognitarium/src/querier/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,29 +395,39 @@ impl<'a> TriplePatternIterator<'a> {
},
})
}

fn map_triple(&self, triple: Triple) -> Option<ResolvedVariables> {
let mut vars: ResolvedVariables = self.input.clone();

if let Some(v) = self.output_bindings.0 {
vars.merge_index(v, ResolvedVariable::Subject(triple.subject))?;
}
if let Some(v) = self.output_bindings.1 {
vars.merge_index(v, ResolvedVariable::Predicate(triple.predicate))?;
}
if let Some(v) = self.output_bindings.2 {
vars.merge_index(v, ResolvedVariable::Object(triple.object))?;
}

Some(vars)
}
}

impl<'a> Iterator for TriplePatternIterator<'a> {
type Item = StdResult<ResolvedVariables>;

fn next(&mut self) -> Option<Self::Item> {
self.triple_iter.next().map(|res| {
res.map(|triple| -> ResolvedVariables {
let mut vars: ResolvedVariables = self.input.clone();
let next = self.triple_iter.next()?;

if let Some(v) = self.output_bindings.0 {
vars.set(v, ResolvedVariable::Subject(triple.subject));
}
if let Some(v) = self.output_bindings.1 {
vars.set(v, ResolvedVariable::Predicate(triple.predicate));
}
if let Some(v) = self.output_bindings.2 {
vars.set(v, ResolvedVariable::Object(triple.object));
}
let maybe_next = match next {
Ok(triple) => self.map_triple(triple).map(|r| Ok(r)),
Err(e) => Some(Err(e)),
};

vars
})
})
if maybe_next == None {
return self.next();
}
maybe_next
}
}

Expand Down Expand Up @@ -1016,6 +1026,21 @@ mod test {
)])],
)),
},
TestCase {
plan: QueryPlan {
entrypoint: QueryNode::TriplePattern {
subject: PatternValue::Constant(Subject::Named(state::Node {
namespace: 0,
value: "97ff7e16-c08d-47be-8475-211016c82e33".to_string(),
})),
predicate: PatternValue::Variable(0),
object: PatternValue::Variable(0),
},
variables: vec![PlanVariable::Basic("v".to_string())],
},
selection: vec![SelectItem::Variable("v".to_string())],
expects: Ok((vec!["v".to_string()], vec![])),
},
TestCase {
plan: QueryPlan {
entrypoint: QueryNode::Limit {
Expand Down Expand Up @@ -1309,13 +1334,13 @@ mod test {
let result = ForLoopJoinIterator::new(
Box::new(case.left.iter().map(|v| {
let mut vars = ResolvedVariables::with_capacity(3);
vars.set(1, ResolvedVariable::Subject(Subject::Blank(v.clone())));
vars.merge_index(1, ResolvedVariable::Subject(Subject::Blank(v.clone())));
Ok(vars)
})),
Rc::new(|input| {
Box::new(case.right.iter().map(move |v| {
let mut vars = input.clone();
vars.set(2, ResolvedVariable::Subject(Subject::Blank(v.clone())));
vars.merge_index(2, ResolvedVariable::Subject(Subject::Blank(v.clone())));
Ok(vars)
}))
}),
Expand All @@ -1328,8 +1353,8 @@ mod test {
.iter()
.map(|(v1, v2)| {
let mut vars = ResolvedVariables::with_capacity(3);
vars.set(1, ResolvedVariable::Subject(Subject::Blank(v1.clone())));
vars.set(2, ResolvedVariable::Subject(Subject::Blank(v2.clone())));
vars.merge_index(1, ResolvedVariable::Subject(Subject::Blank(v1.clone())));
vars.merge_index(2, ResolvedVariable::Subject(Subject::Blank(v2.clone())));
vars
})
.collect();
Expand Down Expand Up @@ -1383,13 +1408,13 @@ mod test {
.iter()
.map(|v| {
let mut vars = ResolvedVariables::with_capacity(2);
vars.set(0, ResolvedVariable::Subject(Subject::Blank(v.clone())));
vars.merge_index(0, ResolvedVariable::Subject(Subject::Blank(v.clone())));
vars
})
.collect(),
Box::new(case.left.iter().map(|v| {
let mut vars = ResolvedVariables::with_capacity(2);
vars.set(1, ResolvedVariable::Subject(Subject::Blank(v.clone())));
vars.merge_index(1, ResolvedVariable::Subject(Subject::Blank(v.clone())));
Ok(vars)
})),
VecDeque::new(),
Expand All @@ -1403,10 +1428,10 @@ mod test {
.map(|v| {
let mut vars = ResolvedVariables::with_capacity(2);
if let Some(val) = v.get(0) {
vars.set(0, ResolvedVariable::Subject(Subject::Blank(val.clone())));
vars.merge_index(0, ResolvedVariable::Subject(Subject::Blank(val.clone())));
}
if let Some(val) = v.get(1) {
vars.set(1, ResolvedVariable::Subject(Subject::Blank(val.clone())));
vars.merge_index(1, ResolvedVariable::Subject(Subject::Blank(val.clone())));
}
vars
})
Expand All @@ -1426,9 +1451,9 @@ mod test {
let t_object = Object::Blank("o".to_string());

let mut variables = ResolvedVariables::with_capacity(6);
variables.set(1, ResolvedVariable::Subject(t_subject.clone()));
variables.set(2, ResolvedVariable::Predicate(t_predicate.clone()));
variables.set(3, ResolvedVariable::Object(t_object.clone()));
variables.merge_index(1, ResolvedVariable::Subject(t_subject.clone()));
variables.merge_index(2, ResolvedVariable::Predicate(t_predicate.clone()));
variables.merge_index(3, ResolvedVariable::Object(t_object.clone()));

struct TestCase {
subject: PatternValue<Subject>,
Expand Down
27 changes: 16 additions & 11 deletions contracts/okp4-cognitarium/src/querier/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,13 @@ impl ResolvedVariables {
Some(Self { variables: merged })
}

pub fn set(&mut self, index: usize, var: ResolvedVariable) {
self.variables[index] = Some(var);
pub fn merge_index(&mut self, index: usize, var: ResolvedVariable) -> Option<()> {
if let Some(old) = self.get(index) {
(*old == var).then(|| ())
} else {
self.variables[index] = Some(var);
Some(())
}
}

pub fn get(&self, index: usize) -> &Option<ResolvedVariable> {
Expand Down Expand Up @@ -295,21 +300,21 @@ mod tests {
#[test]
fn merged_variables() {
let mut vars1 = ResolvedVariables::with_capacity(3);
vars1.set(
vars1.merge_index(
0,
ResolvedVariable::Object(Object::Blank("foo".to_string())),
);
vars1.set(
vars1.merge_index(
2,
ResolvedVariable::Object(Object::Blank("bar".to_string())),
);

let mut vars2 = ResolvedVariables::with_capacity(3);
vars2.set(
vars2.merge_index(
1,
ResolvedVariable::Object(Object::Blank("pop".to_string())),
);
vars2.set(
vars2.merge_index(
2,
ResolvedVariable::Object(Object::Blank("bar".to_string())),
);
Expand All @@ -321,15 +326,15 @@ mod tests {
assert_eq!(vars1.get(1), &None);

let mut expected_result = ResolvedVariables::with_capacity(3);
expected_result.set(
expected_result.merge_index(
0,
ResolvedVariable::Object(Object::Blank("foo".to_string())),
);
expected_result.set(
expected_result.merge_index(
1,
ResolvedVariable::Object(Object::Blank("pop".to_string())),
);
expected_result.set(
expected_result.merge_index(
2,
ResolvedVariable::Object(Object::Blank("bar".to_string())),
);
Expand All @@ -338,11 +343,11 @@ mod tests {
assert_eq!(result, Some(expected_result));

let mut vars3 = ResolvedVariables::with_capacity(3);
vars3.set(
vars3.merge_index(
1,
ResolvedVariable::Object(Object::Blank("pop".to_string())),
);
vars3.set(
vars3.merge_index(
2,
ResolvedVariable::Predicate(Node {
namespace: 0,
Expand Down

0 comments on commit 3c05fe8

Please sign in to comment.