Support pickling of some objects (#97)

master
Chris Tam 2023-08-26 08:13:29 -04:00 committed by GitHub
parent f12bac1f97
commit 05dde2d232
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 410 additions and 38 deletions

61
Cargo.lock generated
View File

@ -45,7 +45,7 @@ checksum = "1cd7fce9ba8c3c042128ce72d8b2ddbf3a05747efb67ea0313c635e10bda47a2"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 1.0.107",
]
[[package]]
@ -236,7 +236,7 @@ dependencies = [
"proc-macro2",
"quote",
"scratch",
"syn",
"syn 1.0.107",
]
[[package]]
@ -253,7 +253,7 @@ checksum = "ebf883b7aacd7b2aeb2a7b338648ee19f57c140d4ee8e52c68979c6b2f7f2263"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 1.0.107",
]
[[package]]
@ -387,7 +387,7 @@ checksum = "95a73af87da33b5acf53acfebdc339fe592ecf5357ac7c0a7734ab9d8c876a70"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 1.0.107",
]
[[package]]
@ -800,9 +800,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]]
name = "proc-macro2"
version = "1.0.51"
version = "1.0.66"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6"
checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9"
dependencies = [
"unicode-ident",
]
@ -854,7 +854,7 @@ dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn",
"syn 1.0.107",
]
[[package]]
@ -865,14 +865,24 @@ checksum = "e0b78ccbb160db1556cdb6fd96c50334c5d4ec44dc5e0a968d0a1208fa0efa8b"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 1.0.107",
]
[[package]]
name = "pythonize"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e35b716d430ace57e2d1b4afb51c9e5b7c46d2bce72926e07f9be6a98ced03e"
dependencies = [
"pyo3",
"serde",
]
[[package]]
name = "quote"
version = "1.0.23"
version = "1.0.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b"
checksum = "5fe8a65d69dd0808184ebb5f836ab526bb259db23c657efa38711b1072ee47f0"
dependencies = [
"proc-macro2",
]
@ -1032,22 +1042,22 @@ checksum = "ddccb15bcce173023b3fedd9436f882a0739b8dfb45e4f6b6002bee5929f61b2"
[[package]]
name = "serde"
version = "1.0.152"
version = "1.0.181"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb7d1f0d3021d347a83e556fc4683dea2ea09d87bccdf88ff5c12545d89d5efb"
checksum = "6d3e73c93c3240c0bda063c239298e633114c69a888c3e37ca8bb33f343e9890"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.152"
version = "1.0.181"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af487d118eecd09402d70a5d72551860e788df87b464af30e5ea6a38c75c541e"
checksum = "be02f6cb0cd3a5ec20bbcfbcbd749f57daddb1a0882dc2e46a6c236c90b977ed"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.28",
]
[[package]]
@ -1111,6 +1121,17 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "syn"
version = "2.0.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04361975b3f5e348b2189d8dc55bc942f278b2d482a6a0365de5bdd62d351567"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "tantivy"
version = "0.20.1"
@ -1120,6 +1141,8 @@ dependencies = [
"itertools",
"pyo3",
"pyo3-build-config",
"pythonize",
"serde",
"serde_json",
"tantivy 0.20.2",
]
@ -1313,7 +1336,7 @@ checksum = "1fb327af4685e4d03fa8cbcf1716380da910eeb2bb8be417e7f9fd3fb164f36f"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 1.0.107",
]
[[package]]
@ -1384,7 +1407,7 @@ checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 1.0.107",
]
[[package]]
@ -1505,7 +1528,7 @@ dependencies = [
"once_cell",
"proc-macro2",
"quote",
"syn",
"syn 1.0.107",
"wasm-bindgen-shared",
]
@ -1527,7 +1550,7 @@ checksum = "2aff81306fcac3c7515ad4e177f521b5c9a15f2b08f4e32d823066102f35a5f6"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 1.0.107",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]

View File

@ -18,6 +18,8 @@ chrono = "0.4.23"
tantivy = "0.20.1"
itertools = "0.10.5"
futures = "0.3.26"
pythonize = "0.19.0"
serde = "1.0"
serde_json = "1.0.91"
[dependencies.pyo3]

View File

@ -13,15 +13,18 @@ use pyo3::{
use chrono::{offset::TimeZone, NaiveDateTime, Utc};
use tantivy as tv;
use tantivy::{self as tv, schema::Value};
use crate::{facet::Facet, schema::Schema, to_pyerr};
use serde::{
ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer,
};
use serde_json::Value as JsonValue;
use std::{
collections::{BTreeMap, HashMap},
fmt,
net::Ipv6Addr,
};
use tantivy::schema::Value;
pub(crate) fn extract_value(any: &PyAny) -> PyResult<Value> {
if let Ok(s) = any.extract::<String>() {
@ -222,6 +225,149 @@ fn value_to_string(value: &Value) -> String {
}
}
/// Serializes a [`tv::DateTime`] object.
///
/// Since tantivy stores it as a single `i64` nanosecond timestamp, it is serialized and
/// deserialized as one.
fn serialize_datetime<S: Serializer>(
dt: &tv::DateTime,
serializer: S,
) -> Result<S::Ok, S::Error> {
dt.into_timestamp_nanos().serialize(serializer)
}
/// Deserializes a [`tv::DateTime`] object.
///
/// Since tantivy stores it as a single `i64` nanosecond timestamp, it is serialized and
/// deserialized as one.
fn deserialize_datetime<'de, D>(
deserializer: D,
) -> Result<tv::DateTime, D::Error>
where
D: Deserializer<'de>,
{
i64::deserialize(deserializer).map(tv::DateTime::from_timestamp_nanos)
}
/// An equivalent type to [`tantivy::schema::Value`], but unlike the tantivy crate's serialization
/// implementation, it uses tagging in its serialization and deserialization to differentiate
/// between different integer types.
///
/// [`BorrowedSerdeValue`] is often used for the serialization path, as owning the data is not
/// necessary for serialization.
#[derive(Deserialize, Serialize)]
enum SerdeValue {
/// The str type is used for any text information.
Str(String),
/// Pre-tokenized str type,
PreTokStr(tv::tokenizer::PreTokenizedString),
/// Unsigned 64-bits Integer `u64`
U64(u64),
/// Signed 64-bits Integer `i64`
I64(i64),
/// 64-bits Float `f64`
F64(f64),
/// Bool value
Bool(bool),
#[serde(
deserialize_with = "deserialize_datetime",
serialize_with = "serialize_datetime"
)]
/// Date/time with microseconds precision
Date(tv::DateTime),
/// Facet
Facet(tv::schema::Facet),
/// Arbitrarily sized byte array
Bytes(Vec<u8>),
/// Json object value.
JsonObject(serde_json::Map<String, serde_json::Value>),
/// IpV6 Address. Internally there is no IpV4, it needs to be converted to `Ipv6Addr`.
IpAddr(Ipv6Addr),
}
impl From<SerdeValue> for Value {
fn from(value: SerdeValue) -> Self {
match value {
SerdeValue::Str(v) => Self::Str(v),
SerdeValue::PreTokStr(v) => Self::PreTokStr(v),
SerdeValue::U64(v) => Self::U64(v),
SerdeValue::I64(v) => Self::I64(v),
SerdeValue::F64(v) => Self::F64(v),
SerdeValue::Date(v) => Self::Date(v),
SerdeValue::Facet(v) => Self::Facet(v),
SerdeValue::Bytes(v) => Self::Bytes(v),
SerdeValue::JsonObject(v) => Self::JsonObject(v),
SerdeValue::Bool(v) => Self::Bool(v),
SerdeValue::IpAddr(v) => Self::IpAddr(v),
}
}
}
impl From<Value> for SerdeValue {
fn from(value: Value) -> Self {
match value {
Value::Str(v) => Self::Str(v),
Value::PreTokStr(v) => Self::PreTokStr(v),
Value::U64(v) => Self::U64(v),
Value::I64(v) => Self::I64(v),
Value::F64(v) => Self::F64(v),
Value::Date(v) => Self::Date(v),
Value::Facet(v) => Self::Facet(v),
Value::Bytes(v) => Self::Bytes(v),
Value::JsonObject(v) => Self::JsonObject(v),
Value::Bool(v) => Self::Bool(v),
Value::IpAddr(v) => Self::IpAddr(v),
}
}
}
/// A non-owning version of [`SerdeValue`]. This is used in serialization to avoid unnecessary
/// cloning.
#[derive(Serialize)]
enum BorrowedSerdeValue<'a> {
/// The str type is used for any text information.
Str(&'a str),
/// Pre-tokenized str type,
PreTokStr(&'a tv::tokenizer::PreTokenizedString),
/// Unsigned 64-bits Integer `u64`
U64(&'a u64),
/// Signed 64-bits Integer `i64`
I64(&'a i64),
/// 64-bits Float `f64`
F64(&'a f64),
/// Bool value
Bool(&'a bool),
#[serde(serialize_with = "serialize_datetime")]
/// Date/time with microseconds precision
Date(&'a tv::DateTime),
/// Facet
Facet(&'a tv::schema::Facet),
/// Arbitrarily sized byte array
Bytes(&'a [u8]),
/// Json object value.
JsonObject(&'a serde_json::Map<String, serde_json::Value>),
/// IpV6 Address. Internally there is no IpV4, it needs to be converted to `Ipv6Addr`.
IpAddr(&'a Ipv6Addr),
}
impl<'a> From<&'a Value> for BorrowedSerdeValue<'a> {
fn from(value: &'a Value) -> Self {
match value {
Value::Str(v) => Self::Str(v),
Value::PreTokStr(v) => Self::PreTokStr(v),
Value::U64(v) => Self::U64(v),
Value::I64(v) => Self::I64(v),
Value::F64(v) => Self::F64(v),
Value::Date(v) => Self::Date(v),
Value::Facet(v) => Self::Facet(v),
Value::Bytes(v) => Self::Bytes(v),
Value::JsonObject(v) => Self::JsonObject(v),
Value::Bool(v) => Self::Bool(v),
Value::IpAddr(v) => Self::IpAddr(v),
}
}
}
/// Tantivy's Document is the object that can be indexed and then searched for.
///
/// Documents are fundamentally a collection of unordered tuples
@ -264,10 +410,10 @@ fn value_to_string(value: &Value) -> String {
/// {"unsigned": 1000, "signed": -5, "float": 0.4},
/// schema,
/// )
#[pyclass]
#[pyclass(module = "tantivy")]
#[derive(Clone, Default, PartialEq)]
pub(crate) struct Document {
pub(crate) field_values: BTreeMap<String, Vec<tv::schema::Value>>,
pub(crate) field_values: BTreeMap<String, Vec<Value>>,
}
impl fmt::Debug for Document {
@ -290,6 +436,42 @@ impl fmt::Debug for Document {
}
}
impl Serialize for Document {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut map =
serializer.serialize_map(Some(self.field_values.len()))?;
for (k, v) in &self.field_values {
let ser_v: Vec<_> =
v.iter().map(BorrowedSerdeValue::from).collect();
map.serialize_entry(&k, &ser_v)?;
}
map.end()
}
}
impl<'de> Deserialize<'de> for Document {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
BTreeMap::<String, Vec<SerdeValue>>::deserialize(deserializer).map(
|field_map| Document {
field_values: field_map
.into_iter()
.map(|(k, v)| {
let v: Vec<_> =
v.into_iter().map(Value::from).collect();
(k, v)
})
.collect(),
},
)
}
}
#[pymethods]
impl Document {
/// Creates a new document with optional fields from `**kwargs`.
@ -529,6 +711,26 @@ impl Document {
_ => py.NotImplemented(),
}
}
#[staticmethod]
fn _internal_from_pythonized(serialized: &PyAny) -> PyResult<Self> {
pythonize::depythonize(serialized).map_err(to_pyerr)
}
fn __reduce__<'a>(
slf: PyRef<'a, Self>,
py: Python<'a>,
) -> PyResult<&'a PyTuple> {
let serialized = pythonize::pythonize(py, &*slf).map_err(to_pyerr)?;
Ok(PyTuple::new(
py,
[
slf.into_py(py).getattr(py, "_internal_from_pythonized")?,
PyTuple::new(py, [serialized]).to_object(py),
],
))
}
}
impl Document {

View File

@ -1,4 +1,10 @@
use pyo3::{basic::CompareOp, prelude::*, types::PyType};
use crate::to_pyerr;
use pyo3::{
basic::CompareOp,
prelude::*,
types::{PyTuple, PyType},
};
use serde::{Deserialize, Serialize};
use tantivy::schema;
/// A Facet represent a point in a given hierarchy.
@ -10,14 +16,22 @@ use tantivy::schema;
/// implicitely imply that a document belonging to a facet also belongs to the
/// ancestor of its facet. In the example above, /electronics/tv_and_video/
/// and /electronics.
#[pyclass(frozen)]
#[derive(Clone, PartialEq)]
#[pyclass(frozen, module = "tantivy")]
#[derive(Clone, Deserialize, PartialEq, Serialize)]
pub(crate) struct Facet {
pub(crate) inner: schema::Facet,
}
#[pymethods]
impl Facet {
/// Creates a `Facet` from its binary representation.
#[staticmethod]
fn from_encoded(encoded_bytes: Vec<u8>) -> PyResult<Self> {
let inner =
schema::Facet::from_encoded(encoded_bytes).map_err(to_pyerr)?;
Ok(Self { inner })
}
/// Create a new instance of the "root facet" Equivalent to /.
#[classmethod]
fn root(_cls: &PyType) -> Facet {
@ -80,4 +94,18 @@ impl Facet {
_ => py.NotImplemented(),
}
}
fn __reduce__<'a>(
slf: PyRef<'a, Self>,
py: Python<'a>,
) -> PyResult<&'a PyTuple> {
let encoded_bytes = slf.inner.encoded_str().as_bytes().to_vec();
Ok(PyTuple::new(
py,
[
slf.into_py(py).getattr(py, "from_encoded")?,
PyTuple::new(py, [encoded_bytes]).to_object(py),
],
))
}
}

View File

@ -14,7 +14,7 @@ use facet::Facet;
use index::Index;
use schema::Schema;
use schemabuilder::SchemaBuilder;
use searcher::{DocAddress, Searcher};
use searcher::{DocAddress, SearchResult, Searcher};
/// Python bindings for the search engine library Tantivy.
///
@ -71,6 +71,7 @@ fn tantivy(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Schema>()?;
m.add_class::<SchemaBuilder>()?;
m.add_class::<Searcher>()?;
m.add_class::<SearchResult>()?;
m.add_class::<Document>()?;
m.add_class::<Index>()?;
m.add_class::<DocAddress>()?;

View File

@ -1,12 +1,14 @@
use pyo3::{basic::CompareOp, prelude::*};
use crate::to_pyerr;
use pyo3::{basic::CompareOp, prelude::*, types::PyTuple};
use serde::{Deserialize, Serialize};
use tantivy as tv;
/// Tantivy schema.
///
/// The schema is very strict. To build the schema the `SchemaBuilder` class is
/// provided.
#[pyclass(frozen)]
#[derive(PartialEq)]
#[pyclass(frozen, module = "tantivy")]
#[derive(Deserialize, PartialEq, Serialize)]
pub(crate) struct Schema {
pub(crate) inner: tv::schema::Schema,
}
@ -25,4 +27,24 @@ impl Schema {
_ => py.NotImplemented(),
}
}
#[staticmethod]
fn _internal_from_pythonized(serialized: &PyAny) -> PyResult<Self> {
pythonize::depythonize(serialized).map_err(to_pyerr)
}
fn __reduce__<'a>(
slf: PyRef<'a, Self>,
py: Python<'a>,
) -> PyResult<&'a PyTuple> {
let serialized = pythonize::pythonize(py, &*slf).map_err(to_pyerr)?;
Ok(PyTuple::new(
py,
[
slf.into_py(py).getattr(py, "_internal_from_pythonized")?,
PyTuple::new(py, [serialized]).to_object(py),
],
))
}
}

View File

@ -2,6 +2,7 @@
use crate::{document::Document, query::Query, to_pyerr};
use pyo3::{basic::CompareOp, exceptions::PyValueError, prelude::*};
use serde::{Deserialize, Serialize};
use tantivy as tv;
use tantivy::collector::{Count, MultiCollector, TopDocs};
@ -13,9 +14,11 @@ pub(crate) struct Searcher {
pub(crate) inner: tv::Searcher,
}
#[derive(Clone, PartialEq)]
#[derive(Clone, Deserialize, FromPyObject, PartialEq, Serialize)]
enum Fruit {
#[pyo3(transparent)]
Score(f32),
#[pyo3(transparent)]
Order(u64),
}
@ -37,8 +40,8 @@ impl ToPyObject for Fruit {
}
}
#[pyclass(frozen)]
#[derive(Clone, PartialEq)]
#[pyclass(frozen, module = "tantivy")]
#[derive(Clone, Default, Deserialize, PartialEq, Serialize)]
/// Object holding a results successful search.
pub(crate) struct SearchResult {
hits: Vec<(Fruit, DocAddress)>,
@ -50,6 +53,19 @@ pub(crate) struct SearchResult {
#[pymethods]
impl SearchResult {
#[new]
fn new(
py: Python,
hits: Vec<(PyObject, DocAddress)>,
count: Option<usize>,
) -> PyResult<Self> {
let hits = hits
.iter()
.map(|(f, d)| Ok((f.extract(py)?, d.clone())))
.collect::<PyResult<Vec<_>>>()?;
Ok(Self { hits, count })
}
fn __repr__(&self) -> PyResult<String> {
if let Some(count) = self.count {
Ok(format!(
@ -74,6 +90,13 @@ impl SearchResult {
}
}
fn __getnewargs__(
&self,
py: Python,
) -> PyResult<(Vec<(PyObject, DocAddress)>, Option<usize>)> {
Ok((self.hits(py)?, self.count))
}
#[getter]
/// The list of tuples that contains the scores and DocAddress of the
/// search results.
@ -214,8 +237,8 @@ impl Searcher {
/// It consists in an id identifying its segment, and its segment-local DocId.
/// The id used for the segment is actually an ordinal in the list of segment
/// hold by a Searcher.
#[pyclass(frozen)]
#[derive(Clone, Debug, PartialEq)]
#[pyclass(frozen, module = "tantivy")]
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
pub(crate) struct DocAddress {
pub(crate) segment_ord: tv::SegmentOrdinal,
pub(crate) doc: tv::DocId,
@ -223,6 +246,11 @@ pub(crate) struct DocAddress {
#[pymethods]
impl DocAddress {
#[new]
fn new(segment_ord: tv::SegmentOrdinal, doc: tv::DocId) -> Self {
DocAddress { segment_ord, doc }
}
/// The segment ordinal is an id identifying the segment hosting the
/// document. It is only meaningful, in the context of a searcher.
#[getter]
@ -248,6 +276,10 @@ impl DocAddress {
_ => py.NotImplemented(),
}
}
fn __getnewargs__(&self) -> PyResult<(tv::SegmentOrdinal, tv::DocId)> {
Ok((self.segment_ord, self.doc))
}
}
impl From<&tv::DocAddress> for DocAddress {

View File

@ -1,6 +1,9 @@
from io import BytesIO
import copy
import datetime
import tantivy
import pickle
import pytest
from tantivy import Document, Index, SchemaBuilder
@ -476,6 +479,15 @@ class TestClass(object):
assert eng_result1 != esp_result
assert eng_result2 != esp_result
def test_search_result_pickle(self, ram_index):
index = ram_index
query = index.parse_query("sea whale", ["title", "body"])
orig = index.searcher().search(query, 10)
pickled = pickle.loads(pickle.dumps(orig))
assert orig == pickled
class TestUpdateClass(object):
def test_delete_update(self, ram_index):
@ -544,7 +556,10 @@ class TestFromDiskClass(object):
class TestSearcher(object):
def test_searcher_repr(self, ram_index, ram_index_numeric_fields):
assert repr(ram_index.searcher()) == "Searcher(num_docs=3, num_segments=1)"
assert repr(ram_index_numeric_fields.searcher()) == "Searcher(num_docs=2, num_segments=1)"
assert (
repr(ram_index_numeric_fields.searcher())
== "Searcher(num_docs=2, num_segments=1)"
)
class TestDocument(object):
@ -557,8 +572,6 @@ class TestDocument(object):
assert doc.to_dict() == {"name": ["Bill"], "reference": [1, 2]}
def test_document_with_date(self):
import datetime
date = datetime.datetime(2019, 8, 12, 13, 0, 0)
doc = tantivy.Document(name="Bill", date=date)
assert doc["date"][0] == date
@ -607,6 +620,23 @@ class TestDocument(object):
assert doc1 == doc3
assert doc2 == doc3
def test_document_pickle(self):
orig = Document()
orig.add_unsigned("unsigned", 1)
orig.add_integer("integer", 5)
orig.add_float("float", 1.0)
orig.add_date("birth", datetime.datetime(2019, 8, 12, 13, 0, 5))
orig.add_text("title", "hello world!")
orig.add_json("json", '{"a": 1, "b": 2}')
orig.add_bytes("bytes", b"abc")
facet = tantivy.Facet.from_string("/europe/france")
orig.add_facet("facet", facet)
pickled = pickle.loads(pickle.dumps(orig))
assert orig == pickled
class TestJsonField:
def test_query_from_json_field(self):
@ -722,3 +752,35 @@ def test_facet_eq():
assert facet1 == facet2
assert facet1 != facet3
assert facet2 != facet3
def test_schema_pickle():
orig = (
SchemaBuilder()
.add_integer_field("id", stored=True, indexed=True)
.add_unsigned_field("unsigned")
.add_float_field("rating", stored=True, indexed=True)
.add_text_field("body", stored=True)
.add_date_field("date")
.add_json_field("json")
.add_bytes_field("bytes")
.build()
)
pickled = pickle.loads(pickle.dumps(orig))
assert orig == pickled
def test_facet_pickle():
orig = tantivy.Facet.from_string("/europe/france")
pickled = pickle.loads(pickle.dumps(orig))
assert orig == pickled
def test_doc_address_pickle():
orig = tantivy.DocAddress(42, 123)
pickled = pickle.loads(pickle.dumps(orig))
assert orig == pickled