From 8ece24161ba67a3cba3663eccf061f9965a35743 Mon Sep 17 00:00:00 2001 From: Tushar Date: Sun, 9 Jun 2024 16:42:57 +0530 Subject: [PATCH] feat: Aggregations API (#288) --- src/searcher.rs | 31 +++++++++++++++++++++++++ tantivy/tantivy.pyi | 7 ++++++ tests/conftest.py | 6 ++--- tests/tantivy_test.py | 54 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 95 insertions(+), 3 deletions(-) diff --git a/src/searcher.rs b/src/searcher.rs index e87ac4e..528e54f 100644 --- a/src/searcher.rs +++ b/src/searcher.rs @@ -1,9 +1,11 @@ #![allow(clippy::new_ret_no_self)] use crate::{document::Document, query::Query, to_pyerr}; +use pyo3::types::PyDict; use pyo3::{basic::CompareOp, exceptions::PyValueError, prelude::*}; use serde::{Deserialize, Serialize}; use tantivy as tv; +use tantivy::aggregation::AggregationCollector; use tantivy::collector::{Count, MultiCollector, TopDocs}; use tantivy::TantivyDocument; // Bring the trait into scope. This is required for the `to_named_doc` method. @@ -233,6 +235,35 @@ impl Searcher { }) } + #[pyo3(signature = (query, agg))] + fn aggregate( + &self, + py: Python, + query: &Query, + agg: Py, + ) -> PyResult> { + let py_json = py.import_bound("json")?; + let agg_query_str = py_json.call_method1("dumps", (agg,))?.to_string(); + + let agg_str = py.allow_threads(move || { + let agg_collector = AggregationCollector::from_aggs( + serde_json::from_str(&agg_query_str).map_err(to_pyerr)?, + Default::default(), + ); + let agg_res = self + .inner + .search(query.get(), &agg_collector) + .map_err(to_pyerr)?; + + serde_json::to_string(&agg_res).map_err(to_pyerr) + })?; + + let agg_dict = py_json.call_method1("loads", (agg_str,))?; + let agg_dict = agg_dict.downcast::()?; + + Ok(agg_dict.clone().unbind()) + } + /// Returns the overall number of documents in the index. #[getter] fn num_docs(&self) -> u64 { diff --git a/tantivy/tantivy.pyi b/tantivy/tantivy.pyi index 3934b62..8dfe9ad 100644 --- a/tantivy/tantivy.pyi +++ b/tantivy/tantivy.pyi @@ -289,6 +289,13 @@ class Searcher: ) -> SearchResult: pass + def aggregate( + self, + search_query: Query, + agg_query: dict, + ) -> dict: + pass + @property def num_docs(self) -> int: pass diff --git a/tests/conftest.py b/tests/conftest.py index 313fdba..74c7c80 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,10 +15,10 @@ def schema(): def schema_numeric_fields(): return ( SchemaBuilder() - .add_integer_field("id", stored=True, indexed=True) - .add_float_field("rating", stored=True, indexed=True) + .add_integer_field("id", stored=True, indexed=True, fast=True) + .add_float_field("rating", stored=True, indexed=True, fast=True) .add_boolean_field("is_good", stored=True, indexed=True) - .add_text_field("body", stored=True) + .add_text_field("body", stored=True, fast=True) .build() ) diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index e2a77eb..3ad9a2f 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -64,6 +64,60 @@ class TestClass(object): assert len(result.hits) == 1 + def test_and_aggregate(self, ram_index_numeric_fields): + index = ram_index_numeric_fields + query = Query.all_query() + agg_query = { + "top_hits_req": { + "top_hits": { + "size": 2, + "sort": [{"rating": "desc"}], + "from": 0, + "docvalue_fields": ["rating", "id", "body"], + } + } + } + searcher = index.searcher() + result = searcher.aggregate(query, agg_query) + assert isinstance(result, dict) + assert "top_hits_req" in result + assert len(result["top_hits_req"]["hits"]) == 2 + for hit in result["top_hits_req"]["hits"]: + assert len(hit["docvalue_fields"]) == 3 + + assert result == json.loads(""" +{ +"top_hits_req": { + "hits": [ + { + "sort": [ 13840124604862955520 ], + "docvalue_fields": { + "id": [ 2 ], + "rating": [ 4.5 ], + "body": [ "a", "few", "miles", "south", "of", "soledad", "the", "salinas", "river", "drops", "in", "close", "to", "the", "hillside", + "bank", "and", "runs", "deep", "and", "green", "the", "water", "is", "warm", "too", "for", "it", "has", "slipped", "twinkling", + "over", "the", "yellow", "sands", "in", "the", "sunlight", "before", "reaching", "the", "narrow", "pool", + "on", "one", "side", "of", "the", "river", "the", "golden", "foothill", "slopes", "curve", "up", + "to", "the", "strong", "and", "rocky", "gabilan", "mountains", "but", "on", "the", "valley", "side", "the", + "water", "is", "lined", "with", "trees", "willows", "fresh", "and", "green", "with", "every", "spring", "carrying", "in", "their", "lower", "leaf", + "junctures", "the", "debris", "of", "the", "winter", "s", "flooding", "and", "sycamores", "with", "mottled", "white", "recumbent", "limbs", + "and", "branches", "that", "arch", "over", "the", "pool" ] + } + }, + { + "sort": [ 13838435755002691584 ], + "docvalue_fields": { + "body": [ "he", "was", "an", "old", "man", "who", "fished", "alone", "in", "a", "skiff", "inthe", "gulf", "stream", + "and", "he", "had", "gone", "eighty", "four", "days", "now", "without", "taking", "a", "fish" ], + "rating": [ 3.5 ], + "id": [ 1 ] + } + } + ] + } +} +""") + def test_and_query_numeric_fields(self, ram_index_numeric_fields): index = ram_index_numeric_fields searcher = index.searcher()