Expose the method of boolean query (#243)
parent
9fa82ef29c
commit
1d61b96ffc
|
@ -15,7 +15,7 @@ mod snippet;
|
||||||
use document::Document;
|
use document::Document;
|
||||||
use facet::Facet;
|
use facet::Facet;
|
||||||
use index::Index;
|
use index::Index;
|
||||||
use query::Query;
|
use query::{Occur, Query};
|
||||||
use schema::Schema;
|
use schema::Schema;
|
||||||
use schemabuilder::SchemaBuilder;
|
use schemabuilder::SchemaBuilder;
|
||||||
use searcher::{DocAddress, Order, SearchResult, Searcher};
|
use searcher::{DocAddress, Order, SearchResult, Searcher};
|
||||||
|
@ -87,6 +87,7 @@ fn tantivy(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||||
m.add_class::<Query>()?;
|
m.add_class::<Query>()?;
|
||||||
m.add_class::<Snippet>()?;
|
m.add_class::<Snippet>()?;
|
||||||
m.add_class::<SnippetGenerator>()?;
|
m.add_class::<SnippetGenerator>()?;
|
||||||
|
m.add_class::<Occur>()?;
|
||||||
|
|
||||||
m.add_wrapped(wrap_pymodule!(query_parser_error))?;
|
m.add_wrapped(wrap_pymodule!(query_parser_error))?;
|
||||||
|
|
||||||
|
|
61
src/query.rs
61
src/query.rs
|
@ -1,13 +1,55 @@
|
||||||
use crate::{make_term, Schema};
|
use crate::{make_term, Schema};
|
||||||
use pyo3::{exceptions, prelude::*, types::PyAny, types::PyString};
|
use pyo3::{exceptions, prelude::*, types::PyAny, types::PyString, types::PyTuple};
|
||||||
use tantivy as tv;
|
use tantivy as tv;
|
||||||
|
|
||||||
|
/// Custom Tuple struct to represent a pair of Occur and Query
|
||||||
|
/// for the BooleanQuery
|
||||||
|
struct OccurQueryPair(Occur, Query);
|
||||||
|
|
||||||
|
impl <'source> FromPyObject<'source> for OccurQueryPair {
|
||||||
|
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||||
|
let tuple = ob.downcast::<PyTuple>()?;
|
||||||
|
let occur = tuple.get_item(0)?.extract()?;
|
||||||
|
let query = tuple.get_item(1)?.extract()?;
|
||||||
|
|
||||||
|
Ok(OccurQueryPair(occur, query))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/// Tantivy's Occur
|
||||||
|
#[pyclass(frozen, module = "tantivy.tantivy")]
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum Occur {
|
||||||
|
Must,
|
||||||
|
Should,
|
||||||
|
MustNot,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Occur> for tv::query::Occur {
|
||||||
|
fn from(occur: Occur) -> tv::query::Occur {
|
||||||
|
match occur {
|
||||||
|
Occur::Must => tv::query::Occur::Must,
|
||||||
|
Occur::Should => tv::query::Occur::Should,
|
||||||
|
Occur::MustNot => tv::query::Occur::MustNot,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Tantivy's Query
|
/// Tantivy's Query
|
||||||
#[pyclass(frozen, module = "tantivy.tantivy")]
|
#[pyclass(frozen, module = "tantivy.tantivy")]
|
||||||
pub(crate) struct Query {
|
pub(crate) struct Query {
|
||||||
pub(crate) inner: Box<dyn tv::query::Query>,
|
pub(crate) inner: Box<dyn tv::query::Query>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Clone for Query {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Query {
|
||||||
|
inner: self.inner.box_clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Query {
|
impl Query {
|
||||||
pub(crate) fn get(&self) -> &dyn tv::query::Query {
|
pub(crate) fn get(&self) -> &dyn tv::query::Query {
|
||||||
&self.inner
|
&self.inner
|
||||||
|
@ -91,4 +133,21 @@ impl Query {
|
||||||
inner: Box::new(inner),
|
inner: Box::new(inner),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[staticmethod]
|
||||||
|
#[pyo3(signature = (subqueries))]
|
||||||
|
pub(crate) fn boolean_query(
|
||||||
|
subqueries: Vec<(Occur, Query)>
|
||||||
|
) -> PyResult<Query> {
|
||||||
|
let dyn_subqueries = subqueries
|
||||||
|
.into_iter()
|
||||||
|
.map(|(occur, query)| (occur.into(), query.inner.box_clone()))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let inner = tv::query::BooleanQuery::from(dyn_subqueries);
|
||||||
|
|
||||||
|
Ok(Query {
|
||||||
|
inner: Box::new(inner),
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import datetime
|
import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, Sequence
|
||||||
|
|
||||||
|
|
||||||
class Schema:
|
class Schema:
|
||||||
|
@ -187,6 +187,10 @@ class Document:
|
||||||
def get_all(self, field_name: str) -> list[Any]:
|
def get_all(self, field_name: str) -> list[Any]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class Occur(Enum):
|
||||||
|
Must = 1
|
||||||
|
Should = 2
|
||||||
|
MustNot = 3
|
||||||
|
|
||||||
class Query:
|
class Query:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -201,6 +205,9 @@ class Query:
|
||||||
def fuzzy_term_query(schema: Schema, field_name: str, text: str, distance: int = 1, transposition_cost_one: bool = True, prefix = False) -> Query:
|
def fuzzy_term_query(schema: Schema, field_name: str, text: str, distance: int = 1, transposition_cost_one: bool = True, prefix = False) -> Query:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def boolean_query(subqueries: Sequence[tuple[Occur, Query]]) -> Query:
|
||||||
|
pass
|
||||||
|
|
||||||
class Order(Enum):
|
class Order(Enum):
|
||||||
Asc = 1
|
Asc = 1
|
||||||
|
|
|
@ -8,7 +8,7 @@ import pytest
|
||||||
|
|
||||||
import tantivy
|
import tantivy
|
||||||
from conftest import schema, schema_numeric_fields
|
from conftest import schema, schema_numeric_fields
|
||||||
from tantivy import Document, Index, SchemaBuilder, SnippetGenerator, Query
|
from tantivy import Document, Index, SchemaBuilder, SnippetGenerator, Query, Occur
|
||||||
|
|
||||||
|
|
||||||
class TestClass(object):
|
class TestClass(object):
|
||||||
|
@ -819,4 +819,62 @@ class TestQuery(object):
|
||||||
titles.update(index.searcher().doc(doc_address)["title"])
|
titles.update(index.searcher().doc(doc_address)["title"])
|
||||||
assert titles == {"Frankenstein", "The Modern Prometheus"}
|
assert titles == {"Frankenstein", "The Modern Prometheus"}
|
||||||
|
|
||||||
|
def test_boolean_query(self, ram_index):
|
||||||
|
index = ram_index
|
||||||
|
query1 = Query.fuzzy_term_query(index.schema, "title", "ice")
|
||||||
|
query2 = Query.fuzzy_term_query(index.schema, "title", "mna")
|
||||||
|
query = Query.boolean_query([
|
||||||
|
(Occur.Must, query1),
|
||||||
|
(Occur.Must, query2)
|
||||||
|
])
|
||||||
|
|
||||||
|
# no document should match both queries
|
||||||
|
result = index.searcher().search(query, 10)
|
||||||
|
assert len(result.hits) == 0
|
||||||
|
|
||||||
|
query = Query.boolean_query([
|
||||||
|
(Occur.Should, query1),
|
||||||
|
(Occur.Should, query2)
|
||||||
|
])
|
||||||
|
|
||||||
|
# two documents should match, one for each query
|
||||||
|
result = index.searcher().search(query, 10)
|
||||||
|
assert len(result.hits) == 2
|
||||||
|
|
||||||
|
titles = set()
|
||||||
|
for _, doc_address in result.hits:
|
||||||
|
titles.update(index.searcher().doc(doc_address)["title"])
|
||||||
|
assert (
|
||||||
|
"The Old Man and the Sea" in titles and
|
||||||
|
"Of Mice and Men" in titles
|
||||||
|
)
|
||||||
|
|
||||||
|
query = Query.boolean_query([
|
||||||
|
(Occur.MustNot, query1),
|
||||||
|
(Occur.Must, query1)
|
||||||
|
])
|
||||||
|
|
||||||
|
# must not should take precedence over must
|
||||||
|
result = index.searcher().search(query, 10)
|
||||||
|
assert len(result.hits) == 0
|
||||||
|
|
||||||
|
query = Query.boolean_query((
|
||||||
|
(Occur.Should, query1),
|
||||||
|
(Occur.Should, query2)
|
||||||
|
))
|
||||||
|
|
||||||
|
# the Vec signature should fit the tuple signature
|
||||||
|
result = index.searcher().search(query, 10)
|
||||||
|
assert len(result.hits) == 2
|
||||||
|
|
||||||
|
# test invalid queries
|
||||||
|
with pytest.raises(ValueError, match = "expected tuple of length 2, but got tuple of length 3"):
|
||||||
|
Query.boolean_query([
|
||||||
|
(Occur.Must, Occur.Must, query1),
|
||||||
|
])
|
||||||
|
|
||||||
|
# test swapping the order of the tuple
|
||||||
|
with pytest.raises(TypeError, match = r"'Query' object cannot be converted to 'Occur'"):
|
||||||
|
Query.boolean_query([
|
||||||
|
(query1, Occur.Must),
|
||||||
|
])
|
Loading…
Reference in New Issue