Better support bytes, IPs, and JSON (#152)
parent
4ac17da8f6
commit
eeaad34a98
|
@ -6,9 +6,10 @@ use pyo3::{
|
||||||
basic::CompareOp,
|
basic::CompareOp,
|
||||||
prelude::*,
|
prelude::*,
|
||||||
types::{
|
types::{
|
||||||
PyAny, PyBool, PyDateAccess, PyDateTime, PyDict, PyList, PyTimeAccess,
|
PyAny, PyBool, PyDateAccess, PyDateTime, PyDict, PyInt, PyList,
|
||||||
PyTuple,
|
PyTimeAccess, PyTuple,
|
||||||
},
|
},
|
||||||
|
Python,
|
||||||
};
|
};
|
||||||
|
|
||||||
use chrono::{offset::TimeZone, NaiveDateTime, Utc};
|
use chrono::{offset::TimeZone, NaiveDateTime, Utc};
|
||||||
|
@ -23,7 +24,8 @@ use serde_json::Value as JsonValue;
|
||||||
use std::{
|
use std::{
|
||||||
collections::{BTreeMap, HashMap},
|
collections::{BTreeMap, HashMap},
|
||||||
fmt,
|
fmt,
|
||||||
net::Ipv6Addr,
|
net::{IpAddr, Ipv6Addr},
|
||||||
|
str::FromStr,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub(crate) fn extract_value(any: &PyAny) -> PyResult<Value> {
|
pub(crate) fn extract_value(any: &PyAny) -> PyResult<Value> {
|
||||||
|
@ -50,6 +52,11 @@ pub(crate) fn extract_value(any: &PyAny) -> PyResult<Value> {
|
||||||
if let Ok(b) = any.extract::<Vec<u8>>() {
|
if let Ok(b) = any.extract::<Vec<u8>>() {
|
||||||
return Ok(Value::Bytes(b));
|
return Ok(Value::Bytes(b));
|
||||||
}
|
}
|
||||||
|
if let Ok(dict) = any.downcast::<PyDict>() {
|
||||||
|
if let Ok(json) = pythonize::depythonize(dict) {
|
||||||
|
return Ok(Value::JsonObject(json));
|
||||||
|
}
|
||||||
|
}
|
||||||
Err(to_pyerr(format!("Value unsupported {any:?}")))
|
Err(to_pyerr(format!("Value unsupported {any:?}")))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -105,7 +112,37 @@ pub(crate) fn extract_value_for_type(
|
||||||
.map_err(to_pyerr_for_type("Facet", field_name, any))?
|
.map_err(to_pyerr_for_type("Facet", field_name, any))?
|
||||||
.inner,
|
.inner,
|
||||||
),
|
),
|
||||||
_ => return Err(to_pyerr(format!("Value unsupported {:?}", any))),
|
tv::schema::Type::Bytes => Value::Bytes(
|
||||||
|
any.extract::<Vec<u8>>()
|
||||||
|
.map_err(to_pyerr_for_type("Bytes", field_name, any))?,
|
||||||
|
),
|
||||||
|
tv::schema::Type::Json => {
|
||||||
|
if let Ok(json_str) = any.extract::<&str>() {
|
||||||
|
return serde_json::from_str(json_str)
|
||||||
|
.map(Value::JsonObject)
|
||||||
|
.map_err(to_pyerr_for_type("Json", field_name, any));
|
||||||
|
}
|
||||||
|
|
||||||
|
Value::JsonObject(
|
||||||
|
any.downcast::<PyDict>()
|
||||||
|
.map(|dict| pythonize::depythonize(&dict))
|
||||||
|
.map_err(to_pyerr_for_type("Json", field_name, any))?
|
||||||
|
.map_err(to_pyerr_for_type("Json", field_name, any))?,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
tv::schema::Type::IpAddr => {
|
||||||
|
let val = any
|
||||||
|
.extract::<&str>()
|
||||||
|
.map_err(to_pyerr_for_type("IpAddr", field_name, any))?;
|
||||||
|
|
||||||
|
IpAddr::from_str(val)
|
||||||
|
.map(|addr| match addr {
|
||||||
|
IpAddr::V4(addr) => addr.to_ipv6_mapped(),
|
||||||
|
IpAddr::V6(addr) => addr,
|
||||||
|
})
|
||||||
|
.map(Value::IpAddr)
|
||||||
|
.map_err(to_pyerr_for_type("IpAddr", field_name, any))?
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(value)
|
Ok(value)
|
||||||
|
@ -126,6 +163,20 @@ fn extract_value_single_or_list_for_type(
|
||||||
) -> PyResult<Vec<Value>> {
|
) -> PyResult<Vec<Value>> {
|
||||||
// Check if a numeric fast field supports multivalues.
|
// Check if a numeric fast field supports multivalues.
|
||||||
if let Ok(values) = any.downcast::<PyList>() {
|
if let Ok(values) = any.downcast::<PyList>() {
|
||||||
|
// Process an array of integers as a single entry if it is a bytes field.
|
||||||
|
if field_type.value_type() == tv::schema::Type::Bytes
|
||||||
|
&& values
|
||||||
|
.get_item(0)
|
||||||
|
.map(|v| v.is_instance_of::<PyInt>())
|
||||||
|
.unwrap_or(false)
|
||||||
|
{
|
||||||
|
return Ok(vec![extract_value_for_type(
|
||||||
|
values,
|
||||||
|
field_type.value_type(),
|
||||||
|
field_name,
|
||||||
|
)?]);
|
||||||
|
}
|
||||||
|
|
||||||
values
|
values
|
||||||
.iter()
|
.iter()
|
||||||
.map(|any| {
|
.map(|any| {
|
||||||
|
|
|
@ -2,11 +2,11 @@
|
||||||
|
|
||||||
use pyo3::{exceptions, prelude::*};
|
use pyo3::{exceptions, prelude::*};
|
||||||
|
|
||||||
use tantivy::schema;
|
|
||||||
|
|
||||||
use crate::schema::Schema;
|
use crate::schema::Schema;
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
use tantivy::schema::{DateOptions, INDEXED};
|
use tantivy::schema::{
|
||||||
|
self, BytesOptions, DateOptions, IpAddrOptions, INDEXED,
|
||||||
|
};
|
||||||
|
|
||||||
/// Tantivy has a very strict schema.
|
/// Tantivy has a very strict schema.
|
||||||
/// You need to specify in advance whether a field is indexed or not,
|
/// You need to specify in advance whether a field is indexed or not,
|
||||||
|
@ -357,17 +357,43 @@ impl SchemaBuilder {
|
||||||
|
|
||||||
/// Add a fast bytes field to the schema.
|
/// Add a fast bytes field to the schema.
|
||||||
///
|
///
|
||||||
/// Bytes field are not searchable and are only used
|
|
||||||
/// as fast field, to associate any kind of payload
|
|
||||||
/// to a document.
|
|
||||||
///
|
|
||||||
/// Args:
|
/// Args:
|
||||||
/// name (str): The name of the field.
|
/// name (str): The name of the field.
|
||||||
fn add_bytes_field(&mut self, name: &str) -> PyResult<Self> {
|
/// stored (bool, optional): If true sets the field as stored, the
|
||||||
|
/// content of the field can be later restored from a Searcher.
|
||||||
|
/// Defaults to False.
|
||||||
|
/// indexed (bool, optional): If true sets the field to be indexed.
|
||||||
|
/// fast (str, optional): Set the bytes options as a fast field. A fast
|
||||||
|
/// field is a column-oriented fashion storage for tantivy. It is
|
||||||
|
/// designed for the fast random access of some document fields
|
||||||
|
/// given a document id.
|
||||||
|
#[pyo3(signature = (
|
||||||
|
name,
|
||||||
|
stored = false,
|
||||||
|
indexed = false,
|
||||||
|
fast = false
|
||||||
|
))]
|
||||||
|
fn add_bytes_field(
|
||||||
|
&mut self,
|
||||||
|
name: &str,
|
||||||
|
stored: bool,
|
||||||
|
indexed: bool,
|
||||||
|
fast: bool,
|
||||||
|
) -> PyResult<Self> {
|
||||||
let builder = &mut self.builder;
|
let builder = &mut self.builder;
|
||||||
|
let mut opts = BytesOptions::default();
|
||||||
|
if stored {
|
||||||
|
opts = opts.set_stored();
|
||||||
|
}
|
||||||
|
if indexed {
|
||||||
|
opts = opts.set_indexed();
|
||||||
|
}
|
||||||
|
if fast {
|
||||||
|
opts = opts.set_fast();
|
||||||
|
}
|
||||||
|
|
||||||
if let Some(builder) = builder.write().unwrap().as_mut() {
|
if let Some(builder) = builder.write().unwrap().as_mut() {
|
||||||
builder.add_bytes_field(name, INDEXED);
|
builder.add_bytes_field(name, opts);
|
||||||
} else {
|
} else {
|
||||||
return Err(exceptions::PyValueError::new_err(
|
return Err(exceptions::PyValueError::new_err(
|
||||||
"Schema builder object isn't valid anymore.",
|
"Schema builder object isn't valid anymore.",
|
||||||
|
@ -376,6 +402,54 @@ impl SchemaBuilder {
|
||||||
Ok(self.clone())
|
Ok(self.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Add an IP address field to the schema.
|
||||||
|
///
|
||||||
|
/// Args:
|
||||||
|
/// name (str): The name of the field.
|
||||||
|
/// stored (bool, optional): If true sets the field as stored, the
|
||||||
|
/// content of the field can be later restored from a Searcher.
|
||||||
|
/// Defaults to False.
|
||||||
|
/// indexed (bool, optional): If true sets the field to be indexed.
|
||||||
|
/// fast (str, optional): Set the IP address options as a fast field. A
|
||||||
|
/// fast field is a column-oriented fashion storage for tantivy. It
|
||||||
|
/// is designed for the fast random access of some document fields
|
||||||
|
/// given a document id.
|
||||||
|
#[pyo3(signature = (
|
||||||
|
name,
|
||||||
|
stored = false,
|
||||||
|
indexed = false,
|
||||||
|
fast = false
|
||||||
|
))]
|
||||||
|
fn add_ip_addr_field(
|
||||||
|
&mut self,
|
||||||
|
name: &str,
|
||||||
|
stored: bool,
|
||||||
|
indexed: bool,
|
||||||
|
fast: bool,
|
||||||
|
) -> PyResult<Self> {
|
||||||
|
let builder = &mut self.builder;
|
||||||
|
let mut opts = IpAddrOptions::default();
|
||||||
|
if stored {
|
||||||
|
opts = opts.set_stored();
|
||||||
|
}
|
||||||
|
if indexed {
|
||||||
|
opts = opts.set_indexed();
|
||||||
|
}
|
||||||
|
if fast {
|
||||||
|
opts = opts.set_fast();
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(builder) = builder.write().unwrap().as_mut() {
|
||||||
|
builder.add_ip_addr_field(name, opts);
|
||||||
|
} else {
|
||||||
|
return Err(exceptions::PyValueError::new_err(
|
||||||
|
"Schema builder object isn't valid anymore.",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(self.clone())
|
||||||
|
}
|
||||||
|
|
||||||
/// Finalize the creation of a Schema.
|
/// Finalize the creation of a Schema.
|
||||||
///
|
///
|
||||||
/// Returns a Schema object. After this is called the SchemaBuilder cannot
|
/// Returns a Schema object. After this is called the SchemaBuilder cannot
|
||||||
|
|
|
@ -2,6 +2,7 @@ from io import BytesIO
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import datetime
|
import datetime
|
||||||
|
import json
|
||||||
import tantivy
|
import tantivy
|
||||||
import pickle
|
import pickle
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -365,7 +366,9 @@ class TestClass(object):
|
||||||
searched_doc = index.searcher().doc(doc_address)
|
searched_doc = index.searcher().doc(doc_address)
|
||||||
assert searched_doc["title"] == ["Test title"]
|
assert searched_doc["title"] == ["Test title"]
|
||||||
|
|
||||||
result = searcher.search(query, 10, order_by_field="order", order=tantivy.Order.Asc)
|
result = searcher.search(
|
||||||
|
query, 10, order_by_field="order", order=tantivy.Order.Asc
|
||||||
|
)
|
||||||
|
|
||||||
assert len(result.hits) == 3
|
assert len(result.hits) == 3
|
||||||
|
|
||||||
|
@ -443,7 +446,7 @@ class TestClass(object):
|
||||||
|
|
||||||
assert searcher.num_segments < 8
|
assert searcher.num_segments < 8
|
||||||
|
|
||||||
def test_doc_from_dict_schema_validation(self):
|
def test_doc_from_dict_numeric_validation(self):
|
||||||
schema = (
|
schema = (
|
||||||
SchemaBuilder()
|
SchemaBuilder()
|
||||||
.add_unsigned_field("unsigned")
|
.add_unsigned_field("unsigned")
|
||||||
|
@ -504,6 +507,70 @@ class TestClass(object):
|
||||||
schema,
|
schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_doc_from_dict_bytes_validation(self):
|
||||||
|
schema = SchemaBuilder().add_bytes_field("bytes").build()
|
||||||
|
|
||||||
|
good = Document.from_dict({"bytes": b"hello"}, schema)
|
||||||
|
good = Document.from_dict({"bytes": [[1, 2, 3], [4, 5, 6]]}, schema)
|
||||||
|
good = Document.from_dict({"bytes": [1, 2, 3]}, schema)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
bad = Document.from_dict({"bytes": [1, 2, 256]}, schema)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
bad = Document.from_dict({"bytes": "hello"}, schema)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
bad = Document.from_dict({"bytes": [1024, "there"]}, schema)
|
||||||
|
|
||||||
|
def test_doc_from_dict_ip_addr_validation(self):
|
||||||
|
schema = SchemaBuilder().add_ip_addr_field("ip").build()
|
||||||
|
|
||||||
|
good = Document.from_dict({"ip": "127.0.0.1"}, schema)
|
||||||
|
good = Document.from_dict({"ip": "::1"}, schema)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
bad = Document.from_dict({"ip": 12309812348}, schema)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
bad = Document.from_dict({"ip": "256.100.0.1"}, schema)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
bad = Document.from_dict(
|
||||||
|
{"ip": "1234:5678:9ABC:DEF0:1234:5678:9ABC:DEF0:1234"}, schema
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
bad = Document.from_dict(
|
||||||
|
{"ip": "1234:5678:9ABC:DEF0:1234:5678:9ABC:GHIJ"}, schema
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_doc_from_dict_json_validation(self):
|
||||||
|
# Test implicit JSON
|
||||||
|
good = Document.from_dict({"dict": {"hello": "world"}})
|
||||||
|
|
||||||
|
schema = SchemaBuilder().add_json_field("json").build()
|
||||||
|
|
||||||
|
good = Document.from_dict({"json": {}}, schema)
|
||||||
|
good = Document.from_dict({"json": {"hello": "world"}}, schema)
|
||||||
|
good = Document.from_dict(
|
||||||
|
{"nested": {"hello": ["world", "!"]}, "numbers": [1, 2, 3]}, schema
|
||||||
|
)
|
||||||
|
|
||||||
|
list_of_jsons = [
|
||||||
|
{"hello": "world"},
|
||||||
|
{"nested": {"hello": ["world", "!"]}, "numbers": [1, 2, 3]},
|
||||||
|
]
|
||||||
|
good = Document.from_dict({"json": list_of_jsons}, schema)
|
||||||
|
|
||||||
|
good = Document.from_dict({"json": json.dumps(list_of_jsons[1])}, schema)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
bad = Document.from_dict({"json": 123}, schema)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
bad = Document.from_dict({"json": "hello"}, schema)
|
||||||
|
|
||||||
def test_search_result_eq(self, ram_index, spanish_index):
|
def test_search_result_eq(self, ram_index, spanish_index):
|
||||||
eng_index = ram_index
|
eng_index = ram_index
|
||||||
eng_query = eng_index.parse_query("sea whale", ["title", "body"])
|
eng_query = eng_index.parse_query("sea whale", ["title", "body"])
|
||||||
|
@ -650,10 +717,6 @@ class TestDocument(object):
|
||||||
doc = tantivy.Document(facet=facet)
|
doc = tantivy.Document(facet=facet)
|
||||||
assert doc["facet"][0].to_path() == ["asia/oceania", "fiji"]
|
assert doc["facet"][0].to_path() == ["asia/oceania", "fiji"]
|
||||||
|
|
||||||
def test_document_error(self):
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
tantivy.Document(name={})
|
|
||||||
|
|
||||||
def test_document_eq(self):
|
def test_document_eq(self):
|
||||||
doc1 = tantivy.Document(name="Bill", reference=[1, 2])
|
doc1 = tantivy.Document(name="Bill", reference=[1, 2])
|
||||||
doc2 = tantivy.Document.from_dict({"name": "Bill", "reference": [1, 2]})
|
doc2 = tantivy.Document.from_dict({"name": "Bill", "reference": [1, 2]})
|
||||||
|
@ -848,9 +911,11 @@ class TestSnippets(object):
|
||||||
result = searcher.search(query)
|
result = searcher.search(query)
|
||||||
assert len(result.hits) == 1
|
assert len(result.hits) == 1
|
||||||
|
|
||||||
snippet_generator = SnippetGenerator.create(searcher, query, doc_schema, "title")
|
snippet_generator = SnippetGenerator.create(
|
||||||
|
searcher, query, doc_schema, "title"
|
||||||
|
)
|
||||||
|
|
||||||
for (score, doc_address) in result.hits:
|
for score, doc_address in result.hits:
|
||||||
doc = searcher.doc(doc_address)
|
doc = searcher.doc(doc_address)
|
||||||
snippet = snippet_generator.snippet_from_doc(doc)
|
snippet = snippet_generator.snippet_from_doc(doc)
|
||||||
highlights = snippet.highlighted()
|
highlights = snippet.highlighted()
|
||||||
|
@ -859,4 +924,4 @@ class TestSnippets(object):
|
||||||
assert first.start == 20
|
assert first.start == 20
|
||||||
assert first.end == 23
|
assert first.end == 23
|
||||||
html_snippet = snippet.to_html()
|
html_snippet = snippet.to_html()
|
||||||
assert html_snippet == 'The Old Man and the <b>Sea</b>'
|
assert html_snippet == "The Old Man and the <b>Sea</b>"
|
||||||
|
|
Loading…
Reference in New Issue