diff --git a/src/index.rs b/src/index.rs index c20af98..ac7f5ef 100644 --- a/src/index.rs +++ b/src/index.rs @@ -8,7 +8,7 @@ use crate::document::{extract_value, Document}; use crate::query::Query; use crate::schema::Schema; use crate::searcher::Searcher; -use crate::to_pyerr; +use crate::{to_pyerr, get_field}; use tantivy as tv; use tantivy::directory::MmapDirectory; use tantivy::schema::{Field, NamedFieldDocument, Term, Value}; @@ -111,13 +111,7 @@ impl IndexWriter { field_name: &str, field_value: &PyAny, ) -> PyResult { - let field = self.schema.get_field(field_name).ok_or_else(|| { - exceptions::ValueError::py_err(format!( - "Field `{}` is not defined in the schema.", - field_name - )) - })?; - + let field = get_field(&self.schema, field_name)?; let value = extract_value(field_value)?; let term = match value { Value::Str(text) => Term::from_field_text(field, &text), @@ -274,6 +268,7 @@ impl Index { fn searcher(&self) -> Searcher { Searcher { inner: self.reader.searcher(), + schema: self.index.schema(), } } diff --git a/src/lib.rs b/src/lib.rs index be75122..77cf565 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ use pyo3::exceptions; use pyo3::prelude::*; +use tantivy as tv; mod document; mod facet; @@ -83,3 +84,14 @@ fn tantivy(_py: Python, m: &PyModule) -> PyResult<()> { pub(crate) fn to_pyerr(err: E) -> PyErr { exceptions::ValueError::py_err(err.to_string()) } + +pub(crate) fn get_field(schema: &tv::schema::Schema, field_name: &str) -> PyResult { + let field = schema.get_field(field_name).ok_or_else(|| { + exceptions::ValueError::py_err(format!( + "Field `{}` is not defined in the schema.", + field_name + )) + })?; + + Ok(field) +} diff --git a/src/searcher.rs b/src/searcher.rs index 6871763..1de5e2b 100644 --- a/src/searcher.rs +++ b/src/searcher.rs @@ -2,9 +2,9 @@ use crate::document::Document; use crate::query::Query; -use crate::to_pyerr; +use crate::{to_pyerr, get_field}; use pyo3::prelude::*; -use pyo3::{exceptions, PyObjectProtocol}; +use pyo3::PyObjectProtocol; use tantivy as tv; /// Tantivy's Searcher class @@ -13,8 +13,11 @@ use tantivy as tv; #[pyclass] pub(crate) struct Searcher { pub(crate) inner: tv::LeasedItem, + pub(crate) schema: tv::schema::Schema, } +const SORT_BY: &str = ""; + #[pymethods] impl Searcher { /// Search the index with the given query and collect results. @@ -29,25 +32,30 @@ impl Searcher { /// search results. /// /// Raises a ValueError if there was an error with the search. - #[args(limit = 10)] + #[args(limit = 10, sort_by = "SORT_BY")] fn search( &self, py: Python, query: &Query, limit: usize, + sort_by: &str, ) -> PyResult> { - let collector = tv::collector::TopDocs::with_limit(limit); - let ret = self.inner.search(&query.inner, &collector); + let field = match sort_by { + "" => None, + field_name => Some(get_field(&self.schema, field_name)?) + }; - match ret { - Ok(r) => { - let result: Vec<(PyObject, DocAddress)> = - r.iter().map(|(f, d)| ((*f).into_py(py), DocAddress::from(d))).collect(); - Ok(result) - } - Err(e) => Err(exceptions::ValueError::py_err(e.to_string())), - } + let result = if let Some(f) = field { + let collector = tv::collector::TopDocs::with_limit(limit).order_by_u64_field(f); + let ret = self.inner.search(&query.inner, &collector).map_err(to_pyerr)?; + ret.iter().map(|(f, d)| ((*f).into_py(py), DocAddress::from(d))).collect() + } else { + let collector = tv::collector::TopDocs::with_limit(limit); + let ret = self.inner.search(&query.inner, &collector).map_err(to_pyerr)?; + ret.iter().map(|(f, d)| ((*f).into_py(py), DocAddress::from(d))).collect() + }; + Ok(result) } /// Returns the overall number of documents in the index. diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index c57b90a..0999b3c 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -131,6 +131,34 @@ class TestClass(object): with pytest.raises(ValueError): index.parse_query("bod:men", ["title", "body"]) + def test_sort_by_search(self): + schema = ( + SchemaBuilder() + .add_text_field("message", stored=True) + .add_unsigned_field("timestamp", stored=True, fast="single") + .build() + ) + index = Index(schema) + writer = index.writer() + doc = Document() + doc.add_text("message", "Test message") + doc.add_unsigned("timestamp", 1569954264) + writer.add_document(doc) + + doc = Document() + doc.add_text("message", "Another test message") + doc.add_unsigned("timestamp", 1569954280) + + writer.add_document(doc) + + writer.commit() + index.reload() + + query = index.parse_query("test") + result = index.searcher().search(query, 10, sort_by="timestamp") + assert result[0][0] == 1569954280 + assert result[1][0] == 1569954264 + class TestUpdateClass(object): def test_delete_update(self, ram_index):