diff --git a/src/searcher.rs b/src/searcher.rs index bf92dd0..243c2b4 100644 --- a/src/searcher.rs +++ b/src/searcher.rs @@ -2,11 +2,12 @@ use crate::document::Document; use crate::query::Query; -use crate::{to_pyerr, get_field}; +use crate::{to_pyerr}; +use pyo3::exceptions::ValueError; use pyo3::prelude::*; use pyo3::PyObjectProtocol; use tantivy as tv; -use tantivy::collector::{MultiCollector, Count, TopDocs}; +use tantivy::collector::{Count, MultiCollector, TopDocs}; /// Tantivy's Searcher class /// @@ -17,12 +18,29 @@ pub(crate) struct Searcher { pub(crate) schema: tv::schema::Schema, } -const SORT_BY: &str = ""; - #[pyclass] +/// Object holding a results successful search. pub(crate) struct SearchResult { - pub(crate) hits: Vec<(PyObject, DocAddress)>, - pub(crate) count: Option + hits: Vec<(PyObject, DocAddress)>, + #[pyo3(get)] + /// How many documents matched the query. Only available if `count` was set + /// to true during the search. + count: Option, +} + +#[pymethods] +impl SearchResult { + #[getter] + /// The list of tuples that contains the scores and DocAddress of the + /// search results. + fn hits(&self, py: Python) -> PyResult> { + let ret: Vec<(PyObject, DocAddress)> = self + .hits + .iter() + .map(|(obj, address)| (obj.clone_ref(py), address.clone())) + .collect(); + Ok(ret) + } } #[pymethods] @@ -31,29 +49,23 @@ impl Searcher { /// /// Args: /// query (Query): The query that will be used for the search. - /// collector (Collector): A collector that determines how the search - /// results will be collected. Only the TopDocs collector is - /// supported for now. + /// limit (int, optional): The maximum number of search results to + /// return. Defaults to 10. + /// count (bool, optional): Should the number of documents that match + /// the query be returned as well. Defaults to true. /// - /// Returns a list of tuples that contains the scores and DocAddress of the - /// search results. + /// Returns `SearchResult` object. /// /// Raises a ValueError if there was an error with the search. - #[args(limit = 10, sort_by = "SORT_BY", count = true)] + #[args(limit = 10, count = true)] fn search( &self, py: Python, query: &Query, limit: usize, count: bool, - sort_by: &str, ) -> PyResult { - let field = match sort_by { - "" => None, - field_name => Some(get_field(&self.schema, field_name)?) - }; - - let mut multicollector = tv::collector::MultiCollector::new(); + let mut multicollector = MultiCollector::new(); let count_handle = if count { Some(multicollector.add_collector(Count)) @@ -61,44 +73,27 @@ impl Searcher { None }; + let (mut multifruit, hits) = { + let collector = TopDocs::with_limit(limit); + let top_docs_handle = multicollector.add_collector(collector); + let ret = self.inner.search(&query.inner, &multicollector); - let (mut multifruit, hits) = match field { - Some(f) => { - let collector = tv::collector::TopDocs::with_limit(limit).order_by_u64_field(f); - let top_docs_handle = multicollector.add_collector(collector); - let ret = self.inner.search(&query.inner, &multicollector); - - match ret { - Ok(mut r) => { - let top_docs = top_docs_handle.extract(&mut r); - let result: Vec<(PyObject, DocAddress)> = - top_docs.iter().map(|(f, d)| ((*f).into_py(py), DocAddress::from(d))).collect(); - (r, result) - } - Err(e) => return Err(exceptions::ValueError::py_err(e.to_string())), - } - - }, - None => { - let collector = tv::collector::TopDocs::with_limit(limit); - let top_docs_handle = multicollector.add_collector(collector); - let ret = self.inner.search(&query.inner, &multicollector); - - match ret { - Ok(mut r) => { - let top_docs = top_docs_handle.extract(&mut r); - let result: Vec<(PyObject, DocAddress)> = - top_docs.iter().map(|(f, d)| ((*f).into_py(py), DocAddress::from(d))).collect(); - (r, result) - } - Err(e) => return Err(exceptions::ValueError::py_err(e.to_string())), + match ret { + Ok(mut r) => { + let top_docs = top_docs_handle.extract(&mut r); + let result: Vec<(PyObject, DocAddress)> = top_docs + .iter() + .map(|(f, d)| ((*f).into_py(py), DocAddress::from(d))) + .collect(); + (r, result) } + Err(e) => return Err(ValueError::py_err(e.to_string())), } }; let count = match count_handle { Some(h) => Some(h.extract(&mut multifruit)), - None => None + None => None, }; Ok(SearchResult { hits, count }) @@ -133,6 +128,7 @@ impl Searcher { /// The id used for the segment is actually an ordinal in the list of segment /// hold by a Searcher. #[pyclass] +#[derive(Clone)] pub(crate) struct DocAddress { pub(crate) segment_ord: tv::SegmentLocalId, pub(crate) doc: tv::DocId, diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 3530298..91e4c8b 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -77,7 +77,7 @@ class TestClass(object): query = index.parse_query("sea whale", ["title", "body"]) result = index.searcher().search(query, 10) - assert len(result) == 1 + assert len(result.hits) == 1 def test_simple_search_after_reuse(self, dir_index): index_dir, _ = dir_index @@ -85,15 +85,15 @@ class TestClass(object): query = index.parse_query("sea whale", ["title", "body"]) result = index.searcher().search(query, 10) - assert len(result) == 1 + assert len(result.hits) == 1 def test_simple_search_in_ram(self, ram_index): index = ram_index query = index.parse_query("sea whale", ["title", "body"]) result = index.searcher().search(query, 10) - assert len(result) == 1 - _, doc_address = result[0] + assert len(result.hits) == 1 + _, doc_address = result.hits[0] searched_doc = index.searcher().doc(doc_address) assert searched_doc["title"] == ["The Old Man and the Sea"] @@ -105,12 +105,12 @@ class TestClass(object): result = searcher.search(query, 10) # summer isn't present - assert len(result) == 0 + assert len(result.hits) == 0 query = index.parse_query("title:men AND body:winter", ["title", "body"]) result = searcher.search(query) - assert len(result) == 1 + assert len(result.hits) == 1 def test_and_query_parser_default_fields(self, ram_index): query = ram_index.parse_query("winter", default_field_names=["title"]) @@ -131,40 +131,11 @@ class TestClass(object): with pytest.raises(ValueError): index.parse_query("bod:men", ["title", "body"]) - def test_sort_by_search(self): - schema = ( - SchemaBuilder() - .add_text_field("message", stored=True) - .add_unsigned_field("timestamp", fast="single", stored=True) - .build() - ) - index = Index(schema) - writer = index.writer() - doc = Document() - doc.add_text("message", "Test message") - doc.add_unsigned("timestamp", 1569954264) - writer.add_document(doc) - - doc = Document() - doc.add_text("message", "Another test message") - doc.add_unsigned("timestamp", 1569954280) - - writer.add_document(doc) - - writer.commit() - index.reload() - - query = index.parse_query("test") - result = index.searcher().search(query, 10, sort_by="timestamp") - # assert result[0][0] == first_doc["timestamp"] - # assert result[1][0] == second_doc["timestamp"] - - class TestUpdateClass(object): def test_delete_update(self, ram_index): query = ram_index.parse_query("Frankenstein", ["title"]) result = ram_index.searcher().search(query, 10) - assert len(result) == 1 + assert len(result.hits) == 1 writer = ram_index.writer() @@ -179,7 +150,7 @@ class TestUpdateClass(object): ram_index.reload() result = ram_index.searcher().search(query) - assert len(result) == 0 + assert len(result.hits) == 0 PATH_TO_INDEX = "tests/test_index/"