feat: Aggregations API (#288)

master
Tushar 2024-06-09 16:42:57 +05:30 committed by GitHub
parent e3de7b1aab
commit 8ece24161b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 95 additions and 3 deletions

View File

@ -1,9 +1,11 @@
#![allow(clippy::new_ret_no_self)]
use crate::{document::Document, query::Query, to_pyerr};
use pyo3::types::PyDict;
use pyo3::{basic::CompareOp, exceptions::PyValueError, prelude::*};
use serde::{Deserialize, Serialize};
use tantivy as tv;
use tantivy::aggregation::AggregationCollector;
use tantivy::collector::{Count, MultiCollector, TopDocs};
use tantivy::TantivyDocument;
// Bring the trait into scope. This is required for the `to_named_doc` method.
@ -233,6 +235,35 @@ impl Searcher {
})
}
#[pyo3(signature = (query, agg))]
fn aggregate(
&self,
py: Python,
query: &Query,
agg: Py<PyDict>,
) -> PyResult<Py<PyDict>> {
let py_json = py.import_bound("json")?;
let agg_query_str = py_json.call_method1("dumps", (agg,))?.to_string();
let agg_str = py.allow_threads(move || {
let agg_collector = AggregationCollector::from_aggs(
serde_json::from_str(&agg_query_str).map_err(to_pyerr)?,
Default::default(),
);
let agg_res = self
.inner
.search(query.get(), &agg_collector)
.map_err(to_pyerr)?;
serde_json::to_string(&agg_res).map_err(to_pyerr)
})?;
let agg_dict = py_json.call_method1("loads", (agg_str,))?;
let agg_dict = agg_dict.downcast::<PyDict>()?;
Ok(agg_dict.clone().unbind())
}
/// Returns the overall number of documents in the index.
#[getter]
fn num_docs(&self) -> u64 {

View File

@ -289,6 +289,13 @@ class Searcher:
) -> SearchResult:
pass
def aggregate(
self,
search_query: Query,
agg_query: dict,
) -> dict:
pass
@property
def num_docs(self) -> int:
pass

View File

@ -15,10 +15,10 @@ def schema():
def schema_numeric_fields():
return (
SchemaBuilder()
.add_integer_field("id", stored=True, indexed=True)
.add_float_field("rating", stored=True, indexed=True)
.add_integer_field("id", stored=True, indexed=True, fast=True)
.add_float_field("rating", stored=True, indexed=True, fast=True)
.add_boolean_field("is_good", stored=True, indexed=True)
.add_text_field("body", stored=True)
.add_text_field("body", stored=True, fast=True)
.build()
)

View File

@ -64,6 +64,60 @@ class TestClass(object):
assert len(result.hits) == 1
def test_and_aggregate(self, ram_index_numeric_fields):
index = ram_index_numeric_fields
query = Query.all_query()
agg_query = {
"top_hits_req": {
"top_hits": {
"size": 2,
"sort": [{"rating": "desc"}],
"from": 0,
"docvalue_fields": ["rating", "id", "body"],
}
}
}
searcher = index.searcher()
result = searcher.aggregate(query, agg_query)
assert isinstance(result, dict)
assert "top_hits_req" in result
assert len(result["top_hits_req"]["hits"]) == 2
for hit in result["top_hits_req"]["hits"]:
assert len(hit["docvalue_fields"]) == 3
assert result == json.loads("""
{
"top_hits_req": {
"hits": [
{
"sort": [ 13840124604862955520 ],
"docvalue_fields": {
"id": [ 2 ],
"rating": [ 4.5 ],
"body": [ "a", "few", "miles", "south", "of", "soledad", "the", "salinas", "river", "drops", "in", "close", "to", "the", "hillside",
"bank", "and", "runs", "deep", "and", "green", "the", "water", "is", "warm", "too", "for", "it", "has", "slipped", "twinkling",
"over", "the", "yellow", "sands", "in", "the", "sunlight", "before", "reaching", "the", "narrow", "pool",
"on", "one", "side", "of", "the", "river", "the", "golden", "foothill", "slopes", "curve", "up",
"to", "the", "strong", "and", "rocky", "gabilan", "mountains", "but", "on", "the", "valley", "side", "the",
"water", "is", "lined", "with", "trees", "willows", "fresh", "and", "green", "with", "every", "spring", "carrying", "in", "their", "lower", "leaf",
"junctures", "the", "debris", "of", "the", "winter", "s", "flooding", "and", "sycamores", "with", "mottled", "white", "recumbent", "limbs",
"and", "branches", "that", "arch", "over", "the", "pool" ]
}
},
{
"sort": [ 13838435755002691584 ],
"docvalue_fields": {
"body": [ "he", "was", "an", "old", "man", "who", "fished", "alone", "in", "a", "skiff", "inthe", "gulf", "stream",
"and", "he", "had", "gone", "eighty", "four", "days", "now", "without", "taking", "a", "fish" ],
"rating": [ 3.5 ],
"id": [ 1 ]
}
}
]
}
}
""")
def test_and_query_numeric_fields(self, ram_index_numeric_fields):
index = ram_index_numeric_fields
searcher = index.searcher()