From 1d61b96ffc1f73c38a6b313589aa98c4d1fe7c27 Mon Sep 17 00:00:00 2001 From: alex-au-922 <79151747+alex-au-922@users.noreply.github.com> Date: Tue, 23 Apr 2024 07:27:51 +0800 Subject: [PATCH] Expose the method of boolean query (#243) --- src/lib.rs | 3 ++- src/query.rs | 61 +++++++++++++++++++++++++++++++++++++++++- tantivy/tantivy.pyi | 11 ++++++-- tests/tantivy_test.py | 62 +++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 131 insertions(+), 6 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 2bf9e3e..47befe0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,7 +15,7 @@ mod snippet; use document::Document; use facet::Facet; use index::Index; -use query::Query; +use query::{Occur, Query}; use schema::Schema; use schemabuilder::SchemaBuilder; use searcher::{DocAddress, Order, SearchResult, Searcher}; @@ -87,6 +87,7 @@ fn tantivy(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_wrapped(wrap_pymodule!(query_parser_error))?; diff --git a/src/query.rs b/src/query.rs index e14d599..8d2730a 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,13 +1,55 @@ use crate::{make_term, Schema}; -use pyo3::{exceptions, prelude::*, types::PyAny, types::PyString}; +use pyo3::{exceptions, prelude::*, types::PyAny, types::PyString, types::PyTuple}; use tantivy as tv; +/// Custom Tuple struct to represent a pair of Occur and Query +/// for the BooleanQuery +struct OccurQueryPair(Occur, Query); + +impl <'source> FromPyObject<'source> for OccurQueryPair { + fn extract(ob: &'source PyAny) -> PyResult { + let tuple = ob.downcast::()?; + let occur = tuple.get_item(0)?.extract()?; + let query = tuple.get_item(1)?.extract()?; + + Ok(OccurQueryPair(occur, query)) + } +} + + +/// Tantivy's Occur +#[pyclass(frozen, module = "tantivy.tantivy")] +#[derive(Clone)] +pub enum Occur { + Must, + Should, + MustNot, +} + +impl From for tv::query::Occur { + fn from(occur: Occur) -> tv::query::Occur { + match occur { + Occur::Must => tv::query::Occur::Must, + Occur::Should => tv::query::Occur::Should, + Occur::MustNot => tv::query::Occur::MustNot, + } + } +} + /// Tantivy's Query #[pyclass(frozen, module = "tantivy.tantivy")] pub(crate) struct Query { pub(crate) inner: Box, } +impl Clone for Query { + fn clone(&self) -> Self { + Query { + inner: self.inner.box_clone(), + } + } +} + impl Query { pub(crate) fn get(&self) -> &dyn tv::query::Query { &self.inner @@ -91,4 +133,21 @@ impl Query { inner: Box::new(inner), }) } + + #[staticmethod] + #[pyo3(signature = (subqueries))] + pub(crate) fn boolean_query( + subqueries: Vec<(Occur, Query)> + ) -> PyResult { + let dyn_subqueries = subqueries + .into_iter() + .map(|(occur, query)| (occur.into(), query.inner.box_clone())) + .collect::>(); + + let inner = tv::query::BooleanQuery::from(dyn_subqueries); + + Ok(Query { + inner: Box::new(inner), + }) + } } diff --git a/tantivy/tantivy.pyi b/tantivy/tantivy.pyi index db95291..466a744 100644 --- a/tantivy/tantivy.pyi +++ b/tantivy/tantivy.pyi @@ -1,6 +1,6 @@ import datetime from enum import Enum -from typing import Any, Optional +from typing import Any, Optional, Sequence class Schema: @@ -187,6 +187,10 @@ class Document: def get_all(self, field_name: str) -> list[Any]: pass +class Occur(Enum): + Must = 1 + Should = 2 + MustNot = 3 class Query: @staticmethod @@ -200,7 +204,10 @@ class Query: @staticmethod def fuzzy_term_query(schema: Schema, field_name: str, text: str, distance: int = 1, transposition_cost_one: bool = True, prefix = False) -> Query: pass - + + @staticmethod + def boolean_query(subqueries: Sequence[tuple[Occur, Query]]) -> Query: + pass class Order(Enum): Asc = 1 diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index f8300d0..90f3b63 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -8,7 +8,7 @@ import pytest import tantivy from conftest import schema, schema_numeric_fields -from tantivy import Document, Index, SchemaBuilder, SnippetGenerator, Query +from tantivy import Document, Index, SchemaBuilder, SnippetGenerator, Query, Occur class TestClass(object): @@ -819,4 +819,62 @@ class TestQuery(object): titles.update(index.searcher().doc(doc_address)["title"]) assert titles == {"Frankenstein", "The Modern Prometheus"} - + def test_boolean_query(self, ram_index): + index = ram_index + query1 = Query.fuzzy_term_query(index.schema, "title", "ice") + query2 = Query.fuzzy_term_query(index.schema, "title", "mna") + query = Query.boolean_query([ + (Occur.Must, query1), + (Occur.Must, query2) + ]) + + # no document should match both queries + result = index.searcher().search(query, 10) + assert len(result.hits) == 0 + + query = Query.boolean_query([ + (Occur.Should, query1), + (Occur.Should, query2) + ]) + + # two documents should match, one for each query + result = index.searcher().search(query, 10) + assert len(result.hits) == 2 + + titles = set() + for _, doc_address in result.hits: + titles.update(index.searcher().doc(doc_address)["title"]) + assert ( + "The Old Man and the Sea" in titles and + "Of Mice and Men" in titles + ) + + query = Query.boolean_query([ + (Occur.MustNot, query1), + (Occur.Must, query1) + ]) + + # must not should take precedence over must + result = index.searcher().search(query, 10) + assert len(result.hits) == 0 + + query = Query.boolean_query(( + (Occur.Should, query1), + (Occur.Should, query2) + )) + + # the Vec signature should fit the tuple signature + result = index.searcher().search(query, 10) + assert len(result.hits) == 2 + + # test invalid queries + with pytest.raises(ValueError, match = "expected tuple of length 2, but got tuple of length 3"): + Query.boolean_query([ + (Occur.Must, Occur.Must, query1), + ]) + + # test swapping the order of the tuple + with pytest.raises(TypeError, match = r"'Query' object cannot be converted to 'Occur'"): + Query.boolean_query([ + (query1, Occur.Must), + ]) \ No newline at end of file