searcher: Allow the search to be sorted by an unsigned field.

master
Damir Jelić 2019-10-01 20:56:42 +02:00
parent fbea6fe633
commit d46417c220
4 changed files with 64 additions and 21 deletions

View File

@ -8,7 +8,7 @@ use crate::document::{extract_value, Document};
use crate::query::Query;
use crate::schema::Schema;
use crate::searcher::Searcher;
use crate::to_pyerr;
use crate::{to_pyerr, get_field};
use tantivy as tv;
use tantivy::directory::MmapDirectory;
use tantivy::schema::{Field, NamedFieldDocument, Term, Value};
@ -111,13 +111,7 @@ impl IndexWriter {
field_name: &str,
field_value: &PyAny,
) -> PyResult<u64> {
let field = self.schema.get_field(field_name).ok_or_else(|| {
exceptions::ValueError::py_err(format!(
"Field `{}` is not defined in the schema.",
field_name
))
})?;
let field = get_field(&self.schema, field_name)?;
let value = extract_value(field_value)?;
let term = match value {
Value::Str(text) => Term::from_field_text(field, &text),
@ -274,6 +268,7 @@ impl Index {
fn searcher(&self) -> Searcher {
Searcher {
inner: self.reader.searcher(),
schema: self.index.schema(),
}
}

View File

@ -1,5 +1,6 @@
use pyo3::exceptions;
use pyo3::prelude::*;
use tantivy as tv;
mod document;
mod facet;
@ -83,3 +84,14 @@ fn tantivy(_py: Python, m: &PyModule) -> PyResult<()> {
pub(crate) fn to_pyerr<E: ToString>(err: E) -> PyErr {
exceptions::ValueError::py_err(err.to_string())
}
pub(crate) fn get_field(schema: &tv::schema::Schema, field_name: &str) -> PyResult<tv::schema::Field> {
let field = schema.get_field(field_name).ok_or_else(|| {
exceptions::ValueError::py_err(format!(
"Field `{}` is not defined in the schema.",
field_name
))
})?;
Ok(field)
}

View File

@ -2,9 +2,9 @@
use crate::document::Document;
use crate::query::Query;
use crate::to_pyerr;
use crate::{to_pyerr, get_field};
use pyo3::prelude::*;
use pyo3::{exceptions, PyObjectProtocol};
use pyo3::PyObjectProtocol;
use tantivy as tv;
/// Tantivy's Searcher class
@ -13,8 +13,11 @@ use tantivy as tv;
#[pyclass]
pub(crate) struct Searcher {
pub(crate) inner: tv::LeasedItem<tv::Searcher>,
pub(crate) schema: tv::schema::Schema,
}
const SORT_BY: &str = "";
#[pymethods]
impl Searcher {
/// Search the index with the given query and collect results.
@ -29,26 +32,31 @@ impl Searcher {
/// search results.
///
/// Raises a ValueError if there was an error with the search.
#[args(limit = 10)]
#[args(limit = 10, sort_by = "SORT_BY")]
fn search(
&self,
py: Python,
query: &Query,
limit: usize,
sort_by: &str,
) -> PyResult<Vec<(PyObject, DocAddress)>> {
let field = match sort_by {
"" => None,
field_name => Some(get_field(&self.schema, field_name)?)
};
let result = if let Some(f) = field {
let collector = tv::collector::TopDocs::with_limit(limit).order_by_u64_field(f);
let ret = self.inner.search(&query.inner, &collector).map_err(to_pyerr)?;
ret.iter().map(|(f, d)| ((*f).into_py(py), DocAddress::from(d))).collect()
} else {
let collector = tv::collector::TopDocs::with_limit(limit);
let ret = self.inner.search(&query.inner, &collector);
let ret = self.inner.search(&query.inner, &collector).map_err(to_pyerr)?;
ret.iter().map(|(f, d)| ((*f).into_py(py), DocAddress::from(d))).collect()
};
match ret {
Ok(r) => {
let result: Vec<(PyObject, DocAddress)> =
r.iter().map(|(f, d)| ((*f).into_py(py), DocAddress::from(d))).collect();
Ok(result)
}
Err(e) => Err(exceptions::ValueError::py_err(e.to_string())),
}
}
/// Returns the overall number of documents in the index.
#[getter]

View File

@ -131,6 +131,34 @@ class TestClass(object):
with pytest.raises(ValueError):
index.parse_query("bod:men", ["title", "body"])
def test_sort_by_search(self):
schema = (
SchemaBuilder()
.add_text_field("message", stored=True)
.add_unsigned_field("timestamp", stored=True, fast="single")
.build()
)
index = Index(schema)
writer = index.writer()
doc = Document()
doc.add_text("message", "Test message")
doc.add_unsigned("timestamp", 1569954264)
writer.add_document(doc)
doc = Document()
doc.add_text("message", "Another test message")
doc.add_unsigned("timestamp", 1569954280)
writer.add_document(doc)
writer.commit()
index.reload()
query = index.parse_query("test")
result = index.searcher().search(query, 10, sort_by="timestamp")
assert result[0][0] == 1569954280
assert result[1][0] == 1569954264
class TestUpdateClass(object):
def test_delete_update(self, ram_index):