diff --git a/src/query.rs b/src/query.rs index 00e35ca..306b40b 100644 --- a/src/query.rs +++ b/src/query.rs @@ -89,6 +89,26 @@ impl Query { }) } + /// Construct a Tantivy's TermSetQuery + #[staticmethod] + #[pyo3(signature = (schema, field_name, field_values))] + pub(crate) fn term_set_query( + schema: &Schema, + field_name: &str, + field_values: Vec<&PyAny>, + ) -> PyResult { + let terms = field_values + .into_iter() + .map(|field_value| { + make_term(&schema.inner, field_name, &field_value) + }) + .collect::, _>>()?; + let inner = tv::query::TermSetQuery::new(terms); + Ok(Query { + inner: Box::new(inner), + }) + } + /// Construct a Tantivy's AllQuery #[staticmethod] pub(crate) fn all_query() -> PyResult { diff --git a/tantivy/tantivy.pyi b/tantivy/tantivy.pyi index 524523c..c5f91e9 100644 --- a/tantivy/tantivy.pyi +++ b/tantivy/tantivy.pyi @@ -197,6 +197,10 @@ class Query: ) -> Query: pass + @staticmethod + def term_set_query(schema: Schema, field_name: str, field_values: Sequence[Any]) -> Query: + pass + @staticmethod def all_query() -> Query: pass diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 0f6b987..9e904bd 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -765,6 +765,35 @@ class TestQuery(object): searched_doc = index.searcher().doc(doc_address) assert searched_doc["title"] == ["The Old Man and the Sea"] + def test_term_set_query(self, ram_index): + index = ram_index + + # Should match 1 document that contains both terms + terms = ["old", "man"] + query = Query.term_set_query(index.schema, "title", terms) + result = index.searcher().search(query, 10) + assert len(result.hits) == 1 + _, doc_address = result.hits[0] + searched_doc = index.searcher().doc(doc_address) + assert searched_doc["title"] == ["The Old Man and the Sea"] + + # Should not match any document since the term does not exist in the index + terms = ["a long term that does not exist in the index"] + query = Query.term_set_query(index.schema, "title", terms) + result = index.searcher().search(query, 10) + assert len(result.hits) == 0 + + # Should not match any document when the terms list is empty + terms = [] + query = Query.term_set_query(index.schema, "title", terms) + result = index.searcher().search(query, 10) + assert len(result.hits) == 0 + + # Should fail to create the query due to the invalid list object in the terms list + with pytest.raises(ValueError, match = r"Can't create a term for Field `title` with value `\[\]`"): + terms = ["old", [], "man"] + query = Query.term_set_query(index.schema, "title", terms) + def test_all_query(self, ram_index): index = ram_index query = Query.all_query()