Merge remote-tracking branch 'origin/search_api_simplification'
commit
46be799248
11
src/index.rs
11
src/index.rs
|
@ -8,7 +8,7 @@ use crate::document::{extract_value, Document};
|
||||||
use crate::query::Query;
|
use crate::query::Query;
|
||||||
use crate::schema::Schema;
|
use crate::schema::Schema;
|
||||||
use crate::searcher::Searcher;
|
use crate::searcher::Searcher;
|
||||||
use crate::to_pyerr;
|
use crate::{to_pyerr, get_field};
|
||||||
use tantivy as tv;
|
use tantivy as tv;
|
||||||
use tantivy::directory::MmapDirectory;
|
use tantivy::directory::MmapDirectory;
|
||||||
use tantivy::schema::{NamedFieldDocument, Term, Value};
|
use tantivy::schema::{NamedFieldDocument, Term, Value};
|
||||||
|
@ -111,13 +111,7 @@ impl IndexWriter {
|
||||||
field_name: &str,
|
field_name: &str,
|
||||||
field_value: &PyAny,
|
field_value: &PyAny,
|
||||||
) -> PyResult<u64> {
|
) -> PyResult<u64> {
|
||||||
let field = self.schema.get_field(field_name).ok_or_else(|| {
|
let field = get_field(&self.schema, field_name)?;
|
||||||
exceptions::ValueError::py_err(format!(
|
|
||||||
"Field `{}` is not defined in the schema.",
|
|
||||||
field_name
|
|
||||||
))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let value = extract_value(field_value)?;
|
let value = extract_value(field_value)?;
|
||||||
let term = match value {
|
let term = match value {
|
||||||
Value::Str(text) => Term::from_field_text(field, &text),
|
Value::Str(text) => Term::from_field_text(field, &text),
|
||||||
|
@ -274,6 +268,7 @@ impl Index {
|
||||||
fn searcher(&self) -> Searcher {
|
fn searcher(&self) -> Searcher {
|
||||||
Searcher {
|
Searcher {
|
||||||
inner: self.reader.searcher(),
|
inner: self.reader.searcher(),
|
||||||
|
schema: self.index.schema(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
15
src/lib.rs
15
src/lib.rs
|
@ -1,5 +1,6 @@
|
||||||
use pyo3::exceptions;
|
use pyo3::exceptions;
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
|
use tantivy as tv;
|
||||||
|
|
||||||
mod document;
|
mod document;
|
||||||
mod facet;
|
mod facet;
|
||||||
|
@ -14,7 +15,7 @@ use facet::Facet;
|
||||||
use index::Index;
|
use index::Index;
|
||||||
use schema::Schema;
|
use schema::Schema;
|
||||||
use schemabuilder::SchemaBuilder;
|
use schemabuilder::SchemaBuilder;
|
||||||
use searcher::{DocAddress, Searcher, TopDocs};
|
use searcher::{DocAddress, Searcher};
|
||||||
|
|
||||||
/// Python bindings for the search engine library Tantivy.
|
/// Python bindings for the search engine library Tantivy.
|
||||||
///
|
///
|
||||||
|
@ -76,7 +77,6 @@ fn tantivy(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||||
m.add_class::<Document>()?;
|
m.add_class::<Document>()?;
|
||||||
m.add_class::<Index>()?;
|
m.add_class::<Index>()?;
|
||||||
m.add_class::<DocAddress>()?;
|
m.add_class::<DocAddress>()?;
|
||||||
m.add_class::<TopDocs>()?;
|
|
||||||
m.add_class::<Facet>()?;
|
m.add_class::<Facet>()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -84,3 +84,14 @@ fn tantivy(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||||
pub(crate) fn to_pyerr<E: ToString>(err: E) -> PyErr {
|
pub(crate) fn to_pyerr<E: ToString>(err: E) -> PyErr {
|
||||||
exceptions::ValueError::py_err(err.to_string())
|
exceptions::ValueError::py_err(err.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn get_field(schema: &tv::schema::Schema, field_name: &str) -> PyResult<tv::schema::Field> {
|
||||||
|
let field = schema.get_field(field_name).ok_or_else(|| {
|
||||||
|
exceptions::ValueError::py_err(format!(
|
||||||
|
"Field `{}` is not defined in the schema.",
|
||||||
|
field_name
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(field)
|
||||||
|
}
|
||||||
|
|
111
src/searcher.rs
111
src/searcher.rs
|
@ -2,10 +2,12 @@
|
||||||
|
|
||||||
use crate::document::Document;
|
use crate::document::Document;
|
||||||
use crate::query::Query;
|
use crate::query::Query;
|
||||||
use crate::to_pyerr;
|
use crate::{to_pyerr};
|
||||||
|
use pyo3::exceptions::ValueError;
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::{exceptions, PyObjectProtocol};
|
use pyo3::PyObjectProtocol;
|
||||||
use tantivy as tv;
|
use tantivy as tv;
|
||||||
|
use tantivy::collector::{Count, MultiCollector, TopDocs};
|
||||||
|
|
||||||
/// Tantivy's Searcher class
|
/// Tantivy's Searcher class
|
||||||
///
|
///
|
||||||
|
@ -13,6 +15,32 @@ use tantivy as tv;
|
||||||
#[pyclass]
|
#[pyclass]
|
||||||
pub(crate) struct Searcher {
|
pub(crate) struct Searcher {
|
||||||
pub(crate) inner: tv::LeasedItem<tv::Searcher>,
|
pub(crate) inner: tv::LeasedItem<tv::Searcher>,
|
||||||
|
pub(crate) schema: tv::schema::Schema,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyclass]
|
||||||
|
/// Object holding a results successful search.
|
||||||
|
pub(crate) struct SearchResult {
|
||||||
|
hits: Vec<(PyObject, DocAddress)>,
|
||||||
|
#[pyo3(get)]
|
||||||
|
/// How many documents matched the query. Only available if `count` was set
|
||||||
|
/// to true during the search.
|
||||||
|
count: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pymethods]
|
||||||
|
impl SearchResult {
|
||||||
|
#[getter]
|
||||||
|
/// The list of tuples that contains the scores and DocAddress of the
|
||||||
|
/// search results.
|
||||||
|
fn hits(&self, py: Python) -> PyResult<Vec<(PyObject, DocAddress)>> {
|
||||||
|
let ret: Vec<(PyObject, DocAddress)> = self
|
||||||
|
.hits
|
||||||
|
.iter()
|
||||||
|
.map(|(obj, address)| (obj.clone_ref(py), address.clone()))
|
||||||
|
.collect();
|
||||||
|
Ok(ret)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
|
@ -21,28 +49,54 @@ impl Searcher {
|
||||||
///
|
///
|
||||||
/// Args:
|
/// Args:
|
||||||
/// query (Query): The query that will be used for the search.
|
/// query (Query): The query that will be used for the search.
|
||||||
/// collector (Collector): A collector that determines how the search
|
/// limit (int, optional): The maximum number of search results to
|
||||||
/// results will be collected. Only the TopDocs collector is
|
/// return. Defaults to 10.
|
||||||
/// supported for now.
|
/// count (bool, optional): Should the number of documents that match
|
||||||
|
/// the query be returned as well. Defaults to true.
|
||||||
///
|
///
|
||||||
/// Returns a list of tuples that contains the scores and DocAddress of the
|
/// Returns `SearchResult` object.
|
||||||
/// search results.
|
|
||||||
///
|
///
|
||||||
/// Raises a ValueError if there was an error with the search.
|
/// Raises a ValueError if there was an error with the search.
|
||||||
|
#[args(limit = 10, count = true)]
|
||||||
fn search(
|
fn search(
|
||||||
&self,
|
&self,
|
||||||
|
py: Python,
|
||||||
query: &Query,
|
query: &Query,
|
||||||
collector: &mut TopDocs,
|
limit: usize,
|
||||||
) -> PyResult<Vec<(f32, DocAddress)>> {
|
count: bool,
|
||||||
let ret = self.inner.search(&query.inner, &collector.inner);
|
) -> PyResult<SearchResult> {
|
||||||
match ret {
|
let mut multicollector = MultiCollector::new();
|
||||||
Ok(r) => {
|
|
||||||
let result: Vec<(f32, DocAddress)> =
|
let count_handle = if count {
|
||||||
r.iter().map(|(f, d)| (*f, DocAddress::from(d))).collect();
|
Some(multicollector.add_collector(Count))
|
||||||
Ok(result)
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let (mut multifruit, hits) = {
|
||||||
|
let collector = TopDocs::with_limit(limit);
|
||||||
|
let top_docs_handle = multicollector.add_collector(collector);
|
||||||
|
let ret = self.inner.search(&query.inner, &multicollector);
|
||||||
|
|
||||||
|
match ret {
|
||||||
|
Ok(mut r) => {
|
||||||
|
let top_docs = top_docs_handle.extract(&mut r);
|
||||||
|
let result: Vec<(PyObject, DocAddress)> = top_docs
|
||||||
|
.iter()
|
||||||
|
.map(|(f, d)| ((*f).into_py(py), DocAddress::from(d)))
|
||||||
|
.collect();
|
||||||
|
(r, result)
|
||||||
|
}
|
||||||
|
Err(e) => return Err(ValueError::py_err(e.to_string())),
|
||||||
}
|
}
|
||||||
Err(e) => Err(exceptions::ValueError::py_err(e.to_string())),
|
};
|
||||||
}
|
|
||||||
|
let count = match count_handle {
|
||||||
|
Some(h) => Some(h.extract(&mut multifruit)),
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(SearchResult { hits, count })
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the overall number of documents in the index.
|
/// Returns the overall number of documents in the index.
|
||||||
|
@ -74,6 +128,7 @@ impl Searcher {
|
||||||
/// The id used for the segment is actually an ordinal in the list of segment
|
/// The id used for the segment is actually an ordinal in the list of segment
|
||||||
/// hold by a Searcher.
|
/// hold by a Searcher.
|
||||||
#[pyclass]
|
#[pyclass]
|
||||||
|
#[derive(Clone)]
|
||||||
pub(crate) struct DocAddress {
|
pub(crate) struct DocAddress {
|
||||||
pub(crate) segment_ord: tv::SegmentLocalId,
|
pub(crate) segment_ord: tv::SegmentLocalId,
|
||||||
pub(crate) doc: tv::DocId,
|
pub(crate) doc: tv::DocId,
|
||||||
|
@ -110,28 +165,6 @@ impl Into<tv::DocAddress> for &DocAddress {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The Top Score Collector keeps track of the K documents sorted by their
|
|
||||||
/// score.
|
|
||||||
///
|
|
||||||
/// Args:
|
|
||||||
/// limit (int, optional): The number of documents that the top scorer will
|
|
||||||
/// retrieve. Must be a positive integer larger than 0. Defaults to 10.
|
|
||||||
#[pyclass]
|
|
||||||
pub(crate) struct TopDocs {
|
|
||||||
inner: tv::collector::TopDocs,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[pymethods]
|
|
||||||
impl TopDocs {
|
|
||||||
#[new]
|
|
||||||
#[args(limit = 10)]
|
|
||||||
fn new(obj: &PyRawObject, limit: usize) -> PyResult<()> {
|
|
||||||
let top = tv::collector::TopDocs::with_limit(limit);
|
|
||||||
obj.init(TopDocs { inner: top });
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[pyproto]
|
#[pyproto]
|
||||||
impl PyObjectProtocol for Searcher {
|
impl PyObjectProtocol for Searcher {
|
||||||
fn __repr__(&self) -> PyResult<String> {
|
fn __repr__(&self) -> PyResult<String> {
|
||||||
|
|
|
@ -76,30 +76,24 @@ class TestClass(object):
|
||||||
_, index = dir_index
|
_, index = dir_index
|
||||||
query = index.parse_query("sea whale", ["title", "body"])
|
query = index.parse_query("sea whale", ["title", "body"])
|
||||||
|
|
||||||
top_docs = tantivy.TopDocs(10)
|
result = index.searcher().search(query, 10)
|
||||||
|
assert len(result.hits) == 1
|
||||||
result = index.searcher().search(query, top_docs)
|
|
||||||
assert len(result) == 1
|
|
||||||
|
|
||||||
def test_simple_search_after_reuse(self, dir_index):
|
def test_simple_search_after_reuse(self, dir_index):
|
||||||
index_dir, _ = dir_index
|
index_dir, _ = dir_index
|
||||||
index = Index(schema(), str(index_dir))
|
index = Index(schema(), str(index_dir))
|
||||||
query = index.parse_query("sea whale", ["title", "body"])
|
query = index.parse_query("sea whale", ["title", "body"])
|
||||||
|
|
||||||
top_docs = tantivy.TopDocs(10)
|
result = index.searcher().search(query, 10)
|
||||||
|
assert len(result.hits) == 1
|
||||||
result = index.searcher().search(query, top_docs)
|
|
||||||
assert len(result) == 1
|
|
||||||
|
|
||||||
def test_simple_search_in_ram(self, ram_index):
|
def test_simple_search_in_ram(self, ram_index):
|
||||||
index = ram_index
|
index = ram_index
|
||||||
query = index.parse_query("sea whale", ["title", "body"])
|
query = index.parse_query("sea whale", ["title", "body"])
|
||||||
|
|
||||||
top_docs = tantivy.TopDocs(10)
|
result = index.searcher().search(query, 10)
|
||||||
|
assert len(result.hits) == 1
|
||||||
result = index.searcher().search(query, top_docs)
|
_, doc_address = result.hits[0]
|
||||||
assert len(result) == 1
|
|
||||||
_, doc_address = result[0]
|
|
||||||
searched_doc = index.searcher().doc(doc_address)
|
searched_doc = index.searcher().doc(doc_address)
|
||||||
assert searched_doc["title"] == ["The Old Man and the Sea"]
|
assert searched_doc["title"] == ["The Old Man and the Sea"]
|
||||||
|
|
||||||
|
@ -107,17 +101,16 @@ class TestClass(object):
|
||||||
index = ram_index
|
index = ram_index
|
||||||
query = index.parse_query("title:men AND body:summer", default_field_names=["title", "body"])
|
query = index.parse_query("title:men AND body:summer", default_field_names=["title", "body"])
|
||||||
# look for an intersection of documents
|
# look for an intersection of documents
|
||||||
top_docs = tantivy.TopDocs(10)
|
|
||||||
searcher = index.searcher()
|
searcher = index.searcher()
|
||||||
result = searcher.search(query, top_docs)
|
result = searcher.search(query, 10)
|
||||||
|
|
||||||
# summer isn't present
|
# summer isn't present
|
||||||
assert len(result) == 0
|
assert len(result.hits) == 0
|
||||||
|
|
||||||
query = index.parse_query("title:men AND body:winter", ["title", "body"])
|
query = index.parse_query("title:men AND body:winter", ["title", "body"])
|
||||||
result = searcher.search(query, top_docs)
|
result = searcher.search(query)
|
||||||
|
|
||||||
assert len(result) == 1
|
assert len(result.hits) == 1
|
||||||
|
|
||||||
def test_and_query_parser_default_fields(self, ram_index):
|
def test_and_query_parser_default_fields(self, ram_index):
|
||||||
query = ram_index.parse_query("winter", default_field_names=["title"])
|
query = ram_index.parse_query("winter", default_field_names=["title"])
|
||||||
|
@ -138,13 +131,11 @@ class TestClass(object):
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
index.parse_query("bod:men", ["title", "body"])
|
index.parse_query("bod:men", ["title", "body"])
|
||||||
|
|
||||||
|
|
||||||
class TestUpdateClass(object):
|
class TestUpdateClass(object):
|
||||||
def test_delete_update(self, ram_index):
|
def test_delete_update(self, ram_index):
|
||||||
query = ram_index.parse_query("Frankenstein", ["title"])
|
query = ram_index.parse_query("Frankenstein", ["title"])
|
||||||
top_docs = tantivy.TopDocs(10)
|
result = ram_index.searcher().search(query, 10)
|
||||||
result = ram_index.searcher().search(query, top_docs)
|
assert len(result.hits) == 1
|
||||||
assert len(result) == 1
|
|
||||||
|
|
||||||
writer = ram_index.writer()
|
writer = ram_index.writer()
|
||||||
|
|
||||||
|
@ -158,8 +149,8 @@ class TestUpdateClass(object):
|
||||||
writer.commit()
|
writer.commit()
|
||||||
ram_index.reload()
|
ram_index.reload()
|
||||||
|
|
||||||
result = ram_index.searcher().search(query, top_docs)
|
result = ram_index.searcher().search(query)
|
||||||
assert len(result) == 0
|
assert len(result.hits) == 0
|
||||||
|
|
||||||
|
|
||||||
PATH_TO_INDEX = "tests/test_index/"
|
PATH_TO_INDEX = "tests/test_index/"
|
||||||
|
|
Loading…
Reference in New Issue