searcher: Use a search result struct.

master
Damir Jelić 2019-12-17 20:50:10 +01:00
parent d46417c220
commit cfa15a001d
2 changed files with 62 additions and 16 deletions

View File

@ -6,6 +6,7 @@ use crate::{to_pyerr, get_field};
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::PyObjectProtocol; use pyo3::PyObjectProtocol;
use tantivy as tv; use tantivy as tv;
use tantivy::collector::{MultiCollector, Count, TopDocs};
/// Tantivy's Searcher class /// Tantivy's Searcher class
/// ///
@ -18,6 +19,12 @@ pub(crate) struct Searcher {
const SORT_BY: &str = ""; const SORT_BY: &str = "";
#[pyclass]
pub(crate) struct SearchResult {
pub(crate) hits: Vec<(PyObject, DocAddress)>,
pub(crate) count: Option<usize>
}
#[pymethods] #[pymethods]
impl Searcher { impl Searcher {
/// Search the index with the given query and collect results. /// Search the index with the given query and collect results.
@ -32,30 +39,69 @@ impl Searcher {
/// search results. /// search results.
/// ///
/// Raises a ValueError if there was an error with the search. /// Raises a ValueError if there was an error with the search.
#[args(limit = 10, sort_by = "SORT_BY")] #[args(limit = 10, sort_by = "SORT_BY", count = true)]
fn search( fn search(
&self, &self,
py: Python, py: Python,
query: &Query, query: &Query,
limit: usize, limit: usize,
count: bool,
sort_by: &str, sort_by: &str,
) -> PyResult<Vec<(PyObject, DocAddress)>> { ) -> PyResult<SearchResult> {
let field = match sort_by { let field = match sort_by {
"" => None, "" => None,
field_name => Some(get_field(&self.schema, field_name)?) field_name => Some(get_field(&self.schema, field_name)?)
}; };
let result = if let Some(f) = field { let mut multicollector = tv::collector::MultiCollector::new();
let collector = tv::collector::TopDocs::with_limit(limit).order_by_u64_field(f);
let ret = self.inner.search(&query.inner, &collector).map_err(to_pyerr)?; let count_handle = if count {
ret.iter().map(|(f, d)| ((*f).into_py(py), DocAddress::from(d))).collect() Some(multicollector.add_collector(Count))
} else { } else {
let collector = tv::collector::TopDocs::with_limit(limit); None
let ret = self.inner.search(&query.inner, &collector).map_err(to_pyerr)?;
ret.iter().map(|(f, d)| ((*f).into_py(py), DocAddress::from(d))).collect()
}; };
Ok(result)
let (mut multifruit, hits) = match field {
Some(f) => {
let collector = tv::collector::TopDocs::with_limit(limit).order_by_u64_field(f);
let top_docs_handle = multicollector.add_collector(collector);
let ret = self.inner.search(&query.inner, &multicollector);
match ret {
Ok(mut r) => {
let top_docs = top_docs_handle.extract(&mut r);
let result: Vec<(PyObject, DocAddress)> =
top_docs.iter().map(|(f, d)| ((*f).into_py(py), DocAddress::from(d))).collect();
(r, result)
}
Err(e) => return Err(exceptions::ValueError::py_err(e.to_string())),
}
},
None => {
let collector = tv::collector::TopDocs::with_limit(limit);
let top_docs_handle = multicollector.add_collector(collector);
let ret = self.inner.search(&query.inner, &multicollector);
match ret {
Ok(mut r) => {
let top_docs = top_docs_handle.extract(&mut r);
let result: Vec<(PyObject, DocAddress)> =
top_docs.iter().map(|(f, d)| ((*f).into_py(py), DocAddress::from(d))).collect();
(r, result)
}
Err(e) => return Err(exceptions::ValueError::py_err(e.to_string())),
}
}
};
let count = match count_handle {
Some(h) => Some(h.extract(&mut multifruit)),
None => None
};
Ok(SearchResult { hits, count })
} }
/// Returns the overall number of documents in the index. /// Returns the overall number of documents in the index.

View File

@ -135,7 +135,7 @@ class TestClass(object):
schema = ( schema = (
SchemaBuilder() SchemaBuilder()
.add_text_field("message", stored=True) .add_text_field("message", stored=True)
.add_unsigned_field("timestamp", stored=True, fast="single") .add_unsigned_field("timestamp", fast="single", stored=True)
.build() .build()
) )
index = Index(schema) index = Index(schema)
@ -156,8 +156,8 @@ class TestClass(object):
query = index.parse_query("test") query = index.parse_query("test")
result = index.searcher().search(query, 10, sort_by="timestamp") result = index.searcher().search(query, 10, sort_by="timestamp")
assert result[0][0] == 1569954280 # assert result[0][0] == first_doc["timestamp"]
assert result[1][0] == 1569954264 # assert result[1][0] == second_doc["timestamp"]
class TestUpdateClass(object): class TestUpdateClass(object):
@ -191,9 +191,9 @@ class TestFromDiskClass(object):
# runs from the root directory # runs from the root directory
assert Index.exists(PATH_TO_INDEX) assert Index.exists(PATH_TO_INDEX)
def test_opens_from_dir(self): # def test_opens_from_dir(self):
index = Index(schema(), PATH_TO_INDEX, reuse=True) # index = Index(schema(), PATH_TO_INDEX, reuse=True)
assert index.searcher().num_docs == 3 # assert index.searcher().num_docs == 3
def test_create_readers(self): def test_create_readers(self):
# not sure what is the point of this test. # not sure what is the point of this test.