diff --git a/src/index.rs b/src/index.rs index 546a1be..6a05012 100644 --- a/src/index.rs +++ b/src/index.rs @@ -8,7 +8,7 @@ use crate::document::{extract_value, Document}; use crate::query::Query; use crate::schema::Schema; use crate::searcher::Searcher; -use crate::to_pyerr; +use crate::{to_pyerr, get_field}; use tantivy as tv; use tantivy::directory::MmapDirectory; use tantivy::schema::{NamedFieldDocument, Term, Value}; @@ -111,13 +111,7 @@ impl IndexWriter { field_name: &str, field_value: &PyAny, ) -> PyResult { - let field = self.schema.get_field(field_name).ok_or_else(|| { - exceptions::ValueError::py_err(format!( - "Field `{}` is not defined in the schema.", - field_name - )) - })?; - + let field = get_field(&self.schema, field_name)?; let value = extract_value(field_value)?; let term = match value { Value::Str(text) => Term::from_field_text(field, &text), @@ -274,6 +268,7 @@ impl Index { fn searcher(&self) -> Searcher { Searcher { inner: self.reader.searcher(), + schema: self.index.schema(), } } diff --git a/src/lib.rs b/src/lib.rs index 5cb0826..77cf565 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ use pyo3::exceptions; use pyo3::prelude::*; +use tantivy as tv; mod document; mod facet; @@ -14,7 +15,7 @@ use facet::Facet; use index::Index; use schema::Schema; use schemabuilder::SchemaBuilder; -use searcher::{DocAddress, Searcher, TopDocs}; +use searcher::{DocAddress, Searcher}; /// Python bindings for the search engine library Tantivy. /// @@ -76,7 +77,6 @@ fn tantivy(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; m.add_class::()?; Ok(()) } @@ -84,3 +84,14 @@ fn tantivy(_py: Python, m: &PyModule) -> PyResult<()> { pub(crate) fn to_pyerr(err: E) -> PyErr { exceptions::ValueError::py_err(err.to_string()) } + +pub(crate) fn get_field(schema: &tv::schema::Schema, field_name: &str) -> PyResult { + let field = schema.get_field(field_name).ok_or_else(|| { + exceptions::ValueError::py_err(format!( + "Field `{}` is not defined in the schema.", + field_name + )) + })?; + + Ok(field) +} diff --git a/src/searcher.rs b/src/searcher.rs index b336ab7..243c2b4 100644 --- a/src/searcher.rs +++ b/src/searcher.rs @@ -2,10 +2,12 @@ use crate::document::Document; use crate::query::Query; -use crate::to_pyerr; +use crate::{to_pyerr}; +use pyo3::exceptions::ValueError; use pyo3::prelude::*; -use pyo3::{exceptions, PyObjectProtocol}; +use pyo3::PyObjectProtocol; use tantivy as tv; +use tantivy::collector::{Count, MultiCollector, TopDocs}; /// Tantivy's Searcher class /// @@ -13,6 +15,32 @@ use tantivy as tv; #[pyclass] pub(crate) struct Searcher { pub(crate) inner: tv::LeasedItem, + pub(crate) schema: tv::schema::Schema, +} + +#[pyclass] +/// Object holding a results successful search. +pub(crate) struct SearchResult { + hits: Vec<(PyObject, DocAddress)>, + #[pyo3(get)] + /// How many documents matched the query. Only available if `count` was set + /// to true during the search. + count: Option, +} + +#[pymethods] +impl SearchResult { + #[getter] + /// The list of tuples that contains the scores and DocAddress of the + /// search results. + fn hits(&self, py: Python) -> PyResult> { + let ret: Vec<(PyObject, DocAddress)> = self + .hits + .iter() + .map(|(obj, address)| (obj.clone_ref(py), address.clone())) + .collect(); + Ok(ret) + } } #[pymethods] @@ -21,28 +49,54 @@ impl Searcher { /// /// Args: /// query (Query): The query that will be used for the search. - /// collector (Collector): A collector that determines how the search - /// results will be collected. Only the TopDocs collector is - /// supported for now. + /// limit (int, optional): The maximum number of search results to + /// return. Defaults to 10. + /// count (bool, optional): Should the number of documents that match + /// the query be returned as well. Defaults to true. /// - /// Returns a list of tuples that contains the scores and DocAddress of the - /// search results. + /// Returns `SearchResult` object. /// /// Raises a ValueError if there was an error with the search. + #[args(limit = 10, count = true)] fn search( &self, + py: Python, query: &Query, - collector: &mut TopDocs, - ) -> PyResult> { - let ret = self.inner.search(&query.inner, &collector.inner); - match ret { - Ok(r) => { - let result: Vec<(f32, DocAddress)> = - r.iter().map(|(f, d)| (*f, DocAddress::from(d))).collect(); - Ok(result) + limit: usize, + count: bool, + ) -> PyResult { + let mut multicollector = MultiCollector::new(); + + let count_handle = if count { + Some(multicollector.add_collector(Count)) + } else { + None + }; + + let (mut multifruit, hits) = { + let collector = TopDocs::with_limit(limit); + let top_docs_handle = multicollector.add_collector(collector); + let ret = self.inner.search(&query.inner, &multicollector); + + match ret { + Ok(mut r) => { + let top_docs = top_docs_handle.extract(&mut r); + let result: Vec<(PyObject, DocAddress)> = top_docs + .iter() + .map(|(f, d)| ((*f).into_py(py), DocAddress::from(d))) + .collect(); + (r, result) + } + Err(e) => return Err(ValueError::py_err(e.to_string())), } - Err(e) => Err(exceptions::ValueError::py_err(e.to_string())), - } + }; + + let count = match count_handle { + Some(h) => Some(h.extract(&mut multifruit)), + None => None, + }; + + Ok(SearchResult { hits, count }) } /// Returns the overall number of documents in the index. @@ -74,6 +128,7 @@ impl Searcher { /// The id used for the segment is actually an ordinal in the list of segment /// hold by a Searcher. #[pyclass] +#[derive(Clone)] pub(crate) struct DocAddress { pub(crate) segment_ord: tv::SegmentLocalId, pub(crate) doc: tv::DocId, @@ -110,28 +165,6 @@ impl Into for &DocAddress { } } -/// The Top Score Collector keeps track of the K documents sorted by their -/// score. -/// -/// Args: -/// limit (int, optional): The number of documents that the top scorer will -/// retrieve. Must be a positive integer larger than 0. Defaults to 10. -#[pyclass] -pub(crate) struct TopDocs { - inner: tv::collector::TopDocs, -} - -#[pymethods] -impl TopDocs { - #[new] - #[args(limit = 10)] - fn new(obj: &PyRawObject, limit: usize) -> PyResult<()> { - let top = tv::collector::TopDocs::with_limit(limit); - obj.init(TopDocs { inner: top }); - Ok(()) - } -} - #[pyproto] impl PyObjectProtocol for Searcher { fn __repr__(&self) -> PyResult { diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 821fa45..7264f5a 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -76,30 +76,24 @@ class TestClass(object): _, index = dir_index query = index.parse_query("sea whale", ["title", "body"]) - top_docs = tantivy.TopDocs(10) - - result = index.searcher().search(query, top_docs) - assert len(result) == 1 + result = index.searcher().search(query, 10) + assert len(result.hits) == 1 def test_simple_search_after_reuse(self, dir_index): index_dir, _ = dir_index index = Index(schema(), str(index_dir)) query = index.parse_query("sea whale", ["title", "body"]) - top_docs = tantivy.TopDocs(10) - - result = index.searcher().search(query, top_docs) - assert len(result) == 1 + result = index.searcher().search(query, 10) + assert len(result.hits) == 1 def test_simple_search_in_ram(self, ram_index): index = ram_index query = index.parse_query("sea whale", ["title", "body"]) - top_docs = tantivy.TopDocs(10) - - result = index.searcher().search(query, top_docs) - assert len(result) == 1 - _, doc_address = result[0] + result = index.searcher().search(query, 10) + assert len(result.hits) == 1 + _, doc_address = result.hits[0] searched_doc = index.searcher().doc(doc_address) assert searched_doc["title"] == ["The Old Man and the Sea"] @@ -107,17 +101,16 @@ class TestClass(object): index = ram_index query = index.parse_query("title:men AND body:summer", default_field_names=["title", "body"]) # look for an intersection of documents - top_docs = tantivy.TopDocs(10) searcher = index.searcher() - result = searcher.search(query, top_docs) + result = searcher.search(query, 10) # summer isn't present - assert len(result) == 0 + assert len(result.hits) == 0 query = index.parse_query("title:men AND body:winter", ["title", "body"]) - result = searcher.search(query, top_docs) + result = searcher.search(query) - assert len(result) == 1 + assert len(result.hits) == 1 def test_and_query_parser_default_fields(self, ram_index): query = ram_index.parse_query("winter", default_field_names=["title"]) @@ -138,13 +131,11 @@ class TestClass(object): with pytest.raises(ValueError): index.parse_query("bod:men", ["title", "body"]) - class TestUpdateClass(object): def test_delete_update(self, ram_index): query = ram_index.parse_query("Frankenstein", ["title"]) - top_docs = tantivy.TopDocs(10) - result = ram_index.searcher().search(query, top_docs) - assert len(result) == 1 + result = ram_index.searcher().search(query, 10) + assert len(result.hits) == 1 writer = ram_index.writer() @@ -158,8 +149,8 @@ class TestUpdateClass(object): writer.commit() ram_index.reload() - result = ram_index.searcher().search(query, top_docs) - assert len(result) == 0 + result = ram_index.searcher().search(query) + assert len(result.hits) == 0 PATH_TO_INDEX = "tests/test_index/"