searcher: Add support to search and order the results by a field.

master
Damir Jelić 2020-04-19 12:26:08 +02:00
parent 1d80c19434
commit 094f8974ea
2 changed files with 154 additions and 18 deletions

View File

@ -2,7 +2,7 @@
use crate::document::Document;
use crate::query::Query;
use crate::to_pyerr;
use crate::{get_field, to_pyerr};
use pyo3::exceptions::ValueError;
use pyo3::prelude::*;
use pyo3::PyObjectProtocol;
@ -17,16 +17,54 @@ pub(crate) struct Searcher {
pub(crate) inner: tv::LeasedItem<tv::Searcher>,
}
#[derive(Clone)]
enum Fruit {
Score(f32),
Order(u64),
}
impl std::fmt::Debug for Fruit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Fruit::Score(s) => f.write_str(&format!("{}", s)),
Fruit::Order(o) => f.write_str(&format!("{}", o)),
}
}
}
impl ToPyObject for Fruit {
fn to_object(&self, py: Python) -> PyObject {
match self {
Fruit::Score(s) => s.to_object(py),
Fruit::Order(o) => o.to_object(py),
}
}
}
#[pyclass]
/// Object holding a results successful search.
pub(crate) struct SearchResult {
hits: Vec<(PyObject, DocAddress)>,
hits: Vec<(Fruit, DocAddress)>,
#[pyo3(get)]
/// How many documents matched the query. Only available if `count` was set
/// to true during the search.
count: Option<usize>,
}
#[pyproto]
impl PyObjectProtocol for SearchResult {
fn __repr__(&self) -> PyResult<String> {
if let Some(count) = self.count {
Ok(format!(
"SearchResult(hits: {:?}, count: {})",
self.hits, count
))
} else {
Ok(format!("SearchResult(hits: {:?})", self.hits))
}
}
}
#[pymethods]
impl SearchResult {
#[getter]
@ -36,7 +74,7 @@ impl SearchResult {
let ret: Vec<(PyObject, DocAddress)> = self
.hits
.iter()
.map(|(obj, address)| (obj.clone_ref(py), address.clone()))
.map(|(result, address)| (result.to_object(py), address.clone()))
.collect();
Ok(ret)
}
@ -51,7 +89,11 @@ impl Searcher {
/// limit (int, optional): The maximum number of search results to
/// return. Defaults to 10.
/// count (bool, optional): Should the number of documents that match
/// the query be returned as well. Defaults to true.
/// the query be returned as well. Defaults to true.
/// order_by_field (Field, optional): A schema field that the results
/// should be ordered by. The field must be declared as a fast field
/// when building the schema. Note, this only works for unsigned
/// fields.
///
/// Returns `SearchResult` object.
///
@ -59,10 +101,11 @@ impl Searcher {
#[args(limit = 10, count = true)]
fn search(
&self,
py: Python,
_py: Python,
query: &Query,
limit: usize,
count: bool,
order_by_field: Option<&str>,
) -> PyResult<SearchResult> {
let mut multicollector = MultiCollector::new();
@ -73,20 +116,44 @@ impl Searcher {
};
let (mut multifruit, hits) = {
let collector = TopDocs::with_limit(limit);
let top_docs_handle = multicollector.add_collector(collector);
let ret = self.inner.search(&query.inner, &multicollector);
if let Some(order_by) = order_by_field {
let field = get_field(&self.inner.index().schema(), order_by)?;
let collector =
TopDocs::with_limit(limit).order_by_u64_field(field);
let top_docs_handle = multicollector.add_collector(collector);
let ret = self.inner.search(&query.inner, &multicollector);
match ret {
Ok(mut r) => {
let top_docs = top_docs_handle.extract(&mut r);
let result: Vec<(PyObject, DocAddress)> = top_docs
.iter()
.map(|(f, d)| ((*f).into_py(py), DocAddress::from(d)))
.collect();
(r, result)
match ret {
Ok(mut r) => {
let top_docs = top_docs_handle.extract(&mut r);
let result: Vec<(Fruit, DocAddress)> = top_docs
.iter()
.map(|(f, d)| {
(Fruit::Order(*f), DocAddress::from(d))
})
.collect();
(r, result)
}
Err(e) => return Err(ValueError::py_err(e.to_string())),
}
} else {
let collector = TopDocs::with_limit(limit);
let top_docs_handle = multicollector.add_collector(collector);
let ret = self.inner.search(&query.inner, &multicollector);
match ret {
Ok(mut r) => {
let top_docs = top_docs_handle.extract(&mut r);
let result: Vec<(Fruit, DocAddress)> = top_docs
.iter()
.map(|(f, d)| {
(Fruit::Score(*f), DocAddress::from(d))
})
.collect();
(r, result)
}
Err(e) => return Err(ValueError::py_err(e.to_string())),
}
Err(e) => return Err(ValueError::py_err(e.to_string())),
}
};
@ -127,7 +194,7 @@ impl Searcher {
/// The id used for the segment is actually an ordinal in the list of segment
/// hold by a Searcher.
#[pyclass]
#[derive(Clone)]
#[derive(Clone, Debug)]
pub(crate) struct DocAddress {
pub(crate) segment_ord: tv::SegmentLocalId,
pub(crate) doc: tv::DocId,

View File

@ -131,6 +131,75 @@ class TestClass(object):
with pytest.raises(ValueError):
index.parse_query("bod:men", ["title", "body"])
def test_order_by_search(self):
schema = (SchemaBuilder()
.add_unsigned_field("order", fast="single")
.add_text_field("title", stored=True).build()
)
index = Index(schema)
writer = index.writer()
doc = Document()
doc.add_unsigned("order", 0)
doc.add_text("title", "Test title")
writer.add_document(doc)
doc = Document()
doc.add_unsigned("order", 2)
doc.add_text("title", "Final test title")
writer.add_document(doc)
doc = Document()
doc.add_unsigned("order", 1)
doc.add_text("title", "Another test title")
writer.add_document(doc)
writer.commit()
index.reload()
query = index.parse_query("test")
searcher = index.searcher()
result = searcher.search(query, 10, order_by_field="order")
assert len(result.hits) == 3
_, doc_address = result.hits[0]
searched_doc = index.searcher().doc(doc_address)
assert searched_doc["title"] == ["Final test title"]
_, doc_address = result.hits[1]
searched_doc = index.searcher().doc(doc_address)
assert searched_doc["title"] == ["Another test title"]
_, doc_address = result.hits[2]
searched_doc = index.searcher().doc(doc_address)
assert searched_doc["title"] == ["Test title"]
def test_order_by_search_without_fast_field(self):
schema = (SchemaBuilder()
.add_unsigned_field("order")
.add_text_field("title", stored=True).build()
)
index = Index(schema)
writer = index.writer()
doc = Document()
doc.add_unsigned("order", 0)
doc.add_text("title", "Test title")
query = index.parse_query("test")
searcher = index.searcher()
result = searcher.search(query, 10, order_by_field="order")
assert len(result.hits) == 0
class TestUpdateClass(object):
def test_delete_update(self, ram_index):
query = ram_index.parse_query("Frankenstein", ["title"])