Expose Range Query (Beta) (#281)

Co-authored-by: Caleb Hattingh <caleb.hattingh@gmail.com>
master
alex-au-922 2024-06-10 07:00:39 +08:00 committed by GitHub
parent 8ece24161b
commit adc7b08b75
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 457 additions and 9 deletions

View File

@ -689,7 +689,7 @@ impl Document {
/// Add a JSON value to the document.
///
/// Args:
/// field_name (str): The field for which we are adding the bytes.
/// field_name (str): The field for which we are adding the JSON.
/// value (str | Dict[str, Any]): The JSON object that will be added
/// to the document.
///
@ -716,6 +716,25 @@ impl Document {
}
}
/// Add an IP address value to the document.
///
/// Args:
/// field_name (str): The field for which we are adding the IP address.
/// value (str): The IP address object that will be added
/// to the document.
///
/// Raises a ValueError if the IP address is invalid.
fn add_ip_addr(&mut self, field_name: String, value: &str) -> PyResult<()> {
let ip_addr = IpAddr::from_str(value).map_err(to_pyerr)?;
match ip_addr {
IpAddr::V4(addr) => {
self.add_value(field_name, addr.to_ipv6_mapped())
}
IpAddr::V6(addr) => self.add_value(field_name, addr),
}
Ok(())
}
/// Returns the number of added fields that have been added to the document
#[getter]
fn num_fields(&self) -> usize {

View File

@ -12,17 +12,15 @@ mod schemabuilder;
mod searcher;
mod snippet;
use document::Document;
use document::{extract_value, extract_value_for_type, Document};
use facet::Facet;
use index::Index;
use query::{Occur, Query};
use schema::Schema;
use schema::{FieldType, Schema};
use schemabuilder::SchemaBuilder;
use searcher::{DocAddress, Order, SearchResult, Searcher};
use snippet::{Snippet, SnippetGenerator};
use crate::document::extract_value;
/// Python bindings for the search engine library Tantivy.
///
/// Tantivy is a full text search engine library written in rust.
@ -88,6 +86,7 @@ fn tantivy(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
m.add_class::<Snippet>()?;
m.add_class::<SnippetGenerator>()?;
m.add_class::<Occur>()?;
m.add_class::<FieldType>()?;
m.add_wrapped(wrap_pymodule!(query_parser_error))?;
@ -183,3 +182,31 @@ pub(crate) fn make_term(
Ok(term)
}
pub(crate) fn make_term_for_type(
schema: &tv::schema::Schema,
field_name: &str,
field_type: FieldType,
field_value: &Bound<PyAny>,
) -> PyResult<tv::Term> {
let field = get_field(schema, field_name)?;
let value =
extract_value_for_type(field_value, field_type.into(), field_name)?;
let term = match value {
Value::Str(text) => Term::from_field_text(field, &text),
Value::U64(num) => Term::from_field_u64(field, num),
Value::I64(num) => Term::from_field_i64(field, num),
Value::F64(num) => Term::from_field_f64(field, num),
Value::Date(d) => Term::from_field_date(field, d),
Value::Facet(facet) => Term::from_facet(field, &facet),
Value::Bool(b) => Term::from_field_bool(field, b),
Value::IpAddr(i) => Term::from_field_ip_addr(field, i),
_ => {
return Err(exceptions::PyValueError::new_err(format!(
"Can't create a term for Field `{field_name}` with value `{field_value}`."
)))
}
};
Ok(term)
}

View File

@ -1,4 +1,8 @@
use crate::{get_field, make_term, to_pyerr, DocAddress, Schema};
use crate::{
get_field, make_term, make_term_for_type, schema::FieldType, to_pyerr,
DocAddress, Schema,
};
use core::ops::Bound as OpsBound;
use pyo3::{
exceptions,
prelude::*,
@ -328,4 +332,81 @@ impl Query {
inner: Box::new(inner),
})
}
#[staticmethod]
#[pyo3(signature = (schema, field_name, field_type, lower_bound, upper_bound, include_lower = true, include_upper = true))]
pub(crate) fn range_query(
schema: &Schema,
field_name: &str,
field_type: FieldType,
lower_bound: &Bound<PyAny>,
upper_bound: &Bound<PyAny>,
include_lower: bool,
include_upper: bool,
) -> PyResult<Query> {
match field_type {
FieldType::Text => {
return Err(exceptions::PyValueError::new_err(
"Text fields are not supported for range queries.",
))
}
FieldType::Boolean => {
return Err(exceptions::PyValueError::new_err(
"Boolean fields are not supported for range queries.",
))
}
FieldType::Facet => {
return Err(exceptions::PyValueError::new_err(
"Facet fields are not supported for range queries.",
))
}
FieldType::Bytes => {
return Err(exceptions::PyValueError::new_err(
"Bytes fields are not supported for range queries.",
))
}
FieldType::Json => {
return Err(exceptions::PyValueError::new_err(
"Json fields are not supported for range queries.",
))
}
_ => {}
}
let lower_bound_term = make_term_for_type(
&schema.inner,
field_name,
field_type.clone(),
lower_bound,
)?;
let upper_bound_term = make_term_for_type(
&schema.inner,
field_name,
field_type.clone(),
upper_bound,
)?;
let lower_bound = if include_lower {
OpsBound::Included(lower_bound_term)
} else {
OpsBound::Excluded(lower_bound_term)
};
let upper_bound = if include_upper {
OpsBound::Included(upper_bound_term)
} else {
OpsBound::Excluded(upper_bound_term)
};
let inner = tv::query::RangeQuery::new_term_bounds(
field_name.to_string(),
field_type.into(),
&lower_bound,
&upper_bound,
);
Ok(Query {
inner: Box::new(inner),
})
}
}

View File

@ -3,6 +3,39 @@ use pyo3::{basic::CompareOp, prelude::*, types::PyTuple};
use serde::{Deserialize, Serialize};
use tantivy as tv;
/// Tantivy's Type
#[pyclass(frozen, module = "tantivy.tantivy")]
#[derive(Clone, PartialEq)]
pub(crate) enum FieldType {
Text,
Unsigned,
Integer,
Float,
Boolean,
Date,
Facet,
Bytes,
Json,
IpAddr,
}
impl From<FieldType> for tv::schema::Type {
fn from(field_type: FieldType) -> tv::schema::Type {
match field_type {
FieldType::Text => tv::schema::Type::Str,
FieldType::Unsigned => tv::schema::Type::U64,
FieldType::Integer => tv::schema::Type::I64,
FieldType::Float => tv::schema::Type::F64,
FieldType::Boolean => tv::schema::Type::Str,
FieldType::Date => tv::schema::Type::Date,
FieldType::Facet => tv::schema::Type::Facet,
FieldType::Bytes => tv::schema::Type::Bytes,
FieldType::Json => tv::schema::Type::Json,
FieldType::IpAddr => tv::schema::Type::IpAddr,
}
}
}
/// Tantivy schema.
///
/// The schema is very strict. To build the schema the `SchemaBuilder` class is

View File

@ -1,6 +1,6 @@
import datetime
from enum import Enum
from typing import Any, Optional, Sequence, Union
from typing import Any, Optional, Sequence, TypeVar, Union
class Schema:
pass
@ -168,6 +168,9 @@ class Document:
def add_json(self, field_name: str, value: Any) -> None:
pass
def add_ip_addr(self, field_name: str, ip_addr: str) -> None:
pass
@property
def num_fields(self) -> int:
pass
@ -187,6 +190,20 @@ class Occur(Enum):
Should = 2
MustNot = 3
class FieldType(Enum):
Text = 1
Unsigned = 2
Integer = 3
Float = 4
Boolean = 5
Date = 6
Facet = 7
Bytes = 8
Json = 9
IpAddr = 10
_RangeType = TypeVar("_RangeType", bound=int | float | datetime.datetime | bool | str | bytes)
class Query:
@staticmethod
def term_query(
@ -256,6 +273,19 @@ class Query:
def const_score_query(query: Query, score: float) -> Query:
pass
@staticmethod
def range_query(
schema: Schema,
field_name: str,
field_type: FieldType,
lower_bound: _RangeType,
upper_bound: _RangeType,
include_lower: bool = True,
include_upper: bool = True,
) -> Query:
pass
class Order(Enum):
Asc = 1
Desc = 2

View File

@ -1,3 +1,4 @@
from datetime import datetime
import pytest
from tantivy import SchemaBuilder, Index, Document
@ -22,6 +23,23 @@ def schema_numeric_fields():
.build()
)
def schema_with_date_field():
return (
SchemaBuilder()
.add_integer_field("id", stored=True, indexed=True)
.add_float_field("rating", stored=True, indexed=True)
.add_date_field("date", stored=True, indexed=True)
.build()
)
def schema_with_ip_addr_field():
return (
SchemaBuilder()
.add_integer_field("id", stored=True, indexed=True)
.add_float_field("rating", stored=True, indexed=True)
.add_ip_addr_field("ip_addr", stored=True, indexed=True)
.build()
)
def create_index(dir=None):
# assume all tests will use the same documents for now
@ -122,6 +140,62 @@ def create_index_with_numeric_fields(dir=None):
index.reload()
return index
def create_index_with_date_field(dir=None):
index = Index(schema_with_date_field(), dir)
writer = index.writer(15_000_000, 1)
doc = Document()
doc.add_integer("id", 1)
doc.add_float("rating", 3.5)
doc.add_date("date", datetime(2021, 1, 1))
writer.add_document(doc)
doc = Document.from_dict(
{
"id": 2,
"rating": 4.5,
"date": datetime(2021, 1, 2),
},
)
writer.add_document(doc)
writer.commit()
writer.wait_merging_threads()
index.reload()
return index
def create_index_with_ip_addr_field(dir=None):
schema = schema_with_ip_addr_field()
index = Index(schema, dir)
writer = index.writer(15_000_000, 1)
doc = Document()
doc.add_integer("id", 1)
doc.add_float("rating", 3.5)
doc.add_ip_addr("ip_addr", "10.0.0.1")
writer.add_document(doc)
doc = Document.from_dict(
{
"id": 2,
"rating": 4.5,
"ip_addr": "127.0.0.1",
},
schema
)
writer.add_document(doc)
doc = Document.from_dict(
{
"id": 2,
"rating": 4.5,
"ip_addr": "::1",
},
schema
)
writer.add_document(doc)
writer.commit()
writer.wait_merging_threads()
index.reload()
return index
def spanish_schema():
return (
@ -188,6 +262,13 @@ def ram_index():
def ram_index_numeric_fields():
return create_index_with_numeric_fields()
@pytest.fixture(scope="class")
def ram_index_with_date_field():
return create_index_with_date_field()
@pytest.fixture(scope="class")
def ram_index_with_ip_addr_field():
return create_index_with_ip_addr_field()
@pytest.fixture(scope="class")
def spanish_index():

View File

@ -8,7 +8,7 @@ import pytest
import tantivy
from conftest import schema, schema_numeric_fields
from tantivy import Document, Index, SchemaBuilder, SnippetGenerator, Query, Occur
from tantivy import Document, Index, SchemaBuilder, SnippetGenerator, Query, Occur, FieldType
class TestClass(object):
@ -1260,3 +1260,180 @@ class TestQuery(object):
# wrong score type
with pytest.raises(TypeError, match = r"argument 'score': must be real number, not str"):
Query.const_score_query(query, "0.1")
def test_range_query_numerics(self, ram_index_numeric_fields):
index = ram_index_numeric_fields
# test integer field including both bounds
query = Query.range_query(index.schema, "id", FieldType.Integer, 1, 2)
result = index.searcher().search(query, 10)
assert len(result.hits) == 2
# test integer field excluding the lower bound
query = Query.range_query(index.schema, "id", FieldType.Integer, 1, 2, include_lower=False)
result = index.searcher().search(query, 10)
assert len(result.hits) == 1
_, doc_address = result.hits[0]
searched_doc = index.searcher().doc(doc_address)
assert searched_doc["id"][0] == 2
# test float field including both bounds
query = Query.range_query(index.schema, "rating", FieldType.Float, 3.5, 4.0)
result = index.searcher().search(query, 10)
assert len(result.hits) == 1
_, doc_address = result.hits[0]
searched_doc = index.searcher().doc(doc_address)
assert searched_doc["id"][0] == 1
# test float field excluding the lower bound
query = Query.range_query(index.schema, "rating", FieldType.Float, 3.5, 4.0, include_lower=False)
result = index.searcher().search(query, 10)
assert len(result.hits) == 0
# test float field excluding the upper bound
query = Query.range_query(index.schema, "rating", FieldType.Float, 3.0, 3.5, include_upper=False)
result = index.searcher().search(query, 10)
assert len(result.hits) == 0
# test if the lower bound is greater than the upper bound
query = Query.range_query(index.schema, "rating", FieldType.Float, 4.0, 3.5)
result = index.searcher().search(query, 10)
assert len(result.hits) == 0
def test_range_query_dates(self, ram_index_with_date_field):
index = ram_index_with_date_field
# test date field including both bounds
query = Query.range_query(
index.schema,
"date",
FieldType.Date,
datetime.datetime(2020, 1, 1),
datetime.datetime(2022, 1, 1)
)
result = index.searcher().search(query, 10)
assert len(result.hits) == 2
# test date field excluding the lower bound
query = Query.range_query(
index.schema, "date",
FieldType.Date,
datetime.datetime(2020, 1, 1),
datetime.datetime(2021, 1, 1),
include_lower=False
)
result = index.searcher().search(query, 10)
assert len(result.hits) == 1
# test date field excluding the upper bound
query = Query.range_query(
index.schema,
"date",
FieldType.Date,
datetime.datetime(2020, 1, 1),
datetime.datetime(2021, 1, 1),
include_upper=False
)
result = index.searcher().search(query, 10)
assert len(result.hits) == 0
def test_range_query_ip_addrs(self, ram_index_with_ip_addr_field):
index = ram_index_with_ip_addr_field
# test ip address field including both bounds
query = Query.range_query(
index.schema,
"ip_addr",
FieldType.IpAddr,
"10.0.0.0",
"10.0.255.255"
)
result = index.searcher().search(query, 10)
assert len(result.hits) == 1
query = Query.range_query(
index.schema,
"ip_addr",
FieldType.IpAddr,
"0.0.0.0",
"255.255.255.255"
)
result = index.searcher().search(query, 10)
assert len(result.hits) == 2
# test ip address field excluding the lower bound
query = Query.range_query(
index.schema,
"ip_addr",
FieldType.IpAddr,
"10.0.0.1",
"10.0.0.255",
include_lower=False
)
result = index.searcher().search(query, 10)
assert len(result.hits) == 0
# test ip address field excluding the upper bound
query = Query.range_query(
index.schema,
"ip_addr",
FieldType.IpAddr,
"127.0.0.0",
"127.0.0.1",
include_upper=False
)
result = index.searcher().search(query, 10)
assert len(result.hits) == 0
# test loopback address
query = Query.range_query(
index.schema,
"ip_addr",
FieldType.IpAddr,
"::1",
"::1"
)
result = index.searcher().search(query, 10)
assert len(result.hits) == 1
def test_range_query_invalid_types(
self,
ram_index,
ram_index_numeric_fields,
ram_index_with_date_field,
ram_index_with_ip_addr_field
):
index = ram_index
query = Query.range_query(index.schema, "title", FieldType.Integer, 1, 2)
with pytest.raises(ValueError, match="Create a range query of the type I64, when the field given was of type Str"):
index.searcher().search(query, 10)
index = ram_index_numeric_fields
query = Query.range_query(index.schema, "id", FieldType.Float, 1.0, 2.0)
with pytest.raises(ValueError, match="Create a range query of the type F64, when the field given was of type I64"):
index.searcher().search(query, 10)
index = ram_index_with_date_field
query = Query.range_query(index.schema, "date", FieldType.Integer, 1, 2)
with pytest.raises(ValueError, match="Create a range query of the type I64, when the field given was of type Date"):
index.searcher().search(query, 10)
index = ram_index_with_ip_addr_field
query = Query.range_query(index.schema, "ip_addr", FieldType.Integer, 1, 2)
with pytest.raises(ValueError, match="Create a range query of the type I64, when the field given was of type IpAddr"):
index.searcher().search(query, 10)
def test_range_query_unsupported_types(self, ram_index):
index = ram_index
with pytest.raises(ValueError, match="Text fields are not supported for range queries."):
Query.range_query(index.schema, "title", FieldType.Text, 1, 2)
with pytest.raises(ValueError, match="Json fields are not supported for range queries."):
Query.range_query(index.schema, "title", FieldType.Json, 1, 2)
with pytest.raises(ValueError, match="Bytes fields are not supported for range queries."):
Query.range_query(index.schema, "title", FieldType.Bytes, 1, 2)
with pytest.raises(ValueError, match="Facet fields are not supported for range queries."):
Query.range_query(index.schema, "title", FieldType.Facet, 1, 2)