Expose Tantivy's DisjunctionMaxQuery (#244)

Co-authored-by: Caleb Hattingh <caleb.hattingh@gmail.com>
master
Aécio Santos 2024-04-24 05:12:24 -07:00 committed by GitHub
parent 7651d2b2cb
commit deb88ccdcd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 55 additions and 2 deletions

View File

@ -1,6 +1,8 @@
use crate::{make_term, Schema}; use crate::{make_term, Schema};
use pyo3::{ use pyo3::{
exceptions, prelude::*, types::PyAny, types::PyString, types::PyTuple, exceptions,
prelude::*,
types::{PyAny, PyFloat, PyString, PyTuple},
}; };
use tantivy as tv; use tantivy as tv;
@ -151,4 +153,29 @@ impl Query {
inner: Box::new(inner), inner: Box::new(inner),
}) })
} }
/// Construct a Tantivy's DisjunctionMaxQuery
#[staticmethod]
pub(crate) fn disjunction_max_query(
subqueries: Vec<Query>,
tie_breaker: Option<&PyFloat>,
) -> PyResult<Query> {
let inner_queries: Vec<Box<dyn tv::query::Query>> = subqueries
.iter()
.map(|query| query.inner.box_clone())
.collect();
let dismax_query = if let Some(tie_breaker) = tie_breaker {
tv::query::DisjunctionMaxQuery::with_tie_breaker(
inner_queries,
tie_breaker.extract::<f32>()?,
)
} else {
tv::query::DisjunctionMaxQuery::new(inner_queries)
};
Ok(Query {
inner: Box::new(dismax_query),
})
}
} }

View File

@ -209,6 +209,11 @@ class Query:
def boolean_query(subqueries: Sequence[tuple[Occur, Query]]) -> Query: def boolean_query(subqueries: Sequence[tuple[Occur, Query]]) -> Query:
pass pass
@staticmethod
def disjunction_max_query(subqueries: Sequence[Query], tie_breaker: Optional[float] = None) -> Query:
pass
class Order(Enum): class Order(Enum):
Asc = 1 Asc = 1
Desc = 2 Desc = 2

View File

@ -877,4 +877,25 @@ class TestQuery(object):
with pytest.raises(TypeError, match = r"'Query' object cannot be converted to 'Occur'"): with pytest.raises(TypeError, match = r"'Query' object cannot be converted to 'Occur'"):
Query.boolean_query([ Query.boolean_query([
(query1, Occur.Must), (query1, Occur.Must),
]) ])
def test_disjunction_max_query(self, ram_index):
index = ram_index
# query1 should match the doc: "The Old Man and the Sea"
query1 = Query.term_query(index.schema, "title", "sea")
# query2 should matches the doc: "Of Mice and Men"
query2 = Query.term_query(index.schema, "title", "mice")
# the disjunction max query should match both docs.
query = Query.disjunction_max_query([query1, query2])
result = index.searcher().search(query, 10)
assert len(result.hits) == 2
# the disjunction max query should also take a tie_breaker parameter
query = Query.disjunction_max_query([query1, query2], tie_breaker=0.5)
result = index.searcher().search(query, 10)
assert len(result.hits) == 2
with pytest.raises(TypeError, match = r"'str' object cannot be converted to 'Query'"):
query = Query.disjunction_max_query([query1, "not a query"], tie_breaker=0.5)