feat: Aggregations API (#288)
parent
e3de7b1aab
commit
8ece24161b
|
@ -1,9 +1,11 @@
|
||||||
#![allow(clippy::new_ret_no_self)]
|
#![allow(clippy::new_ret_no_self)]
|
||||||
|
|
||||||
use crate::{document::Document, query::Query, to_pyerr};
|
use crate::{document::Document, query::Query, to_pyerr};
|
||||||
|
use pyo3::types::PyDict;
|
||||||
use pyo3::{basic::CompareOp, exceptions::PyValueError, prelude::*};
|
use pyo3::{basic::CompareOp, exceptions::PyValueError, prelude::*};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tantivy as tv;
|
use tantivy as tv;
|
||||||
|
use tantivy::aggregation::AggregationCollector;
|
||||||
use tantivy::collector::{Count, MultiCollector, TopDocs};
|
use tantivy::collector::{Count, MultiCollector, TopDocs};
|
||||||
use tantivy::TantivyDocument;
|
use tantivy::TantivyDocument;
|
||||||
// Bring the trait into scope. This is required for the `to_named_doc` method.
|
// Bring the trait into scope. This is required for the `to_named_doc` method.
|
||||||
|
@ -233,6 +235,35 @@ impl Searcher {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[pyo3(signature = (query, agg))]
|
||||||
|
fn aggregate(
|
||||||
|
&self,
|
||||||
|
py: Python,
|
||||||
|
query: &Query,
|
||||||
|
agg: Py<PyDict>,
|
||||||
|
) -> PyResult<Py<PyDict>> {
|
||||||
|
let py_json = py.import_bound("json")?;
|
||||||
|
let agg_query_str = py_json.call_method1("dumps", (agg,))?.to_string();
|
||||||
|
|
||||||
|
let agg_str = py.allow_threads(move || {
|
||||||
|
let agg_collector = AggregationCollector::from_aggs(
|
||||||
|
serde_json::from_str(&agg_query_str).map_err(to_pyerr)?,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
let agg_res = self
|
||||||
|
.inner
|
||||||
|
.search(query.get(), &agg_collector)
|
||||||
|
.map_err(to_pyerr)?;
|
||||||
|
|
||||||
|
serde_json::to_string(&agg_res).map_err(to_pyerr)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let agg_dict = py_json.call_method1("loads", (agg_str,))?;
|
||||||
|
let agg_dict = agg_dict.downcast::<PyDict>()?;
|
||||||
|
|
||||||
|
Ok(agg_dict.clone().unbind())
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns the overall number of documents in the index.
|
/// Returns the overall number of documents in the index.
|
||||||
#[getter]
|
#[getter]
|
||||||
fn num_docs(&self) -> u64 {
|
fn num_docs(&self) -> u64 {
|
||||||
|
|
|
@ -289,6 +289,13 @@ class Searcher:
|
||||||
) -> SearchResult:
|
) -> SearchResult:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def aggregate(
|
||||||
|
self,
|
||||||
|
search_query: Query,
|
||||||
|
agg_query: dict,
|
||||||
|
) -> dict:
|
||||||
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_docs(self) -> int:
|
def num_docs(self) -> int:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -15,10 +15,10 @@ def schema():
|
||||||
def schema_numeric_fields():
|
def schema_numeric_fields():
|
||||||
return (
|
return (
|
||||||
SchemaBuilder()
|
SchemaBuilder()
|
||||||
.add_integer_field("id", stored=True, indexed=True)
|
.add_integer_field("id", stored=True, indexed=True, fast=True)
|
||||||
.add_float_field("rating", stored=True, indexed=True)
|
.add_float_field("rating", stored=True, indexed=True, fast=True)
|
||||||
.add_boolean_field("is_good", stored=True, indexed=True)
|
.add_boolean_field("is_good", stored=True, indexed=True)
|
||||||
.add_text_field("body", stored=True)
|
.add_text_field("body", stored=True, fast=True)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -64,6 +64,60 @@ class TestClass(object):
|
||||||
|
|
||||||
assert len(result.hits) == 1
|
assert len(result.hits) == 1
|
||||||
|
|
||||||
|
def test_and_aggregate(self, ram_index_numeric_fields):
|
||||||
|
index = ram_index_numeric_fields
|
||||||
|
query = Query.all_query()
|
||||||
|
agg_query = {
|
||||||
|
"top_hits_req": {
|
||||||
|
"top_hits": {
|
||||||
|
"size": 2,
|
||||||
|
"sort": [{"rating": "desc"}],
|
||||||
|
"from": 0,
|
||||||
|
"docvalue_fields": ["rating", "id", "body"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
searcher = index.searcher()
|
||||||
|
result = searcher.aggregate(query, agg_query)
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert "top_hits_req" in result
|
||||||
|
assert len(result["top_hits_req"]["hits"]) == 2
|
||||||
|
for hit in result["top_hits_req"]["hits"]:
|
||||||
|
assert len(hit["docvalue_fields"]) == 3
|
||||||
|
|
||||||
|
assert result == json.loads("""
|
||||||
|
{
|
||||||
|
"top_hits_req": {
|
||||||
|
"hits": [
|
||||||
|
{
|
||||||
|
"sort": [ 13840124604862955520 ],
|
||||||
|
"docvalue_fields": {
|
||||||
|
"id": [ 2 ],
|
||||||
|
"rating": [ 4.5 ],
|
||||||
|
"body": [ "a", "few", "miles", "south", "of", "soledad", "the", "salinas", "river", "drops", "in", "close", "to", "the", "hillside",
|
||||||
|
"bank", "and", "runs", "deep", "and", "green", "the", "water", "is", "warm", "too", "for", "it", "has", "slipped", "twinkling",
|
||||||
|
"over", "the", "yellow", "sands", "in", "the", "sunlight", "before", "reaching", "the", "narrow", "pool",
|
||||||
|
"on", "one", "side", "of", "the", "river", "the", "golden", "foothill", "slopes", "curve", "up",
|
||||||
|
"to", "the", "strong", "and", "rocky", "gabilan", "mountains", "but", "on", "the", "valley", "side", "the",
|
||||||
|
"water", "is", "lined", "with", "trees", "willows", "fresh", "and", "green", "with", "every", "spring", "carrying", "in", "their", "lower", "leaf",
|
||||||
|
"junctures", "the", "debris", "of", "the", "winter", "s", "flooding", "and", "sycamores", "with", "mottled", "white", "recumbent", "limbs",
|
||||||
|
"and", "branches", "that", "arch", "over", "the", "pool" ]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"sort": [ 13838435755002691584 ],
|
||||||
|
"docvalue_fields": {
|
||||||
|
"body": [ "he", "was", "an", "old", "man", "who", "fished", "alone", "in", "a", "skiff", "inthe", "gulf", "stream",
|
||||||
|
"and", "he", "had", "gone", "eighty", "four", "days", "now", "without", "taking", "a", "fish" ],
|
||||||
|
"rating": [ 3.5 ],
|
||||||
|
"id": [ 1 ]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""")
|
||||||
|
|
||||||
def test_and_query_numeric_fields(self, ram_index_numeric_fields):
|
def test_and_query_numeric_fields(self, ram_index_numeric_fields):
|
||||||
index = ram_index_numeric_fields
|
index = ram_index_numeric_fields
|
||||||
searcher = index.searcher()
|
searcher = index.searcher()
|
||||||
|
|
Loading…
Reference in New Issue