diff --git a/src/lib.rs b/src/lib.rs index 5cb0826..be75122 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,7 +14,7 @@ use facet::Facet; use index::Index; use schema::Schema; use schemabuilder::SchemaBuilder; -use searcher::{DocAddress, Searcher, TopDocs}; +use searcher::{DocAddress, Searcher}; /// Python bindings for the search engine library Tantivy. /// @@ -76,7 +76,6 @@ fn tantivy(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; m.add_class::()?; Ok(()) } diff --git a/src/searcher.rs b/src/searcher.rs index b336ab7..6871763 100644 --- a/src/searcher.rs +++ b/src/searcher.rs @@ -29,20 +29,25 @@ impl Searcher { /// search results. /// /// Raises a ValueError if there was an error with the search. + #[args(limit = 10)] fn search( &self, + py: Python, query: &Query, - collector: &mut TopDocs, - ) -> PyResult> { - let ret = self.inner.search(&query.inner, &collector.inner); + limit: usize, + ) -> PyResult> { + let collector = tv::collector::TopDocs::with_limit(limit); + let ret = self.inner.search(&query.inner, &collector); + match ret { Ok(r) => { - let result: Vec<(f32, DocAddress)> = - r.iter().map(|(f, d)| (*f, DocAddress::from(d))).collect(); + let result: Vec<(PyObject, DocAddress)> = + r.iter().map(|(f, d)| ((*f).into_py(py), DocAddress::from(d))).collect(); Ok(result) } Err(e) => Err(exceptions::ValueError::py_err(e.to_string())), } + } /// Returns the overall number of documents in the index. @@ -110,28 +115,6 @@ impl Into for &DocAddress { } } -/// The Top Score Collector keeps track of the K documents sorted by their -/// score. -/// -/// Args: -/// limit (int, optional): The number of documents that the top scorer will -/// retrieve. Must be a positive integer larger than 0. Defaults to 10. -#[pyclass] -pub(crate) struct TopDocs { - inner: tv::collector::TopDocs, -} - -#[pymethods] -impl TopDocs { - #[new] - #[args(limit = 10)] - fn new(obj: &PyRawObject, limit: usize) -> PyResult<()> { - let top = tv::collector::TopDocs::with_limit(limit); - obj.init(TopDocs { inner: top }); - Ok(()) - } -} - #[pyproto] impl PyObjectProtocol for Searcher { fn __repr__(&self) -> PyResult { diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 347baa8..c57b90a 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -76,9 +76,7 @@ class TestClass(object): _, index = dir_index query = index.parse_query("sea whale", ["title", "body"]) - top_docs = tantivy.TopDocs(10) - - result = index.searcher().search(query, top_docs) + result = index.searcher().search(query, 10) assert len(result) == 1 def test_simple_search_after_reuse(self, dir_index): @@ -86,18 +84,14 @@ class TestClass(object): index = Index(schema(), str(index_dir)) query = index.parse_query("sea whale", ["title", "body"]) - top_docs = tantivy.TopDocs(10) - - result = index.searcher().search(query, top_docs) + result = index.searcher().search(query, 10) assert len(result) == 1 def test_simple_search_in_ram(self, ram_index): index = ram_index query = index.parse_query("sea whale", ["title", "body"]) - top_docs = tantivy.TopDocs(10) - - result = index.searcher().search(query, top_docs) + result = index.searcher().search(query, 10) assert len(result) == 1 _, doc_address = result[0] searched_doc = index.searcher().doc(doc_address) @@ -107,15 +101,14 @@ class TestClass(object): index = ram_index query = index.parse_query("title:men AND body:summer", default_field_names=["title", "body"]) # look for an intersection of documents - top_docs = tantivy.TopDocs(10) searcher = index.searcher() - result = searcher.search(query, top_docs) + result = searcher.search(query, 10) # summer isn't present assert len(result) == 0 query = index.parse_query("title:men AND body:winter", ["title", "body"]) - result = searcher.search(query, top_docs) + result = searcher.search(query) assert len(result) == 1 @@ -142,8 +135,7 @@ class TestClass(object): class TestUpdateClass(object): def test_delete_update(self, ram_index): query = ram_index.parse_query("Frankenstein", ["title"]) - top_docs = tantivy.TopDocs(10) - result = ram_index.searcher().search(query, top_docs) + result = ram_index.searcher().search(query, 10) assert len(result) == 1 writer = ram_index.writer() @@ -158,7 +150,7 @@ class TestUpdateClass(object): writer.commit() ram_index.reload() - result = ram_index.searcher().search(query, top_docs) + result = ram_index.searcher().search(query) assert len(result) == 0