Expose Tantivy's DisjunctionMaxQuery (#244)
Co-authored-by: Caleb Hattingh <caleb.hattingh@gmail.com>master
parent
7651d2b2cb
commit
deb88ccdcd
29
src/query.rs
29
src/query.rs
|
@ -1,6 +1,8 @@
|
||||||
use crate::{make_term, Schema};
|
use crate::{make_term, Schema};
|
||||||
use pyo3::{
|
use pyo3::{
|
||||||
exceptions, prelude::*, types::PyAny, types::PyString, types::PyTuple,
|
exceptions,
|
||||||
|
prelude::*,
|
||||||
|
types::{PyAny, PyFloat, PyString, PyTuple},
|
||||||
};
|
};
|
||||||
use tantivy as tv;
|
use tantivy as tv;
|
||||||
|
|
||||||
|
@ -151,4 +153,29 @@ impl Query {
|
||||||
inner: Box::new(inner),
|
inner: Box::new(inner),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Construct a Tantivy's DisjunctionMaxQuery
|
||||||
|
#[staticmethod]
|
||||||
|
pub(crate) fn disjunction_max_query(
|
||||||
|
subqueries: Vec<Query>,
|
||||||
|
tie_breaker: Option<&PyFloat>,
|
||||||
|
) -> PyResult<Query> {
|
||||||
|
let inner_queries: Vec<Box<dyn tv::query::Query>> = subqueries
|
||||||
|
.iter()
|
||||||
|
.map(|query| query.inner.box_clone())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let dismax_query = if let Some(tie_breaker) = tie_breaker {
|
||||||
|
tv::query::DisjunctionMaxQuery::with_tie_breaker(
|
||||||
|
inner_queries,
|
||||||
|
tie_breaker.extract::<f32>()?,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
tv::query::DisjunctionMaxQuery::new(inner_queries)
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Query {
|
||||||
|
inner: Box::new(dismax_query),
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -209,6 +209,11 @@ class Query:
|
||||||
def boolean_query(subqueries: Sequence[tuple[Occur, Query]]) -> Query:
|
def boolean_query(subqueries: Sequence[tuple[Occur, Query]]) -> Query:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def disjunction_max_query(subqueries: Sequence[Query], tie_breaker: Optional[float] = None) -> Query:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Order(Enum):
|
class Order(Enum):
|
||||||
Asc = 1
|
Asc = 1
|
||||||
Desc = 2
|
Desc = 2
|
||||||
|
|
|
@ -877,4 +877,25 @@ class TestQuery(object):
|
||||||
with pytest.raises(TypeError, match = r"'Query' object cannot be converted to 'Occur'"):
|
with pytest.raises(TypeError, match = r"'Query' object cannot be converted to 'Occur'"):
|
||||||
Query.boolean_query([
|
Query.boolean_query([
|
||||||
(query1, Occur.Must),
|
(query1, Occur.Must),
|
||||||
])
|
])
|
||||||
|
|
||||||
|
def test_disjunction_max_query(self, ram_index):
|
||||||
|
index = ram_index
|
||||||
|
|
||||||
|
# query1 should match the doc: "The Old Man and the Sea"
|
||||||
|
query1 = Query.term_query(index.schema, "title", "sea")
|
||||||
|
# query2 should matches the doc: "Of Mice and Men"
|
||||||
|
query2 = Query.term_query(index.schema, "title", "mice")
|
||||||
|
# the disjunction max query should match both docs.
|
||||||
|
query = Query.disjunction_max_query([query1, query2])
|
||||||
|
|
||||||
|
result = index.searcher().search(query, 10)
|
||||||
|
assert len(result.hits) == 2
|
||||||
|
|
||||||
|
# the disjunction max query should also take a tie_breaker parameter
|
||||||
|
query = Query.disjunction_max_query([query1, query2], tie_breaker=0.5)
|
||||||
|
result = index.searcher().search(query, 10)
|
||||||
|
assert len(result.hits) == 2
|
||||||
|
|
||||||
|
with pytest.raises(TypeError, match = r"'str' object cannot be converted to 'Query'"):
|
||||||
|
query = Query.disjunction_max_query([query1, "not a query"], tie_breaker=0.5)
|
||||||
|
|
Loading…
Reference in New Issue