searcher: Allow the search to be sorted by an unsigned field.

master
Damir Jelić 2019-10-01 20:56:42 +02:00
parent fbea6fe633
commit d46417c220
4 changed files with 64 additions and 21 deletions

View File

@ -8,7 +8,7 @@ use crate::document::{extract_value, Document};
use crate::query::Query; use crate::query::Query;
use crate::schema::Schema; use crate::schema::Schema;
use crate::searcher::Searcher; use crate::searcher::Searcher;
use crate::to_pyerr; use crate::{to_pyerr, get_field};
use tantivy as tv; use tantivy as tv;
use tantivy::directory::MmapDirectory; use tantivy::directory::MmapDirectory;
use tantivy::schema::{Field, NamedFieldDocument, Term, Value}; use tantivy::schema::{Field, NamedFieldDocument, Term, Value};
@ -111,13 +111,7 @@ impl IndexWriter {
field_name: &str, field_name: &str,
field_value: &PyAny, field_value: &PyAny,
) -> PyResult<u64> { ) -> PyResult<u64> {
let field = self.schema.get_field(field_name).ok_or_else(|| { let field = get_field(&self.schema, field_name)?;
exceptions::ValueError::py_err(format!(
"Field `{}` is not defined in the schema.",
field_name
))
})?;
let value = extract_value(field_value)?; let value = extract_value(field_value)?;
let term = match value { let term = match value {
Value::Str(text) => Term::from_field_text(field, &text), Value::Str(text) => Term::from_field_text(field, &text),
@ -274,6 +268,7 @@ impl Index {
fn searcher(&self) -> Searcher { fn searcher(&self) -> Searcher {
Searcher { Searcher {
inner: self.reader.searcher(), inner: self.reader.searcher(),
schema: self.index.schema(),
} }
} }

View File

@ -1,5 +1,6 @@
use pyo3::exceptions; use pyo3::exceptions;
use pyo3::prelude::*; use pyo3::prelude::*;
use tantivy as tv;
mod document; mod document;
mod facet; mod facet;
@ -83,3 +84,14 @@ fn tantivy(_py: Python, m: &PyModule) -> PyResult<()> {
pub(crate) fn to_pyerr<E: ToString>(err: E) -> PyErr { pub(crate) fn to_pyerr<E: ToString>(err: E) -> PyErr {
exceptions::ValueError::py_err(err.to_string()) exceptions::ValueError::py_err(err.to_string())
} }
pub(crate) fn get_field(schema: &tv::schema::Schema, field_name: &str) -> PyResult<tv::schema::Field> {
let field = schema.get_field(field_name).ok_or_else(|| {
exceptions::ValueError::py_err(format!(
"Field `{}` is not defined in the schema.",
field_name
))
})?;
Ok(field)
}

View File

@ -2,9 +2,9 @@
use crate::document::Document; use crate::document::Document;
use crate::query::Query; use crate::query::Query;
use crate::to_pyerr; use crate::{to_pyerr, get_field};
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::{exceptions, PyObjectProtocol}; use pyo3::PyObjectProtocol;
use tantivy as tv; use tantivy as tv;
/// Tantivy's Searcher class /// Tantivy's Searcher class
@ -13,8 +13,11 @@ use tantivy as tv;
#[pyclass] #[pyclass]
pub(crate) struct Searcher { pub(crate) struct Searcher {
pub(crate) inner: tv::LeasedItem<tv::Searcher>, pub(crate) inner: tv::LeasedItem<tv::Searcher>,
pub(crate) schema: tv::schema::Schema,
} }
const SORT_BY: &str = "";
#[pymethods] #[pymethods]
impl Searcher { impl Searcher {
/// Search the index with the given query and collect results. /// Search the index with the given query and collect results.
@ -29,25 +32,30 @@ impl Searcher {
/// search results. /// search results.
/// ///
/// Raises a ValueError if there was an error with the search. /// Raises a ValueError if there was an error with the search.
#[args(limit = 10)] #[args(limit = 10, sort_by = "SORT_BY")]
fn search( fn search(
&self, &self,
py: Python, py: Python,
query: &Query, query: &Query,
limit: usize, limit: usize,
sort_by: &str,
) -> PyResult<Vec<(PyObject, DocAddress)>> { ) -> PyResult<Vec<(PyObject, DocAddress)>> {
let collector = tv::collector::TopDocs::with_limit(limit); let field = match sort_by {
let ret = self.inner.search(&query.inner, &collector); "" => None,
field_name => Some(get_field(&self.schema, field_name)?)
};
match ret { let result = if let Some(f) = field {
Ok(r) => { let collector = tv::collector::TopDocs::with_limit(limit).order_by_u64_field(f);
let result: Vec<(PyObject, DocAddress)> = let ret = self.inner.search(&query.inner, &collector).map_err(to_pyerr)?;
r.iter().map(|(f, d)| ((*f).into_py(py), DocAddress::from(d))).collect(); ret.iter().map(|(f, d)| ((*f).into_py(py), DocAddress::from(d))).collect()
Ok(result) } else {
} let collector = tv::collector::TopDocs::with_limit(limit);
Err(e) => Err(exceptions::ValueError::py_err(e.to_string())), let ret = self.inner.search(&query.inner, &collector).map_err(to_pyerr)?;
} ret.iter().map(|(f, d)| ((*f).into_py(py), DocAddress::from(d))).collect()
};
Ok(result)
} }
/// Returns the overall number of documents in the index. /// Returns the overall number of documents in the index.

View File

@ -131,6 +131,34 @@ class TestClass(object):
with pytest.raises(ValueError): with pytest.raises(ValueError):
index.parse_query("bod:men", ["title", "body"]) index.parse_query("bod:men", ["title", "body"])
def test_sort_by_search(self):
schema = (
SchemaBuilder()
.add_text_field("message", stored=True)
.add_unsigned_field("timestamp", stored=True, fast="single")
.build()
)
index = Index(schema)
writer = index.writer()
doc = Document()
doc.add_text("message", "Test message")
doc.add_unsigned("timestamp", 1569954264)
writer.add_document(doc)
doc = Document()
doc.add_text("message", "Another test message")
doc.add_unsigned("timestamp", 1569954280)
writer.add_document(doc)
writer.commit()
index.reload()
query = index.parse_query("test")
result = index.searcher().search(query, 10, sort_by="timestamp")
assert result[0][0] == 1569954280
assert result[1][0] == 1569954264
class TestUpdateClass(object): class TestUpdateClass(object):
def test_delete_update(self, ram_index): def test_delete_update(self, ram_index):