diff --git a/src/searcher.rs b/src/searcher.rs index 97054a0..f2ba1d3 100644 --- a/src/searcher.rs +++ b/src/searcher.rs @@ -2,7 +2,7 @@ use crate::document::Document; use crate::query::Query; -use crate::to_pyerr; +use crate::{get_field, to_pyerr}; use pyo3::exceptions::ValueError; use pyo3::prelude::*; use pyo3::PyObjectProtocol; @@ -17,16 +17,54 @@ pub(crate) struct Searcher { pub(crate) inner: tv::LeasedItem, } +#[derive(Clone)] +enum Fruit { + Score(f32), + Order(u64), +} + +impl std::fmt::Debug for Fruit { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Fruit::Score(s) => f.write_str(&format!("{}", s)), + Fruit::Order(o) => f.write_str(&format!("{}", o)), + } + } +} + +impl ToPyObject for Fruit { + fn to_object(&self, py: Python) -> PyObject { + match self { + Fruit::Score(s) => s.to_object(py), + Fruit::Order(o) => o.to_object(py), + } + } +} + #[pyclass] /// Object holding a results successful search. pub(crate) struct SearchResult { - hits: Vec<(PyObject, DocAddress)>, + hits: Vec<(Fruit, DocAddress)>, #[pyo3(get)] /// How many documents matched the query. Only available if `count` was set /// to true during the search. count: Option, } +#[pyproto] +impl PyObjectProtocol for SearchResult { + fn __repr__(&self) -> PyResult { + if let Some(count) = self.count { + Ok(format!( + "SearchResult(hits: {:?}, count: {})", + self.hits, count + )) + } else { + Ok(format!("SearchResult(hits: {:?})", self.hits)) + } + } +} + #[pymethods] impl SearchResult { #[getter] @@ -36,7 +74,7 @@ impl SearchResult { let ret: Vec<(PyObject, DocAddress)> = self .hits .iter() - .map(|(obj, address)| (obj.clone_ref(py), address.clone())) + .map(|(result, address)| (result.to_object(py), address.clone())) .collect(); Ok(ret) } @@ -51,7 +89,11 @@ impl Searcher { /// limit (int, optional): The maximum number of search results to /// return. Defaults to 10. /// count (bool, optional): Should the number of documents that match - /// the query be returned as well. Defaults to true. + /// the query be returned as well. Defaults to true. + /// order_by_field (Field, optional): A schema field that the results + /// should be ordered by. The field must be declared as a fast field + /// when building the schema. Note, this only works for unsigned + /// fields. /// /// Returns `SearchResult` object. /// @@ -59,10 +101,11 @@ impl Searcher { #[args(limit = 10, count = true)] fn search( &self, - py: Python, + _py: Python, query: &Query, limit: usize, count: bool, + order_by_field: Option<&str>, ) -> PyResult { let mut multicollector = MultiCollector::new(); @@ -73,20 +116,44 @@ impl Searcher { }; let (mut multifruit, hits) = { - let collector = TopDocs::with_limit(limit); - let top_docs_handle = multicollector.add_collector(collector); - let ret = self.inner.search(&query.inner, &multicollector); + if let Some(order_by) = order_by_field { + let field = get_field(&self.inner.index().schema(), order_by)?; + let collector = + TopDocs::with_limit(limit).order_by_u64_field(field); + let top_docs_handle = multicollector.add_collector(collector); + let ret = self.inner.search(&query.inner, &multicollector); - match ret { - Ok(mut r) => { - let top_docs = top_docs_handle.extract(&mut r); - let result: Vec<(PyObject, DocAddress)> = top_docs - .iter() - .map(|(f, d)| ((*f).into_py(py), DocAddress::from(d))) - .collect(); - (r, result) + match ret { + Ok(mut r) => { + let top_docs = top_docs_handle.extract(&mut r); + let result: Vec<(Fruit, DocAddress)> = top_docs + .iter() + .map(|(f, d)| { + (Fruit::Order(*f), DocAddress::from(d)) + }) + .collect(); + (r, result) + } + Err(e) => return Err(ValueError::py_err(e.to_string())), + } + } else { + let collector = TopDocs::with_limit(limit); + let top_docs_handle = multicollector.add_collector(collector); + let ret = self.inner.search(&query.inner, &multicollector); + + match ret { + Ok(mut r) => { + let top_docs = top_docs_handle.extract(&mut r); + let result: Vec<(Fruit, DocAddress)> = top_docs + .iter() + .map(|(f, d)| { + (Fruit::Score(*f), DocAddress::from(d)) + }) + .collect(); + (r, result) + } + Err(e) => return Err(ValueError::py_err(e.to_string())), } - Err(e) => return Err(ValueError::py_err(e.to_string())), } }; @@ -127,7 +194,7 @@ impl Searcher { /// The id used for the segment is actually an ordinal in the list of segment /// hold by a Searcher. #[pyclass] -#[derive(Clone)] +#[derive(Clone, Debug)] pub(crate) struct DocAddress { pub(crate) segment_ord: tv::SegmentLocalId, pub(crate) doc: tv::DocId, diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 7264f5a..fc0ccc5 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -131,6 +131,75 @@ class TestClass(object): with pytest.raises(ValueError): index.parse_query("bod:men", ["title", "body"]) + def test_order_by_search(self): + schema = (SchemaBuilder() + .add_unsigned_field("order", fast="single") + .add_text_field("title", stored=True).build() + ) + + index = Index(schema) + writer = index.writer() + + doc = Document() + doc.add_unsigned("order", 0) + doc.add_text("title", "Test title") + + writer.add_document(doc) + + doc = Document() + doc.add_unsigned("order", 2) + doc.add_text("title", "Final test title") + writer.add_document(doc) + + doc = Document() + doc.add_unsigned("order", 1) + doc.add_text("title", "Another test title") + + + writer.add_document(doc) + + writer.commit() + index.reload() + + query = index.parse_query("test") + + searcher = index.searcher() + result = searcher.search(query, 10, order_by_field="order") + + assert len(result.hits) == 3 + + _, doc_address = result.hits[0] + searched_doc = index.searcher().doc(doc_address) + assert searched_doc["title"] == ["Final test title"] + + _, doc_address = result.hits[1] + searched_doc = index.searcher().doc(doc_address) + assert searched_doc["title"] == ["Another test title"] + + _, doc_address = result.hits[2] + searched_doc = index.searcher().doc(doc_address) + assert searched_doc["title"] == ["Test title"] + + def test_order_by_search_without_fast_field(self): + schema = (SchemaBuilder() + .add_unsigned_field("order") + .add_text_field("title", stored=True).build() + ) + + index = Index(schema) + writer = index.writer() + + doc = Document() + doc.add_unsigned("order", 0) + doc.add_text("title", "Test title") + + query = index.parse_query("test") + + searcher = index.searcher() + result = searcher.search(query, 10, order_by_field="order") + assert len(result.hits) == 0 + + class TestUpdateClass(object): def test_delete_update(self, ram_index): query = ram_index.parse_query("Frankenstein", ["title"])