From deb88ccdcdbbb1aad0bca3e691bc58bfaca23133 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?A=C3=A9cio=20Santos?= <150570+aecio@users.noreply.github.com> Date: Wed, 24 Apr 2024 05:12:24 -0700 Subject: [PATCH] Expose Tantivy's DisjunctionMaxQuery (#244) Co-authored-by: Caleb Hattingh --- src/query.rs | 29 ++++++++++++++++++++++++++++- tantivy/tantivy.pyi | 5 +++++ tests/tantivy_test.py | 23 ++++++++++++++++++++++- 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/src/query.rs b/src/query.rs index bf036fe..a310d04 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,6 +1,8 @@ use crate::{make_term, Schema}; use pyo3::{ - exceptions, prelude::*, types::PyAny, types::PyString, types::PyTuple, + exceptions, + prelude::*, + types::{PyAny, PyFloat, PyString, PyTuple}, }; use tantivy as tv; @@ -151,4 +153,29 @@ impl Query { inner: Box::new(inner), }) } + + /// Construct a Tantivy's DisjunctionMaxQuery + #[staticmethod] + pub(crate) fn disjunction_max_query( + subqueries: Vec, + tie_breaker: Option<&PyFloat>, + ) -> PyResult { + let inner_queries: Vec> = subqueries + .iter() + .map(|query| query.inner.box_clone()) + .collect(); + + let dismax_query = if let Some(tie_breaker) = tie_breaker { + tv::query::DisjunctionMaxQuery::with_tie_breaker( + inner_queries, + tie_breaker.extract::()?, + ) + } else { + tv::query::DisjunctionMaxQuery::new(inner_queries) + }; + + Ok(Query { + inner: Box::new(dismax_query), + }) + } } diff --git a/tantivy/tantivy.pyi b/tantivy/tantivy.pyi index 466a744..ee267b8 100644 --- a/tantivy/tantivy.pyi +++ b/tantivy/tantivy.pyi @@ -209,6 +209,11 @@ class Query: def boolean_query(subqueries: Sequence[tuple[Occur, Query]]) -> Query: pass + @staticmethod + def disjunction_max_query(subqueries: Sequence[Query], tie_breaker: Optional[float] = None) -> Query: + pass + + class Order(Enum): Asc = 1 Desc = 2 diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 90f3b63..0124c2f 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -877,4 +877,25 @@ class TestQuery(object): with pytest.raises(TypeError, match = r"'Query' object cannot be converted to 'Occur'"): Query.boolean_query([ (query1, Occur.Must), - ]) \ No newline at end of file + ]) + + def test_disjunction_max_query(self, ram_index): + index = ram_index + + # query1 should match the doc: "The Old Man and the Sea" + query1 = Query.term_query(index.schema, "title", "sea") + # query2 should matches the doc: "Of Mice and Men" + query2 = Query.term_query(index.schema, "title", "mice") + # the disjunction max query should match both docs. + query = Query.disjunction_max_query([query1, query2]) + + result = index.searcher().search(query, 10) + assert len(result.hits) == 2 + + # the disjunction max query should also take a tie_breaker parameter + query = Query.disjunction_max_query([query1, query2], tie_breaker=0.5) + result = index.searcher().search(query, 10) + assert len(result.hits) == 2 + + with pytest.raises(TypeError, match = r"'str' object cannot be converted to 'Query'"): + query = Query.disjunction_max_query([query1, "not a query"], tie_breaker=0.5)