From 9fafdf25cb9ead9519b8ab26bf4efa54a846fd93 Mon Sep 17 00:00:00 2001 From: Tomoko Uchida Date: Sat, 4 May 2024 05:15:21 +0900 Subject: [PATCH] Expose Tantivy's MoreLikeThisQuery (#257) --- src/query.rs | 50 ++++++++++++++++++++++++++++++++++++++++--- src/searcher.rs | 1 + tantivy/tantivy.pyi | 14 ++++++++++++ tests/tantivy_test.py | 38 +++++++++++++++++++++++++++++++- 4 files changed, 99 insertions(+), 4 deletions(-) diff --git a/src/query.rs b/src/query.rs index e450779..a963211 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,4 +1,4 @@ -use crate::{get_field, make_term, to_pyerr, Schema}; +use crate::{get_field, make_term, to_pyerr, DocAddress, Schema}; use pyo3::{ exceptions, prelude::*, @@ -100,7 +100,7 @@ impl Query { let terms = field_values .into_iter() .map(|field_value| { - make_term(&schema.inner, field_name, &field_value) + make_term(&schema.inner, field_name, field_value) }) .collect::, _>>()?; let inner = tv::query::TermSetQuery::new(terms); @@ -138,7 +138,7 @@ impl Query { transposition_cost_one: bool, prefix: bool, ) -> PyResult { - let term = make_term(&schema.inner, field_name, &text)?; + let term = make_term(&schema.inner, field_name, text)?; let inner = if prefix { tv::query::FuzzyTermQuery::new_prefix( term, @@ -272,6 +272,50 @@ impl Query { } } + #[staticmethod] + #[pyo3(signature = (doc_address, min_doc_frequency = Some(5), max_doc_frequency = None, min_term_frequency = Some(2), max_query_terms = Some(25), min_word_length = None, max_word_length = None, boost_factor = Some(1.0), stop_words = vec![]))] + #[allow(clippy::too_many_arguments)] + pub(crate) fn more_like_this_query( + doc_address: &DocAddress, + min_doc_frequency: Option, + max_doc_frequency: Option, + min_term_frequency: Option, + max_query_terms: Option, + min_word_length: Option, + max_word_length: Option, + boost_factor: Option, + stop_words: Vec, + ) -> PyResult { + let mut builder = tv::query::MoreLikeThisQuery::builder(); + if let Some(value) = min_doc_frequency { + builder = builder.with_min_doc_frequency(value); + } + if let Some(value) = max_doc_frequency { + builder = builder.with_max_doc_frequency(value); + } + if let Some(value) = min_term_frequency { + builder = builder.with_min_term_frequency(value); + } + if let Some(value) = max_query_terms { + builder = builder.with_max_query_terms(value); + } + if let Some(value) = min_word_length { + builder = builder.with_min_word_length(value); + } + if let Some(value) = max_word_length { + builder = builder.with_max_word_length(value); + } + if let Some(value) = boost_factor { + builder = builder.with_boost_factor(value); + } + builder = builder.with_stop_words(stop_words); + + let inner = builder.with_document(tv::DocAddress::from(doc_address)); + Ok(Query { + inner: Box::new(inner), + }) + } + /// Construct a Tantivy's ConstScoreQuery #[staticmethod] #[pyo3(signature = (query, score))] diff --git a/src/searcher.rs b/src/searcher.rs index c202bfd..3b0b912 100644 --- a/src/searcher.rs +++ b/src/searcher.rs @@ -153,6 +153,7 @@ impl Searcher { /// /// Raises a ValueError if there was an error with the search. #[pyo3(signature = (query, limit = 10, count = true, order_by_field = None, offset = 0, order = Order::Desc))] + #[allow(clippy::too_many_arguments)] fn search( &self, py: Python, diff --git a/tantivy/tantivy.pyi b/tantivy/tantivy.pyi index 8eaeff0..131557e 100644 --- a/tantivy/tantivy.pyi +++ b/tantivy/tantivy.pyi @@ -227,6 +227,20 @@ class Query: def regex_query(schema: Schema, field_name: str, regex_pattern: str) -> Query: pass + @staticmethod + def more_like_this_query( + doc_address: DocAddress, + min_doc_frequency: Optional[int] = 5, + max_doc_frequency: Optional[int] = None, + min_term_frequency: Optional[int] = 2, + max_query_terms: Optional[int] = 25, + min_word_length: Optional[int] = None, + max_word_length: Optional[int] = None, + boost_factor: Optional[float] = 1.0, + stop_words: list[str] = [] + ) -> Query: + pass + @staticmethod def const_score_query(query: Query, score: float) -> Query: pass diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 4883674..48a013e 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -1090,6 +1090,42 @@ class TestQuery(object): ): Query.regex_query(index.schema, "body", "fish(") + def test_more_like_this_query(self, ram_index): + index = ram_index + + # first, search the target doc + query = Query.term_query(index.schema, "title", "man") + result = index.searcher().search(query, 1) + _, doc_address = result.hits[0] + searched_doc = index.searcher().doc(doc_address) + assert searched_doc["title"] == ["The Old Man and the Sea"] + + # construct the default MLT Query + mlt_query = Query.more_like_this_query(doc_address) + assert ( + repr(mlt_query) + == "Query(MoreLikeThisQuery { mlt: MoreLikeThis { min_doc_frequency: Some(5), max_doc_frequency: None, min_term_frequency: Some(2), max_query_terms: Some(25), min_word_length: None, max_word_length: None, boost_factor: Some(1.0), stop_words: [] }, target: DocumentAdress(DocAddress { segment_ord: 0, doc_id: 0 }) })" + ) + result = index.searcher().search(mlt_query, 10) + assert len(result.hits) == 0 + + # construct a fine-tuned MLT Query + mlt_query = Query.more_like_this_query( + doc_address, + min_doc_frequency=2, + max_doc_frequency=10, + min_term_frequency=1, + max_query_terms=10, + min_word_length=2, + max_word_length=20, + boost_factor=2.0, + stop_words=["fish"]) + assert ( + repr(mlt_query) + == "Query(MoreLikeThisQuery { mlt: MoreLikeThis { min_doc_frequency: Some(2), max_doc_frequency: Some(10), min_term_frequency: Some(1), max_query_terms: Some(10), min_word_length: Some(2), max_word_length: Some(20), boost_factor: Some(2.0), stop_words: [\"fish\"] }, target: DocumentAdress(DocAddress { segment_ord: 0, doc_id: 0 }) })" + ) + result = index.searcher().search(mlt_query, 10) + assert len(result.hits) > 0 def test_const_score_query(self, ram_index): index = ram_index @@ -1119,4 +1155,4 @@ class TestQuery(object): # wrong score type with pytest.raises(TypeError, match = r"argument 'score': must be real number, not str"): Query.const_score_query(query, "0.1") - \ No newline at end of file +