From adc7b08b75371c0d6de5944cd47bf922eb88766d Mon Sep 17 00:00:00 2001 From: alex-au-922 <79151747+alex-au-922@users.noreply.github.com> Date: Mon, 10 Jun 2024 07:00:39 +0800 Subject: [PATCH] Expose Range Query (Beta) (#281) Co-authored-by: Caleb Hattingh --- src/document.rs | 21 ++++- src/lib.rs | 35 ++++++++- src/query.rs | 83 +++++++++++++++++++- src/schema.rs | 33 ++++++++ tantivy/tantivy.pyi | 34 +++++++- tests/conftest.py | 81 +++++++++++++++++++ tests/tantivy_test.py | 179 +++++++++++++++++++++++++++++++++++++++++- 7 files changed, 457 insertions(+), 9 deletions(-) diff --git a/src/document.rs b/src/document.rs index ac391ab..8e2ba6c 100644 --- a/src/document.rs +++ b/src/document.rs @@ -689,7 +689,7 @@ impl Document { /// Add a JSON value to the document. /// /// Args: - /// field_name (str): The field for which we are adding the bytes. + /// field_name (str): The field for which we are adding the JSON. /// value (str | Dict[str, Any]): The JSON object that will be added /// to the document. /// @@ -716,6 +716,25 @@ impl Document { } } + /// Add an IP address value to the document. + /// + /// Args: + /// field_name (str): The field for which we are adding the IP address. + /// value (str): The IP address object that will be added + /// to the document. + /// + /// Raises a ValueError if the IP address is invalid. + fn add_ip_addr(&mut self, field_name: String, value: &str) -> PyResult<()> { + let ip_addr = IpAddr::from_str(value).map_err(to_pyerr)?; + match ip_addr { + IpAddr::V4(addr) => { + self.add_value(field_name, addr.to_ipv6_mapped()) + } + IpAddr::V6(addr) => self.add_value(field_name, addr), + } + Ok(()) + } + /// Returns the number of added fields that have been added to the document #[getter] fn num_fields(&self) -> usize { diff --git a/src/lib.rs b/src/lib.rs index dd1a5af..e583657 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,17 +12,15 @@ mod schemabuilder; mod searcher; mod snippet; -use document::Document; +use document::{extract_value, extract_value_for_type, Document}; use facet::Facet; use index::Index; use query::{Occur, Query}; -use schema::Schema; +use schema::{FieldType, Schema}; use schemabuilder::SchemaBuilder; use searcher::{DocAddress, Order, SearchResult, Searcher}; use snippet::{Snippet, SnippetGenerator}; -use crate::document::extract_value; - /// Python bindings for the search engine library Tantivy. /// /// Tantivy is a full text search engine library written in rust. @@ -88,6 +86,7 @@ fn tantivy(_py: Python, m: &Bound) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_wrapped(wrap_pymodule!(query_parser_error))?; @@ -183,3 +182,31 @@ pub(crate) fn make_term( Ok(term) } + +pub(crate) fn make_term_for_type( + schema: &tv::schema::Schema, + field_name: &str, + field_type: FieldType, + field_value: &Bound, +) -> PyResult { + let field = get_field(schema, field_name)?; + let value = + extract_value_for_type(field_value, field_type.into(), field_name)?; + let term = match value { + Value::Str(text) => Term::from_field_text(field, &text), + Value::U64(num) => Term::from_field_u64(field, num), + Value::I64(num) => Term::from_field_i64(field, num), + Value::F64(num) => Term::from_field_f64(field, num), + Value::Date(d) => Term::from_field_date(field, d), + Value::Facet(facet) => Term::from_facet(field, &facet), + Value::Bool(b) => Term::from_field_bool(field, b), + Value::IpAddr(i) => Term::from_field_ip_addr(field, i), + _ => { + return Err(exceptions::PyValueError::new_err(format!( + "Can't create a term for Field `{field_name}` with value `{field_value}`." + ))) + } + }; + + Ok(term) +} diff --git a/src/query.rs b/src/query.rs index d38a747..a043936 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,4 +1,8 @@ -use crate::{get_field, make_term, to_pyerr, DocAddress, Schema}; +use crate::{ + get_field, make_term, make_term_for_type, schema::FieldType, to_pyerr, + DocAddress, Schema, +}; +use core::ops::Bound as OpsBound; use pyo3::{ exceptions, prelude::*, @@ -328,4 +332,81 @@ impl Query { inner: Box::new(inner), }) } + + #[staticmethod] + #[pyo3(signature = (schema, field_name, field_type, lower_bound, upper_bound, include_lower = true, include_upper = true))] + pub(crate) fn range_query( + schema: &Schema, + field_name: &str, + field_type: FieldType, + lower_bound: &Bound, + upper_bound: &Bound, + include_lower: bool, + include_upper: bool, + ) -> PyResult { + match field_type { + FieldType::Text => { + return Err(exceptions::PyValueError::new_err( + "Text fields are not supported for range queries.", + )) + } + FieldType::Boolean => { + return Err(exceptions::PyValueError::new_err( + "Boolean fields are not supported for range queries.", + )) + } + FieldType::Facet => { + return Err(exceptions::PyValueError::new_err( + "Facet fields are not supported for range queries.", + )) + } + FieldType::Bytes => { + return Err(exceptions::PyValueError::new_err( + "Bytes fields are not supported for range queries.", + )) + } + FieldType::Json => { + return Err(exceptions::PyValueError::new_err( + "Json fields are not supported for range queries.", + )) + } + _ => {} + } + + let lower_bound_term = make_term_for_type( + &schema.inner, + field_name, + field_type.clone(), + lower_bound, + )?; + let upper_bound_term = make_term_for_type( + &schema.inner, + field_name, + field_type.clone(), + upper_bound, + )?; + + let lower_bound = if include_lower { + OpsBound::Included(lower_bound_term) + } else { + OpsBound::Excluded(lower_bound_term) + }; + + let upper_bound = if include_upper { + OpsBound::Included(upper_bound_term) + } else { + OpsBound::Excluded(upper_bound_term) + }; + + let inner = tv::query::RangeQuery::new_term_bounds( + field_name.to_string(), + field_type.into(), + &lower_bound, + &upper_bound, + ); + + Ok(Query { + inner: Box::new(inner), + }) + } } diff --git a/src/schema.rs b/src/schema.rs index d1b5549..f18f84e 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -3,6 +3,39 @@ use pyo3::{basic::CompareOp, prelude::*, types::PyTuple}; use serde::{Deserialize, Serialize}; use tantivy as tv; +/// Tantivy's Type +#[pyclass(frozen, module = "tantivy.tantivy")] +#[derive(Clone, PartialEq)] +pub(crate) enum FieldType { + Text, + Unsigned, + Integer, + Float, + Boolean, + Date, + Facet, + Bytes, + Json, + IpAddr, +} + +impl From for tv::schema::Type { + fn from(field_type: FieldType) -> tv::schema::Type { + match field_type { + FieldType::Text => tv::schema::Type::Str, + FieldType::Unsigned => tv::schema::Type::U64, + FieldType::Integer => tv::schema::Type::I64, + FieldType::Float => tv::schema::Type::F64, + FieldType::Boolean => tv::schema::Type::Str, + FieldType::Date => tv::schema::Type::Date, + FieldType::Facet => tv::schema::Type::Facet, + FieldType::Bytes => tv::schema::Type::Bytes, + FieldType::Json => tv::schema::Type::Json, + FieldType::IpAddr => tv::schema::Type::IpAddr, + } + } +} + /// Tantivy schema. /// /// The schema is very strict. To build the schema the `SchemaBuilder` class is diff --git a/tantivy/tantivy.pyi b/tantivy/tantivy.pyi index 8dfe9ad..e77180f 100644 --- a/tantivy/tantivy.pyi +++ b/tantivy/tantivy.pyi @@ -1,6 +1,6 @@ import datetime from enum import Enum -from typing import Any, Optional, Sequence, Union +from typing import Any, Optional, Sequence, TypeVar, Union class Schema: pass @@ -168,6 +168,9 @@ class Document: def add_json(self, field_name: str, value: Any) -> None: pass + def add_ip_addr(self, field_name: str, ip_addr: str) -> None: + pass + @property def num_fields(self) -> int: pass @@ -187,6 +190,20 @@ class Occur(Enum): Should = 2 MustNot = 3 +class FieldType(Enum): + Text = 1 + Unsigned = 2 + Integer = 3 + Float = 4 + Boolean = 5 + Date = 6 + Facet = 7 + Bytes = 8 + Json = 9 + IpAddr = 10 + +_RangeType = TypeVar("_RangeType", bound=int | float | datetime.datetime | bool | str | bytes) + class Query: @staticmethod def term_query( @@ -255,7 +272,20 @@ class Query: @staticmethod def const_score_query(query: Query, score: float) -> Query: pass - + + @staticmethod + def range_query( + schema: Schema, + field_name: str, + field_type: FieldType, + lower_bound: _RangeType, + upper_bound: _RangeType, + include_lower: bool = True, + include_upper: bool = True, + ) -> Query: + pass + + class Order(Enum): Asc = 1 Desc = 2 diff --git a/tests/conftest.py b/tests/conftest.py index 74c7c80..539c75d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +from datetime import datetime import pytest from tantivy import SchemaBuilder, Index, Document @@ -22,6 +23,23 @@ def schema_numeric_fields(): .build() ) +def schema_with_date_field(): + return ( + SchemaBuilder() + .add_integer_field("id", stored=True, indexed=True) + .add_float_field("rating", stored=True, indexed=True) + .add_date_field("date", stored=True, indexed=True) + .build() + ) + +def schema_with_ip_addr_field(): + return ( + SchemaBuilder() + .add_integer_field("id", stored=True, indexed=True) + .add_float_field("rating", stored=True, indexed=True) + .add_ip_addr_field("ip_addr", stored=True, indexed=True) + .build() + ) def create_index(dir=None): # assume all tests will use the same documents for now @@ -122,6 +140,62 @@ def create_index_with_numeric_fields(dir=None): index.reload() return index +def create_index_with_date_field(dir=None): + index = Index(schema_with_date_field(), dir) + writer = index.writer(15_000_000, 1) + + doc = Document() + doc.add_integer("id", 1) + doc.add_float("rating", 3.5) + doc.add_date("date", datetime(2021, 1, 1)) + + writer.add_document(doc) + doc = Document.from_dict( + { + "id": 2, + "rating": 4.5, + "date": datetime(2021, 1, 2), + }, + ) + writer.add_document(doc) + writer.commit() + writer.wait_merging_threads() + index.reload() + return index + +def create_index_with_ip_addr_field(dir=None): + schema = schema_with_ip_addr_field() + index = Index(schema, dir) + writer = index.writer(15_000_000, 1) + + doc = Document() + doc.add_integer("id", 1) + doc.add_float("rating", 3.5) + doc.add_ip_addr("ip_addr", "10.0.0.1") + writer.add_document(doc) + + doc = Document.from_dict( + { + "id": 2, + "rating": 4.5, + "ip_addr": "127.0.0.1", + }, + schema + ) + writer.add_document(doc) + doc = Document.from_dict( + { + "id": 2, + "rating": 4.5, + "ip_addr": "::1", + }, + schema + ) + writer.add_document(doc) + writer.commit() + writer.wait_merging_threads() + index.reload() + return index def spanish_schema(): return ( @@ -188,6 +262,13 @@ def ram_index(): def ram_index_numeric_fields(): return create_index_with_numeric_fields() +@pytest.fixture(scope="class") +def ram_index_with_date_field(): + return create_index_with_date_field() + +@pytest.fixture(scope="class") +def ram_index_with_ip_addr_field(): + return create_index_with_ip_addr_field() @pytest.fixture(scope="class") def spanish_index(): diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 3ad9a2f..6200ed7 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -8,7 +8,7 @@ import pytest import tantivy from conftest import schema, schema_numeric_fields -from tantivy import Document, Index, SchemaBuilder, SnippetGenerator, Query, Occur +from tantivy import Document, Index, SchemaBuilder, SnippetGenerator, Query, Occur, FieldType class TestClass(object): @@ -1260,3 +1260,180 @@ class TestQuery(object): # wrong score type with pytest.raises(TypeError, match = r"argument 'score': must be real number, not str"): Query.const_score_query(query, "0.1") + + def test_range_query_numerics(self, ram_index_numeric_fields): + index = ram_index_numeric_fields + + # test integer field including both bounds + query = Query.range_query(index.schema, "id", FieldType.Integer, 1, 2) + result = index.searcher().search(query, 10) + assert len(result.hits) == 2 + + # test integer field excluding the lower bound + query = Query.range_query(index.schema, "id", FieldType.Integer, 1, 2, include_lower=False) + result = index.searcher().search(query, 10) + assert len(result.hits) == 1 + _, doc_address = result.hits[0] + searched_doc = index.searcher().doc(doc_address) + assert searched_doc["id"][0] == 2 + + # test float field including both bounds + query = Query.range_query(index.schema, "rating", FieldType.Float, 3.5, 4.0) + result = index.searcher().search(query, 10) + assert len(result.hits) == 1 + _, doc_address = result.hits[0] + searched_doc = index.searcher().doc(doc_address) + assert searched_doc["id"][0] == 1 + + # test float field excluding the lower bound + query = Query.range_query(index.schema, "rating", FieldType.Float, 3.5, 4.0, include_lower=False) + result = index.searcher().search(query, 10) + assert len(result.hits) == 0 + + # test float field excluding the upper bound + query = Query.range_query(index.schema, "rating", FieldType.Float, 3.0, 3.5, include_upper=False) + result = index.searcher().search(query, 10) + assert len(result.hits) == 0 + + # test if the lower bound is greater than the upper bound + query = Query.range_query(index.schema, "rating", FieldType.Float, 4.0, 3.5) + result = index.searcher().search(query, 10) + assert len(result.hits) == 0 + + def test_range_query_dates(self, ram_index_with_date_field): + index = ram_index_with_date_field + + # test date field including both bounds + query = Query.range_query( + index.schema, + "date", + FieldType.Date, + datetime.datetime(2020, 1, 1), + datetime.datetime(2022, 1, 1) + ) + result = index.searcher().search(query, 10) + assert len(result.hits) == 2 + + # test date field excluding the lower bound + query = Query.range_query( + index.schema, "date", + FieldType.Date, + datetime.datetime(2020, 1, 1), + datetime.datetime(2021, 1, 1), + include_lower=False + ) + result = index.searcher().search(query, 10) + assert len(result.hits) == 1 + + # test date field excluding the upper bound + query = Query.range_query( + index.schema, + "date", + FieldType.Date, + datetime.datetime(2020, 1, 1), + datetime.datetime(2021, 1, 1), + include_upper=False + ) + result = index.searcher().search(query, 10) + assert len(result.hits) == 0 + + def test_range_query_ip_addrs(self, ram_index_with_ip_addr_field): + index = ram_index_with_ip_addr_field + + # test ip address field including both bounds + query = Query.range_query( + index.schema, + "ip_addr", + FieldType.IpAddr, + "10.0.0.0", + "10.0.255.255" + ) + result = index.searcher().search(query, 10) + assert len(result.hits) == 1 + + query = Query.range_query( + index.schema, + "ip_addr", + FieldType.IpAddr, + "0.0.0.0", + "255.255.255.255" + ) + result = index.searcher().search(query, 10) + assert len(result.hits) == 2 + + # test ip address field excluding the lower bound + query = Query.range_query( + index.schema, + "ip_addr", + FieldType.IpAddr, + "10.0.0.1", + "10.0.0.255", + include_lower=False + ) + + result = index.searcher().search(query, 10) + assert len(result.hits) == 0 + + # test ip address field excluding the upper bound + query = Query.range_query( + index.schema, + "ip_addr", + FieldType.IpAddr, + "127.0.0.0", + "127.0.0.1", + include_upper=False + ) + result = index.searcher().search(query, 10) + assert len(result.hits) == 0 + + # test loopback address + query = Query.range_query( + index.schema, + "ip_addr", + FieldType.IpAddr, + "::1", + "::1" + ) + result = index.searcher().search(query, 10) + assert len(result.hits) == 1 + + def test_range_query_invalid_types( + self, + ram_index, + ram_index_numeric_fields, + ram_index_with_date_field, + ram_index_with_ip_addr_field + ): + index = ram_index + query = Query.range_query(index.schema, "title", FieldType.Integer, 1, 2) + with pytest.raises(ValueError, match="Create a range query of the type I64, when the field given was of type Str"): + index.searcher().search(query, 10) + + index = ram_index_numeric_fields + query = Query.range_query(index.schema, "id", FieldType.Float, 1.0, 2.0) + with pytest.raises(ValueError, match="Create a range query of the type F64, when the field given was of type I64"): + index.searcher().search(query, 10) + + index = ram_index_with_date_field + query = Query.range_query(index.schema, "date", FieldType.Integer, 1, 2) + with pytest.raises(ValueError, match="Create a range query of the type I64, when the field given was of type Date"): + index.searcher().search(query, 10) + + index = ram_index_with_ip_addr_field + query = Query.range_query(index.schema, "ip_addr", FieldType.Integer, 1, 2) + with pytest.raises(ValueError, match="Create a range query of the type I64, when the field given was of type IpAddr"): + index.searcher().search(query, 10) + + def test_range_query_unsupported_types(self, ram_index): + index = ram_index + with pytest.raises(ValueError, match="Text fields are not supported for range queries."): + Query.range_query(index.schema, "title", FieldType.Text, 1, 2) + + with pytest.raises(ValueError, match="Json fields are not supported for range queries."): + Query.range_query(index.schema, "title", FieldType.Json, 1, 2) + + with pytest.raises(ValueError, match="Bytes fields are not supported for range queries."): + Query.range_query(index.schema, "title", FieldType.Bytes, 1, 2) + + with pytest.raises(ValueError, match="Facet fields are not supported for range queries."): + Query.range_query(index.schema, "title", FieldType.Facet, 1, 2)