Expose Tantivy's DisjunctionMaxQuery (#244)
Co-authored-by: Caleb Hattingh <caleb.hattingh@gmail.com>master
parent
7651d2b2cb
commit
deb88ccdcd
29
src/query.rs
29
src/query.rs
|
@ -1,6 +1,8 @@
|
|||
use crate::{make_term, Schema};
|
||||
use pyo3::{
|
||||
exceptions, prelude::*, types::PyAny, types::PyString, types::PyTuple,
|
||||
exceptions,
|
||||
prelude::*,
|
||||
types::{PyAny, PyFloat, PyString, PyTuple},
|
||||
};
|
||||
use tantivy as tv;
|
||||
|
||||
|
@ -151,4 +153,29 @@ impl Query {
|
|||
inner: Box::new(inner),
|
||||
})
|
||||
}
|
||||
|
||||
/// Construct a Tantivy's DisjunctionMaxQuery
|
||||
#[staticmethod]
|
||||
pub(crate) fn disjunction_max_query(
|
||||
subqueries: Vec<Query>,
|
||||
tie_breaker: Option<&PyFloat>,
|
||||
) -> PyResult<Query> {
|
||||
let inner_queries: Vec<Box<dyn tv::query::Query>> = subqueries
|
||||
.iter()
|
||||
.map(|query| query.inner.box_clone())
|
||||
.collect();
|
||||
|
||||
let dismax_query = if let Some(tie_breaker) = tie_breaker {
|
||||
tv::query::DisjunctionMaxQuery::with_tie_breaker(
|
||||
inner_queries,
|
||||
tie_breaker.extract::<f32>()?,
|
||||
)
|
||||
} else {
|
||||
tv::query::DisjunctionMaxQuery::new(inner_queries)
|
||||
};
|
||||
|
||||
Ok(Query {
|
||||
inner: Box::new(dismax_query),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -209,6 +209,11 @@ class Query:
|
|||
def boolean_query(subqueries: Sequence[tuple[Occur, Query]]) -> Query:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def disjunction_max_query(subqueries: Sequence[Query], tie_breaker: Optional[float] = None) -> Query:
|
||||
pass
|
||||
|
||||
|
||||
class Order(Enum):
|
||||
Asc = 1
|
||||
Desc = 2
|
||||
|
|
|
@ -877,4 +877,25 @@ class TestQuery(object):
|
|||
with pytest.raises(TypeError, match = r"'Query' object cannot be converted to 'Occur'"):
|
||||
Query.boolean_query([
|
||||
(query1, Occur.Must),
|
||||
])
|
||||
])
|
||||
|
||||
def test_disjunction_max_query(self, ram_index):
|
||||
index = ram_index
|
||||
|
||||
# query1 should match the doc: "The Old Man and the Sea"
|
||||
query1 = Query.term_query(index.schema, "title", "sea")
|
||||
# query2 should matches the doc: "Of Mice and Men"
|
||||
query2 = Query.term_query(index.schema, "title", "mice")
|
||||
# the disjunction max query should match both docs.
|
||||
query = Query.disjunction_max_query([query1, query2])
|
||||
|
||||
result = index.searcher().search(query, 10)
|
||||
assert len(result.hits) == 2
|
||||
|
||||
# the disjunction max query should also take a tie_breaker parameter
|
||||
query = Query.disjunction_max_query([query1, query2], tie_breaker=0.5)
|
||||
result = index.searcher().search(query, 10)
|
||||
assert len(result.hits) == 2
|
||||
|
||||
with pytest.raises(TypeError, match = r"'str' object cannot be converted to 'Query'"):
|
||||
query = Query.disjunction_max_query([query1, "not a query"], tie_breaker=0.5)
|
||||
|
|
Loading…
Reference in New Issue