From 10553de76e41c61c94ff4340862b99f930dc0e5a Mon Sep 17 00:00:00 2001 From: alex-au-922 <79151747+alex-au-922@users.noreply.github.com> Date: Sun, 28 Apr 2024 18:54:21 +0800 Subject: [PATCH] Expose const score query (#256) --- src/query.rs | 16 ++++++++++++++++ tantivy/tantivy.pyi | 4 ++++ tests/tantivy_test.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 51 insertions(+) diff --git a/src/query.rs b/src/query.rs index 306b40b..1160f6a 100644 --- a/src/query.rs +++ b/src/query.rs @@ -157,6 +157,7 @@ impl Query { }) } + /// Construct a Tantivy's BooleanQuery #[staticmethod] #[pyo3(signature = (subqueries))] pub(crate) fn boolean_query( @@ -199,6 +200,7 @@ impl Query { }) } + /// Construct a Tantivy's BoostQuery #[staticmethod] #[pyo3(signature = (query, boost))] pub(crate) fn boost_query(query: Query, boost: f32) -> PyResult { @@ -208,6 +210,7 @@ impl Query { }) } + /// Construct a Tantivy's RegexQuery #[staticmethod] #[pyo3(signature = (schema, field_name, regex_pattern))] pub(crate) fn regex_query( @@ -226,4 +229,17 @@ impl Query { Err(e) => Err(to_pyerr(e)), } } + + /// Construct a Tantivy's ConstScoreQuery + #[staticmethod] + #[pyo3(signature = (query, score))] + pub(crate) fn const_score_query( + query: Query, + score: f32, + ) -> PyResult { + let inner = tv::query::ConstScoreQuery::new(query.inner, score); + Ok(Query { + inner: Box::new(inner), + }) + } } diff --git a/tantivy/tantivy.pyi b/tantivy/tantivy.pyi index c5f91e9..9325ae7 100644 --- a/tantivy/tantivy.pyi +++ b/tantivy/tantivy.pyi @@ -233,6 +233,10 @@ class Query: def regex_query(schema: Schema, field_name: str, regex_pattern: str) -> Query: pass + @staticmethod + def const_score_query(query: Query, score: float) -> Query: + pass + class Order(Enum): Asc = 1 Desc = 2 diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 9e904bd..bbb3219 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -1057,3 +1057,34 @@ class TestQuery(object): ValueError, match=r"An invalid argument was passed: 'fish\('" ): Query.regex_query(index.schema, "body", "fish(") + + def test_const_score_query(self, ram_index): + index = ram_index + + query = Query.regex_query(index.schema, "body", "fish") + const_score_query = Query.const_score_query( + query, score = 1.0 + ) + result = index.searcher().search(const_score_query, 10) + assert len(result.hits) == 1 + score, _ = result.hits[0] + # the score should be 1.0 + assert score == pytest.approx(1.0) + + const_score_query = Query.const_score_query( + Query.const_score_query( + query, score = 1.0 + ), score = 0.5 + ) + + result = index.searcher().search(const_score_query, 10) + assert len(result.hits) == 1 + score, _ = result.hits[0] + # nested const score queries should retain the + # score of the outer query + assert score == pytest.approx(0.5) + + # wrong score type + with pytest.raises(TypeError, match = r"argument 'score': must be real number, not str"): + Query.const_score_query(query, "0.1") + \ No newline at end of file