feat: Aggregations API (#288)
parent
e3de7b1aab
commit
8ece24161b
|
@ -1,9 +1,11 @@
|
|||
#![allow(clippy::new_ret_no_self)]
|
||||
|
||||
use crate::{document::Document, query::Query, to_pyerr};
|
||||
use pyo3::types::PyDict;
|
||||
use pyo3::{basic::CompareOp, exceptions::PyValueError, prelude::*};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tantivy as tv;
|
||||
use tantivy::aggregation::AggregationCollector;
|
||||
use tantivy::collector::{Count, MultiCollector, TopDocs};
|
||||
use tantivy::TantivyDocument;
|
||||
// Bring the trait into scope. This is required for the `to_named_doc` method.
|
||||
|
@ -233,6 +235,35 @@ impl Searcher {
|
|||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (query, agg))]
|
||||
fn aggregate(
|
||||
&self,
|
||||
py: Python,
|
||||
query: &Query,
|
||||
agg: Py<PyDict>,
|
||||
) -> PyResult<Py<PyDict>> {
|
||||
let py_json = py.import_bound("json")?;
|
||||
let agg_query_str = py_json.call_method1("dumps", (agg,))?.to_string();
|
||||
|
||||
let agg_str = py.allow_threads(move || {
|
||||
let agg_collector = AggregationCollector::from_aggs(
|
||||
serde_json::from_str(&agg_query_str).map_err(to_pyerr)?,
|
||||
Default::default(),
|
||||
);
|
||||
let agg_res = self
|
||||
.inner
|
||||
.search(query.get(), &agg_collector)
|
||||
.map_err(to_pyerr)?;
|
||||
|
||||
serde_json::to_string(&agg_res).map_err(to_pyerr)
|
||||
})?;
|
||||
|
||||
let agg_dict = py_json.call_method1("loads", (agg_str,))?;
|
||||
let agg_dict = agg_dict.downcast::<PyDict>()?;
|
||||
|
||||
Ok(agg_dict.clone().unbind())
|
||||
}
|
||||
|
||||
/// Returns the overall number of documents in the index.
|
||||
#[getter]
|
||||
fn num_docs(&self) -> u64 {
|
||||
|
|
|
@ -289,6 +289,13 @@ class Searcher:
|
|||
) -> SearchResult:
|
||||
pass
|
||||
|
||||
def aggregate(
|
||||
self,
|
||||
search_query: Query,
|
||||
agg_query: dict,
|
||||
) -> dict:
|
||||
pass
|
||||
|
||||
@property
|
||||
def num_docs(self) -> int:
|
||||
pass
|
||||
|
|
|
@ -15,10 +15,10 @@ def schema():
|
|||
def schema_numeric_fields():
|
||||
return (
|
||||
SchemaBuilder()
|
||||
.add_integer_field("id", stored=True, indexed=True)
|
||||
.add_float_field("rating", stored=True, indexed=True)
|
||||
.add_integer_field("id", stored=True, indexed=True, fast=True)
|
||||
.add_float_field("rating", stored=True, indexed=True, fast=True)
|
||||
.add_boolean_field("is_good", stored=True, indexed=True)
|
||||
.add_text_field("body", stored=True)
|
||||
.add_text_field("body", stored=True, fast=True)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
|
|
@ -64,6 +64,60 @@ class TestClass(object):
|
|||
|
||||
assert len(result.hits) == 1
|
||||
|
||||
def test_and_aggregate(self, ram_index_numeric_fields):
|
||||
index = ram_index_numeric_fields
|
||||
query = Query.all_query()
|
||||
agg_query = {
|
||||
"top_hits_req": {
|
||||
"top_hits": {
|
||||
"size": 2,
|
||||
"sort": [{"rating": "desc"}],
|
||||
"from": 0,
|
||||
"docvalue_fields": ["rating", "id", "body"],
|
||||
}
|
||||
}
|
||||
}
|
||||
searcher = index.searcher()
|
||||
result = searcher.aggregate(query, agg_query)
|
||||
assert isinstance(result, dict)
|
||||
assert "top_hits_req" in result
|
||||
assert len(result["top_hits_req"]["hits"]) == 2
|
||||
for hit in result["top_hits_req"]["hits"]:
|
||||
assert len(hit["docvalue_fields"]) == 3
|
||||
|
||||
assert result == json.loads("""
|
||||
{
|
||||
"top_hits_req": {
|
||||
"hits": [
|
||||
{
|
||||
"sort": [ 13840124604862955520 ],
|
||||
"docvalue_fields": {
|
||||
"id": [ 2 ],
|
||||
"rating": [ 4.5 ],
|
||||
"body": [ "a", "few", "miles", "south", "of", "soledad", "the", "salinas", "river", "drops", "in", "close", "to", "the", "hillside",
|
||||
"bank", "and", "runs", "deep", "and", "green", "the", "water", "is", "warm", "too", "for", "it", "has", "slipped", "twinkling",
|
||||
"over", "the", "yellow", "sands", "in", "the", "sunlight", "before", "reaching", "the", "narrow", "pool",
|
||||
"on", "one", "side", "of", "the", "river", "the", "golden", "foothill", "slopes", "curve", "up",
|
||||
"to", "the", "strong", "and", "rocky", "gabilan", "mountains", "but", "on", "the", "valley", "side", "the",
|
||||
"water", "is", "lined", "with", "trees", "willows", "fresh", "and", "green", "with", "every", "spring", "carrying", "in", "their", "lower", "leaf",
|
||||
"junctures", "the", "debris", "of", "the", "winter", "s", "flooding", "and", "sycamores", "with", "mottled", "white", "recumbent", "limbs",
|
||||
"and", "branches", "that", "arch", "over", "the", "pool" ]
|
||||
}
|
||||
},
|
||||
{
|
||||
"sort": [ 13838435755002691584 ],
|
||||
"docvalue_fields": {
|
||||
"body": [ "he", "was", "an", "old", "man", "who", "fished", "alone", "in", "a", "skiff", "inthe", "gulf", "stream",
|
||||
"and", "he", "had", "gone", "eighty", "four", "days", "now", "without", "taking", "a", "fish" ],
|
||||
"rating": [ 3.5 ],
|
||||
"id": [ 1 ]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
""")
|
||||
|
||||
def test_and_query_numeric_fields(self, ram_index_numeric_fields):
|
||||
index = ram_index_numeric_fields
|
||||
searcher = index.searcher()
|
||||
|
|
Loading…
Reference in New Issue