Expose the method of boolean query (#243)

master
alex-au-922 2024-04-23 07:27:51 +08:00 committed by GitHub
parent 9fa82ef29c
commit 1d61b96ffc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 131 additions and 6 deletions

View File

@ -15,7 +15,7 @@ mod snippet;
use document::Document; use document::Document;
use facet::Facet; use facet::Facet;
use index::Index; use index::Index;
use query::Query; use query::{Occur, Query};
use schema::Schema; use schema::Schema;
use schemabuilder::SchemaBuilder; use schemabuilder::SchemaBuilder;
use searcher::{DocAddress, Order, SearchResult, Searcher}; use searcher::{DocAddress, Order, SearchResult, Searcher};
@ -87,6 +87,7 @@ fn tantivy(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Query>()?; m.add_class::<Query>()?;
m.add_class::<Snippet>()?; m.add_class::<Snippet>()?;
m.add_class::<SnippetGenerator>()?; m.add_class::<SnippetGenerator>()?;
m.add_class::<Occur>()?;
m.add_wrapped(wrap_pymodule!(query_parser_error))?; m.add_wrapped(wrap_pymodule!(query_parser_error))?;

View File

@ -1,13 +1,55 @@
use crate::{make_term, Schema}; use crate::{make_term, Schema};
use pyo3::{exceptions, prelude::*, types::PyAny, types::PyString}; use pyo3::{exceptions, prelude::*, types::PyAny, types::PyString, types::PyTuple};
use tantivy as tv; use tantivy as tv;
/// Custom Tuple struct to represent a pair of Occur and Query
/// for the BooleanQuery
struct OccurQueryPair(Occur, Query);
impl <'source> FromPyObject<'source> for OccurQueryPair {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let tuple = ob.downcast::<PyTuple>()?;
let occur = tuple.get_item(0)?.extract()?;
let query = tuple.get_item(1)?.extract()?;
Ok(OccurQueryPair(occur, query))
}
}
/// Tantivy's Occur
#[pyclass(frozen, module = "tantivy.tantivy")]
#[derive(Clone)]
pub enum Occur {
Must,
Should,
MustNot,
}
impl From<Occur> for tv::query::Occur {
fn from(occur: Occur) -> tv::query::Occur {
match occur {
Occur::Must => tv::query::Occur::Must,
Occur::Should => tv::query::Occur::Should,
Occur::MustNot => tv::query::Occur::MustNot,
}
}
}
/// Tantivy's Query /// Tantivy's Query
#[pyclass(frozen, module = "tantivy.tantivy")] #[pyclass(frozen, module = "tantivy.tantivy")]
pub(crate) struct Query { pub(crate) struct Query {
pub(crate) inner: Box<dyn tv::query::Query>, pub(crate) inner: Box<dyn tv::query::Query>,
} }
impl Clone for Query {
fn clone(&self) -> Self {
Query {
inner: self.inner.box_clone(),
}
}
}
impl Query { impl Query {
pub(crate) fn get(&self) -> &dyn tv::query::Query { pub(crate) fn get(&self) -> &dyn tv::query::Query {
&self.inner &self.inner
@ -91,4 +133,21 @@ impl Query {
inner: Box::new(inner), inner: Box::new(inner),
}) })
} }
#[staticmethod]
#[pyo3(signature = (subqueries))]
pub(crate) fn boolean_query(
subqueries: Vec<(Occur, Query)>
) -> PyResult<Query> {
let dyn_subqueries = subqueries
.into_iter()
.map(|(occur, query)| (occur.into(), query.inner.box_clone()))
.collect::<Vec<_>>();
let inner = tv::query::BooleanQuery::from(dyn_subqueries);
Ok(Query {
inner: Box::new(inner),
})
}
} }

View File

@ -1,6 +1,6 @@
import datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Optional from typing import Any, Optional, Sequence
class Schema: class Schema:
@ -187,6 +187,10 @@ class Document:
def get_all(self, field_name: str) -> list[Any]: def get_all(self, field_name: str) -> list[Any]:
pass pass
class Occur(Enum):
Must = 1
Should = 2
MustNot = 3
class Query: class Query:
@staticmethod @staticmethod
@ -200,7 +204,10 @@ class Query:
@staticmethod @staticmethod
def fuzzy_term_query(schema: Schema, field_name: str, text: str, distance: int = 1, transposition_cost_one: bool = True, prefix = False) -> Query: def fuzzy_term_query(schema: Schema, field_name: str, text: str, distance: int = 1, transposition_cost_one: bool = True, prefix = False) -> Query:
pass pass
@staticmethod
def boolean_query(subqueries: Sequence[tuple[Occur, Query]]) -> Query:
pass
class Order(Enum): class Order(Enum):
Asc = 1 Asc = 1

View File

@ -8,7 +8,7 @@ import pytest
import tantivy import tantivy
from conftest import schema, schema_numeric_fields from conftest import schema, schema_numeric_fields
from tantivy import Document, Index, SchemaBuilder, SnippetGenerator, Query from tantivy import Document, Index, SchemaBuilder, SnippetGenerator, Query, Occur
class TestClass(object): class TestClass(object):
@ -819,4 +819,62 @@ class TestQuery(object):
titles.update(index.searcher().doc(doc_address)["title"]) titles.update(index.searcher().doc(doc_address)["title"])
assert titles == {"Frankenstein", "The Modern Prometheus"} assert titles == {"Frankenstein", "The Modern Prometheus"}
def test_boolean_query(self, ram_index):
index = ram_index
query1 = Query.fuzzy_term_query(index.schema, "title", "ice")
query2 = Query.fuzzy_term_query(index.schema, "title", "mna")
query = Query.boolean_query([
(Occur.Must, query1),
(Occur.Must, query2)
])
# no document should match both queries
result = index.searcher().search(query, 10)
assert len(result.hits) == 0
query = Query.boolean_query([
(Occur.Should, query1),
(Occur.Should, query2)
])
# two documents should match, one for each query
result = index.searcher().search(query, 10)
assert len(result.hits) == 2
titles = set()
for _, doc_address in result.hits:
titles.update(index.searcher().doc(doc_address)["title"])
assert (
"The Old Man and the Sea" in titles and
"Of Mice and Men" in titles
)
query = Query.boolean_query([
(Occur.MustNot, query1),
(Occur.Must, query1)
])
# must not should take precedence over must
result = index.searcher().search(query, 10)
assert len(result.hits) == 0
query = Query.boolean_query((
(Occur.Should, query1),
(Occur.Should, query2)
))
# the Vec signature should fit the tuple signature
result = index.searcher().search(query, 10)
assert len(result.hits) == 2
# test invalid queries
with pytest.raises(ValueError, match = "expected tuple of length 2, but got tuple of length 3"):
Query.boolean_query([
(Occur.Must, Occur.Must, query1),
])
# test swapping the order of the tuple
with pytest.raises(TypeError, match = r"'Query' object cannot be converted to 'Occur'"):
Query.boolean_query([
(query1, Occur.Must),
])