Skip to content

Commit

Permalink
Expose the method of boolean query (#243)
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-au-922 authored Apr 22, 2024
1 parent 9fa82ef commit 1d61b96
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 6 deletions.
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ mod snippet;
use document::Document;
use facet::Facet;
use index::Index;
use query::Query;
use query::{Occur, Query};
use schema::Schema;
use schemabuilder::SchemaBuilder;
use searcher::{DocAddress, Order, SearchResult, Searcher};
Expand Down Expand Up @@ -87,6 +87,7 @@ fn tantivy(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Query>()?;
m.add_class::<Snippet>()?;
m.add_class::<SnippetGenerator>()?;
m.add_class::<Occur>()?;

m.add_wrapped(wrap_pymodule!(query_parser_error))?;

Expand Down
61 changes: 60 additions & 1 deletion src/query.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,55 @@
use crate::{make_term, Schema};
use pyo3::{exceptions, prelude::*, types::PyAny, types::PyString};
use pyo3::{exceptions, prelude::*, types::PyAny, types::PyString, types::PyTuple};
use tantivy as tv;

/// Custom Tuple struct to represent a pair of Occur and Query
/// for the BooleanQuery
struct OccurQueryPair(Occur, Query);

impl <'source> FromPyObject<'source> for OccurQueryPair {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let tuple = ob.downcast::<PyTuple>()?;
let occur = tuple.get_item(0)?.extract()?;
let query = tuple.get_item(1)?.extract()?;

Ok(OccurQueryPair(occur, query))
}
}


/// Tantivy's Occur
#[pyclass(frozen, module = "tantivy.tantivy")]
#[derive(Clone)]
pub enum Occur {
Must,
Should,
MustNot,
}

impl From<Occur> for tv::query::Occur {
fn from(occur: Occur) -> tv::query::Occur {
match occur {
Occur::Must => tv::query::Occur::Must,
Occur::Should => tv::query::Occur::Should,
Occur::MustNot => tv::query::Occur::MustNot,
}
}
}

/// Tantivy's Query
#[pyclass(frozen, module = "tantivy.tantivy")]
pub(crate) struct Query {
pub(crate) inner: Box<dyn tv::query::Query>,
}

impl Clone for Query {
fn clone(&self) -> Self {
Query {
inner: self.inner.box_clone(),
}
}
}

impl Query {
pub(crate) fn get(&self) -> &dyn tv::query::Query {
&self.inner
Expand Down Expand Up @@ -91,4 +133,21 @@ impl Query {
inner: Box::new(inner),
})
}

#[staticmethod]
#[pyo3(signature = (subqueries))]
pub(crate) fn boolean_query(
subqueries: Vec<(Occur, Query)>
) -> PyResult<Query> {
let dyn_subqueries = subqueries
.into_iter()
.map(|(occur, query)| (occur.into(), query.inner.box_clone()))
.collect::<Vec<_>>();

let inner = tv::query::BooleanQuery::from(dyn_subqueries);

Ok(Query {
inner: Box::new(inner),
})
}
}
11 changes: 9 additions & 2 deletions tantivy/tantivy.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
from enum import Enum
from typing import Any, Optional
from typing import Any, Optional, Sequence


class Schema:
Expand Down Expand Up @@ -187,6 +187,10 @@ class Document:
def get_all(self, field_name: str) -> list[Any]:
pass

class Occur(Enum):
Must = 1
Should = 2
MustNot = 3

class Query:
@staticmethod
Expand All @@ -200,7 +204,10 @@ class Query:
@staticmethod
def fuzzy_term_query(schema: Schema, field_name: str, text: str, distance: int = 1, transposition_cost_one: bool = True, prefix = False) -> Query:
pass


@staticmethod
def boolean_query(subqueries: Sequence[tuple[Occur, Query]]) -> Query:
pass

class Order(Enum):
Asc = 1
Expand Down
62 changes: 60 additions & 2 deletions tests/tantivy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import tantivy
from conftest import schema, schema_numeric_fields
from tantivy import Document, Index, SchemaBuilder, SnippetGenerator, Query
from tantivy import Document, Index, SchemaBuilder, SnippetGenerator, Query, Occur


class TestClass(object):
Expand Down Expand Up @@ -819,4 +819,62 @@ def test_fuzzy_term_query(self, ram_index):
titles.update(index.searcher().doc(doc_address)["title"])
assert titles == {"Frankenstein", "The Modern Prometheus"}


def test_boolean_query(self, ram_index):
index = ram_index
query1 = Query.fuzzy_term_query(index.schema, "title", "ice")
query2 = Query.fuzzy_term_query(index.schema, "title", "mna")
query = Query.boolean_query([
(Occur.Must, query1),
(Occur.Must, query2)
])

# no document should match both queries
result = index.searcher().search(query, 10)
assert len(result.hits) == 0

query = Query.boolean_query([
(Occur.Should, query1),
(Occur.Should, query2)
])

# two documents should match, one for each query
result = index.searcher().search(query, 10)
assert len(result.hits) == 2

titles = set()
for _, doc_address in result.hits:
titles.update(index.searcher().doc(doc_address)["title"])
assert (
"The Old Man and the Sea" in titles and
"Of Mice and Men" in titles
)

query = Query.boolean_query([
(Occur.MustNot, query1),
(Occur.Must, query1)
])

# must not should take precedence over must
result = index.searcher().search(query, 10)
assert len(result.hits) == 0

query = Query.boolean_query((
(Occur.Should, query1),
(Occur.Should, query2)
))

# the Vec signature should fit the tuple signature
result = index.searcher().search(query, 10)
assert len(result.hits) == 2

# test invalid queries
with pytest.raises(ValueError, match = "expected tuple of length 2, but got tuple of length 3"):
Query.boolean_query([
(Occur.Must, Occur.Must, query1),
])

# test swapping the order of the tuple
with pytest.raises(TypeError, match = r"'Query' object cannot be converted to 'Occur'"):
Query.boolean_query([
(query1, Occur.Must),
])

0 comments on commit 1d61b96

Please sign in to comment.