Expose the method of boolean query (#243)
parent
9fa82ef29c
commit
1d61b96ffc
|
@ -15,7 +15,7 @@ mod snippet;
|
|||
use document::Document;
|
||||
use facet::Facet;
|
||||
use index::Index;
|
||||
use query::Query;
|
||||
use query::{Occur, Query};
|
||||
use schema::Schema;
|
||||
use schemabuilder::SchemaBuilder;
|
||||
use searcher::{DocAddress, Order, SearchResult, Searcher};
|
||||
|
@ -87,6 +87,7 @@ fn tantivy(_py: Python, m: &PyModule) -> PyResult<()> {
|
|||
m.add_class::<Query>()?;
|
||||
m.add_class::<Snippet>()?;
|
||||
m.add_class::<SnippetGenerator>()?;
|
||||
m.add_class::<Occur>()?;
|
||||
|
||||
m.add_wrapped(wrap_pymodule!(query_parser_error))?;
|
||||
|
||||
|
|
61
src/query.rs
61
src/query.rs
|
@ -1,13 +1,55 @@
|
|||
use crate::{make_term, Schema};
|
||||
use pyo3::{exceptions, prelude::*, types::PyAny, types::PyString};
|
||||
use pyo3::{exceptions, prelude::*, types::PyAny, types::PyString, types::PyTuple};
|
||||
use tantivy as tv;
|
||||
|
||||
/// Custom Tuple struct to represent a pair of Occur and Query
|
||||
/// for the BooleanQuery
|
||||
struct OccurQueryPair(Occur, Query);
|
||||
|
||||
impl <'source> FromPyObject<'source> for OccurQueryPair {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
let tuple = ob.downcast::<PyTuple>()?;
|
||||
let occur = tuple.get_item(0)?.extract()?;
|
||||
let query = tuple.get_item(1)?.extract()?;
|
||||
|
||||
Ok(OccurQueryPair(occur, query))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Tantivy's Occur
|
||||
#[pyclass(frozen, module = "tantivy.tantivy")]
|
||||
#[derive(Clone)]
|
||||
pub enum Occur {
|
||||
Must,
|
||||
Should,
|
||||
MustNot,
|
||||
}
|
||||
|
||||
impl From<Occur> for tv::query::Occur {
|
||||
fn from(occur: Occur) -> tv::query::Occur {
|
||||
match occur {
|
||||
Occur::Must => tv::query::Occur::Must,
|
||||
Occur::Should => tv::query::Occur::Should,
|
||||
Occur::MustNot => tv::query::Occur::MustNot,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tantivy's Query
|
||||
#[pyclass(frozen, module = "tantivy.tantivy")]
|
||||
pub(crate) struct Query {
|
||||
pub(crate) inner: Box<dyn tv::query::Query>,
|
||||
}
|
||||
|
||||
impl Clone for Query {
|
||||
fn clone(&self) -> Self {
|
||||
Query {
|
||||
inner: self.inner.box_clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Query {
|
||||
pub(crate) fn get(&self) -> &dyn tv::query::Query {
|
||||
&self.inner
|
||||
|
@ -91,4 +133,21 @@ impl Query {
|
|||
inner: Box::new(inner),
|
||||
})
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
#[pyo3(signature = (subqueries))]
|
||||
pub(crate) fn boolean_query(
|
||||
subqueries: Vec<(Occur, Query)>
|
||||
) -> PyResult<Query> {
|
||||
let dyn_subqueries = subqueries
|
||||
.into_iter()
|
||||
.map(|(occur, query)| (occur.into(), query.inner.box_clone()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let inner = tv::query::BooleanQuery::from(dyn_subqueries);
|
||||
|
||||
Ok(Query {
|
||||
inner: Box::new(inner),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
|
||||
class Schema:
|
||||
|
@ -187,6 +187,10 @@ class Document:
|
|||
def get_all(self, field_name: str) -> list[Any]:
|
||||
pass
|
||||
|
||||
class Occur(Enum):
|
||||
Must = 1
|
||||
Should = 2
|
||||
MustNot = 3
|
||||
|
||||
class Query:
|
||||
@staticmethod
|
||||
|
@ -201,6 +205,9 @@ class Query:
|
|||
def fuzzy_term_query(schema: Schema, field_name: str, text: str, distance: int = 1, transposition_cost_one: bool = True, prefix = False) -> Query:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def boolean_query(subqueries: Sequence[tuple[Occur, Query]]) -> Query:
|
||||
pass
|
||||
|
||||
class Order(Enum):
|
||||
Asc = 1
|
||||
|
|
|
@ -8,7 +8,7 @@ import pytest
|
|||
|
||||
import tantivy
|
||||
from conftest import schema, schema_numeric_fields
|
||||
from tantivy import Document, Index, SchemaBuilder, SnippetGenerator, Query
|
||||
from tantivy import Document, Index, SchemaBuilder, SnippetGenerator, Query, Occur
|
||||
|
||||
|
||||
class TestClass(object):
|
||||
|
@ -819,4 +819,62 @@ class TestQuery(object):
|
|||
titles.update(index.searcher().doc(doc_address)["title"])
|
||||
assert titles == {"Frankenstein", "The Modern Prometheus"}
|
||||
|
||||
def test_boolean_query(self, ram_index):
|
||||
index = ram_index
|
||||
query1 = Query.fuzzy_term_query(index.schema, "title", "ice")
|
||||
query2 = Query.fuzzy_term_query(index.schema, "title", "mna")
|
||||
query = Query.boolean_query([
|
||||
(Occur.Must, query1),
|
||||
(Occur.Must, query2)
|
||||
])
|
||||
|
||||
# no document should match both queries
|
||||
result = index.searcher().search(query, 10)
|
||||
assert len(result.hits) == 0
|
||||
|
||||
query = Query.boolean_query([
|
||||
(Occur.Should, query1),
|
||||
(Occur.Should, query2)
|
||||
])
|
||||
|
||||
# two documents should match, one for each query
|
||||
result = index.searcher().search(query, 10)
|
||||
assert len(result.hits) == 2
|
||||
|
||||
titles = set()
|
||||
for _, doc_address in result.hits:
|
||||
titles.update(index.searcher().doc(doc_address)["title"])
|
||||
assert (
|
||||
"The Old Man and the Sea" in titles and
|
||||
"Of Mice and Men" in titles
|
||||
)
|
||||
|
||||
query = Query.boolean_query([
|
||||
(Occur.MustNot, query1),
|
||||
(Occur.Must, query1)
|
||||
])
|
||||
|
||||
# must not should take precedence over must
|
||||
result = index.searcher().search(query, 10)
|
||||
assert len(result.hits) == 0
|
||||
|
||||
query = Query.boolean_query((
|
||||
(Occur.Should, query1),
|
||||
(Occur.Should, query2)
|
||||
))
|
||||
|
||||
# the Vec signature should fit the tuple signature
|
||||
result = index.searcher().search(query, 10)
|
||||
assert len(result.hits) == 2
|
||||
|
||||
# test invalid queries
|
||||
with pytest.raises(ValueError, match = "expected tuple of length 2, but got tuple of length 3"):
|
||||
Query.boolean_query([
|
||||
(Occur.Must, Occur.Must, query1),
|
||||
])
|
||||
|
||||
# test swapping the order of the tuple
|
||||
with pytest.raises(TypeError, match = r"'Query' object cannot be converted to 'Occur'"):
|
||||
Query.boolean_query([
|
||||
(query1, Occur.Must),
|
||||
])
|
Loading…
Reference in New Issue