diff --git a/Cargo.lock b/Cargo.lock index 7374aef..28f299a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -45,7 +45,7 @@ checksum = "1cd7fce9ba8c3c042128ce72d8b2ddbf3a05747efb67ea0313c635e10bda47a2" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.107", ] [[package]] @@ -236,7 +236,7 @@ dependencies = [ "proc-macro2", "quote", "scratch", - "syn", + "syn 1.0.107", ] [[package]] @@ -253,7 +253,7 @@ checksum = "ebf883b7aacd7b2aeb2a7b338648ee19f57c140d4ee8e52c68979c6b2f7f2263" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.107", ] [[package]] @@ -387,7 +387,7 @@ checksum = "95a73af87da33b5acf53acfebdc339fe592ecf5357ac7c0a7734ab9d8c876a70" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.107", ] [[package]] @@ -800,9 +800,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.51" +version = "1.0.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6" +checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9" dependencies = [ "unicode-ident", ] @@ -854,7 +854,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn", + "syn 1.0.107", ] [[package]] @@ -865,14 +865,24 @@ checksum = "e0b78ccbb160db1556cdb6fd96c50334c5d4ec44dc5e0a968d0a1208fa0efa8b" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.107", +] + +[[package]] +name = "pythonize" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e35b716d430ace57e2d1b4afb51c9e5b7c46d2bce72926e07f9be6a98ced03e" +dependencies = [ + "pyo3", + "serde", ] [[package]] name = "quote" -version = "1.0.23" +version = "1.0.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b" +checksum = "5fe8a65d69dd0808184ebb5f836ab526bb259db23c657efa38711b1072ee47f0" dependencies = [ "proc-macro2", ] @@ -1032,22 +1042,22 @@ checksum = "ddccb15bcce173023b3fedd9436f882a0739b8dfb45e4f6b6002bee5929f61b2" [[package]] name = "serde" -version = "1.0.152" +version = "1.0.181" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb7d1f0d3021d347a83e556fc4683dea2ea09d87bccdf88ff5c12545d89d5efb" +checksum = "6d3e73c93c3240c0bda063c239298e633114c69a888c3e37ca8bb33f343e9890" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.152" +version = "1.0.181" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af487d118eecd09402d70a5d72551860e788df87b464af30e5ea6a38c75c541e" +checksum = "be02f6cb0cd3a5ec20bbcfbcbd749f57daddb1a0882dc2e46a6c236c90b977ed" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.28", ] [[package]] @@ -1111,6 +1121,17 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "syn" +version = "2.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04361975b3f5e348b2189d8dc55bc942f278b2d482a6a0365de5bdd62d351567" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "tantivy" version = "0.20.1" @@ -1120,6 +1141,8 @@ dependencies = [ "itertools", "pyo3", "pyo3-build-config", + "pythonize", + "serde", "serde_json", "tantivy 0.20.2", ] @@ -1313,7 +1336,7 @@ checksum = "1fb327af4685e4d03fa8cbcf1716380da910eeb2bb8be417e7f9fd3fb164f36f" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.107", ] [[package]] @@ -1384,7 +1407,7 @@ checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.107", ] [[package]] @@ -1505,7 +1528,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn", + "syn 1.0.107", "wasm-bindgen-shared", ] @@ -1527,7 +1550,7 @@ checksum = "2aff81306fcac3c7515ad4e177f521b5c9a15f2b08f4e32d823066102f35a5f6" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.107", "wasm-bindgen-backend", "wasm-bindgen-shared", ] diff --git a/Cargo.toml b/Cargo.toml index a46cc18..1c190cd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,8 @@ chrono = "0.4.23" tantivy = "0.20.1" itertools = "0.10.5" futures = "0.3.26" +pythonize = "0.19.0" +serde = "1.0" serde_json = "1.0.91" [dependencies.pyo3] diff --git a/src/document.rs b/src/document.rs index b930b64..7e15946 100644 --- a/src/document.rs +++ b/src/document.rs @@ -13,15 +13,18 @@ use pyo3::{ use chrono::{offset::TimeZone, NaiveDateTime, Utc}; -use tantivy as tv; +use tantivy::{self as tv, schema::Value}; use crate::{facet::Facet, schema::Schema, to_pyerr}; +use serde::{ + ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer, +}; use serde_json::Value as JsonValue; use std::{ collections::{BTreeMap, HashMap}, fmt, + net::Ipv6Addr, }; -use tantivy::schema::Value; pub(crate) fn extract_value(any: &PyAny) -> PyResult { if let Ok(s) = any.extract::() { @@ -222,6 +225,149 @@ fn value_to_string(value: &Value) -> String { } } +/// Serializes a [`tv::DateTime`] object. +/// +/// Since tantivy stores it as a single `i64` nanosecond timestamp, it is serialized and +/// deserialized as one. +fn serialize_datetime( + dt: &tv::DateTime, + serializer: S, +) -> Result { + dt.into_timestamp_nanos().serialize(serializer) +} + +/// Deserializes a [`tv::DateTime`] object. +/// +/// Since tantivy stores it as a single `i64` nanosecond timestamp, it is serialized and +/// deserialized as one. +fn deserialize_datetime<'de, D>( + deserializer: D, +) -> Result +where + D: Deserializer<'de>, +{ + i64::deserialize(deserializer).map(tv::DateTime::from_timestamp_nanos) +} + +/// An equivalent type to [`tantivy::schema::Value`], but unlike the tantivy crate's serialization +/// implementation, it uses tagging in its serialization and deserialization to differentiate +/// between different integer types. +/// +/// [`BorrowedSerdeValue`] is often used for the serialization path, as owning the data is not +/// necessary for serialization. +#[derive(Deserialize, Serialize)] +enum SerdeValue { + /// The str type is used for any text information. + Str(String), + /// Pre-tokenized str type, + PreTokStr(tv::tokenizer::PreTokenizedString), + /// Unsigned 64-bits Integer `u64` + U64(u64), + /// Signed 64-bits Integer `i64` + I64(i64), + /// 64-bits Float `f64` + F64(f64), + /// Bool value + Bool(bool), + #[serde( + deserialize_with = "deserialize_datetime", + serialize_with = "serialize_datetime" + )] + /// Date/time with microseconds precision + Date(tv::DateTime), + /// Facet + Facet(tv::schema::Facet), + /// Arbitrarily sized byte array + Bytes(Vec), + /// Json object value. + JsonObject(serde_json::Map), + /// IpV6 Address. Internally there is no IpV4, it needs to be converted to `Ipv6Addr`. + IpAddr(Ipv6Addr), +} + +impl From for Value { + fn from(value: SerdeValue) -> Self { + match value { + SerdeValue::Str(v) => Self::Str(v), + SerdeValue::PreTokStr(v) => Self::PreTokStr(v), + SerdeValue::U64(v) => Self::U64(v), + SerdeValue::I64(v) => Self::I64(v), + SerdeValue::F64(v) => Self::F64(v), + SerdeValue::Date(v) => Self::Date(v), + SerdeValue::Facet(v) => Self::Facet(v), + SerdeValue::Bytes(v) => Self::Bytes(v), + SerdeValue::JsonObject(v) => Self::JsonObject(v), + SerdeValue::Bool(v) => Self::Bool(v), + SerdeValue::IpAddr(v) => Self::IpAddr(v), + } + } +} + +impl From for SerdeValue { + fn from(value: Value) -> Self { + match value { + Value::Str(v) => Self::Str(v), + Value::PreTokStr(v) => Self::PreTokStr(v), + Value::U64(v) => Self::U64(v), + Value::I64(v) => Self::I64(v), + Value::F64(v) => Self::F64(v), + Value::Date(v) => Self::Date(v), + Value::Facet(v) => Self::Facet(v), + Value::Bytes(v) => Self::Bytes(v), + Value::JsonObject(v) => Self::JsonObject(v), + Value::Bool(v) => Self::Bool(v), + Value::IpAddr(v) => Self::IpAddr(v), + } + } +} + +/// A non-owning version of [`SerdeValue`]. This is used in serialization to avoid unnecessary +/// cloning. +#[derive(Serialize)] +enum BorrowedSerdeValue<'a> { + /// The str type is used for any text information. + Str(&'a str), + /// Pre-tokenized str type, + PreTokStr(&'a tv::tokenizer::PreTokenizedString), + /// Unsigned 64-bits Integer `u64` + U64(&'a u64), + /// Signed 64-bits Integer `i64` + I64(&'a i64), + /// 64-bits Float `f64` + F64(&'a f64), + /// Bool value + Bool(&'a bool), + #[serde(serialize_with = "serialize_datetime")] + /// Date/time with microseconds precision + Date(&'a tv::DateTime), + /// Facet + Facet(&'a tv::schema::Facet), + /// Arbitrarily sized byte array + Bytes(&'a [u8]), + /// Json object value. + JsonObject(&'a serde_json::Map), + /// IpV6 Address. Internally there is no IpV4, it needs to be converted to `Ipv6Addr`. + IpAddr(&'a Ipv6Addr), +} + +impl<'a> From<&'a Value> for BorrowedSerdeValue<'a> { + fn from(value: &'a Value) -> Self { + match value { + Value::Str(v) => Self::Str(v), + Value::PreTokStr(v) => Self::PreTokStr(v), + Value::U64(v) => Self::U64(v), + Value::I64(v) => Self::I64(v), + Value::F64(v) => Self::F64(v), + Value::Date(v) => Self::Date(v), + Value::Facet(v) => Self::Facet(v), + Value::Bytes(v) => Self::Bytes(v), + Value::JsonObject(v) => Self::JsonObject(v), + Value::Bool(v) => Self::Bool(v), + Value::IpAddr(v) => Self::IpAddr(v), + } + } +} + /// Tantivy's Document is the object that can be indexed and then searched for. /// /// Documents are fundamentally a collection of unordered tuples @@ -264,10 +410,10 @@ fn value_to_string(value: &Value) -> String { /// {"unsigned": 1000, "signed": -5, "float": 0.4}, /// schema, /// ) -#[pyclass] +#[pyclass(module = "tantivy")] #[derive(Clone, Default, PartialEq)] pub(crate) struct Document { - pub(crate) field_values: BTreeMap>, + pub(crate) field_values: BTreeMap>, } impl fmt::Debug for Document { @@ -290,6 +436,42 @@ impl fmt::Debug for Document { } } +impl Serialize for Document { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut map = + serializer.serialize_map(Some(self.field_values.len()))?; + for (k, v) in &self.field_values { + let ser_v: Vec<_> = + v.iter().map(BorrowedSerdeValue::from).collect(); + map.serialize_entry(&k, &ser_v)?; + } + map.end() + } +} + +impl<'de> Deserialize<'de> for Document { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + BTreeMap::>::deserialize(deserializer).map( + |field_map| Document { + field_values: field_map + .into_iter() + .map(|(k, v)| { + let v: Vec<_> = + v.into_iter().map(Value::from).collect(); + (k, v) + }) + .collect(), + }, + ) + } +} + #[pymethods] impl Document { /// Creates a new document with optional fields from `**kwargs`. @@ -529,6 +711,26 @@ impl Document { _ => py.NotImplemented(), } } + + #[staticmethod] + fn _internal_from_pythonized(serialized: &PyAny) -> PyResult { + pythonize::depythonize(serialized).map_err(to_pyerr) + } + + fn __reduce__<'a>( + slf: PyRef<'a, Self>, + py: Python<'a>, + ) -> PyResult<&'a PyTuple> { + let serialized = pythonize::pythonize(py, &*slf).map_err(to_pyerr)?; + + Ok(PyTuple::new( + py, + [ + slf.into_py(py).getattr(py, "_internal_from_pythonized")?, + PyTuple::new(py, [serialized]).to_object(py), + ], + )) + } } impl Document { diff --git a/src/facet.rs b/src/facet.rs index a624e24..2983fe2 100644 --- a/src/facet.rs +++ b/src/facet.rs @@ -1,4 +1,10 @@ -use pyo3::{basic::CompareOp, prelude::*, types::PyType}; +use crate::to_pyerr; +use pyo3::{ + basic::CompareOp, + prelude::*, + types::{PyTuple, PyType}, +}; +use serde::{Deserialize, Serialize}; use tantivy::schema; /// A Facet represent a point in a given hierarchy. @@ -10,14 +16,22 @@ use tantivy::schema; /// implicitely imply that a document belonging to a facet also belongs to the /// ancestor of its facet. In the example above, /electronics/tv_and_video/ /// and /electronics. -#[pyclass(frozen)] -#[derive(Clone, PartialEq)] +#[pyclass(frozen, module = "tantivy")] +#[derive(Clone, Deserialize, PartialEq, Serialize)] pub(crate) struct Facet { pub(crate) inner: schema::Facet, } #[pymethods] impl Facet { + /// Creates a `Facet` from its binary representation. + #[staticmethod] + fn from_encoded(encoded_bytes: Vec) -> PyResult { + let inner = + schema::Facet::from_encoded(encoded_bytes).map_err(to_pyerr)?; + Ok(Self { inner }) + } + /// Create a new instance of the "root facet" Equivalent to /. #[classmethod] fn root(_cls: &PyType) -> Facet { @@ -80,4 +94,18 @@ impl Facet { _ => py.NotImplemented(), } } + + fn __reduce__<'a>( + slf: PyRef<'a, Self>, + py: Python<'a>, + ) -> PyResult<&'a PyTuple> { + let encoded_bytes = slf.inner.encoded_str().as_bytes().to_vec(); + Ok(PyTuple::new( + py, + [ + slf.into_py(py).getattr(py, "from_encoded")?, + PyTuple::new(py, [encoded_bytes]).to_object(py), + ], + )) + } } diff --git a/src/lib.rs b/src/lib.rs index 7fe6c2a..245cfee 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,7 +14,7 @@ use facet::Facet; use index::Index; use schema::Schema; use schemabuilder::SchemaBuilder; -use searcher::{DocAddress, Searcher}; +use searcher::{DocAddress, SearchResult, Searcher}; /// Python bindings for the search engine library Tantivy. /// @@ -71,6 +71,7 @@ fn tantivy(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/schema.rs b/src/schema.rs index 61cf273..ba0c740 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -1,12 +1,14 @@ -use pyo3::{basic::CompareOp, prelude::*}; +use crate::to_pyerr; +use pyo3::{basic::CompareOp, prelude::*, types::PyTuple}; +use serde::{Deserialize, Serialize}; use tantivy as tv; /// Tantivy schema. /// /// The schema is very strict. To build the schema the `SchemaBuilder` class is /// provided. -#[pyclass(frozen)] -#[derive(PartialEq)] +#[pyclass(frozen, module = "tantivy")] +#[derive(Deserialize, PartialEq, Serialize)] pub(crate) struct Schema { pub(crate) inner: tv::schema::Schema, } @@ -25,4 +27,24 @@ impl Schema { _ => py.NotImplemented(), } } + + #[staticmethod] + fn _internal_from_pythonized(serialized: &PyAny) -> PyResult { + pythonize::depythonize(serialized).map_err(to_pyerr) + } + + fn __reduce__<'a>( + slf: PyRef<'a, Self>, + py: Python<'a>, + ) -> PyResult<&'a PyTuple> { + let serialized = pythonize::pythonize(py, &*slf).map_err(to_pyerr)?; + + Ok(PyTuple::new( + py, + [ + slf.into_py(py).getattr(py, "_internal_from_pythonized")?, + PyTuple::new(py, [serialized]).to_object(py), + ], + )) + } } diff --git a/src/searcher.rs b/src/searcher.rs index ae37fa5..d76d984 100644 --- a/src/searcher.rs +++ b/src/searcher.rs @@ -2,6 +2,7 @@ use crate::{document::Document, query::Query, to_pyerr}; use pyo3::{basic::CompareOp, exceptions::PyValueError, prelude::*}; +use serde::{Deserialize, Serialize}; use tantivy as tv; use tantivy::collector::{Count, MultiCollector, TopDocs}; @@ -13,9 +14,11 @@ pub(crate) struct Searcher { pub(crate) inner: tv::Searcher, } -#[derive(Clone, PartialEq)] +#[derive(Clone, Deserialize, FromPyObject, PartialEq, Serialize)] enum Fruit { + #[pyo3(transparent)] Score(f32), + #[pyo3(transparent)] Order(u64), } @@ -37,8 +40,8 @@ impl ToPyObject for Fruit { } } -#[pyclass(frozen)] -#[derive(Clone, PartialEq)] +#[pyclass(frozen, module = "tantivy")] +#[derive(Clone, Default, Deserialize, PartialEq, Serialize)] /// Object holding a results successful search. pub(crate) struct SearchResult { hits: Vec<(Fruit, DocAddress)>, @@ -50,6 +53,19 @@ pub(crate) struct SearchResult { #[pymethods] impl SearchResult { + #[new] + fn new( + py: Python, + hits: Vec<(PyObject, DocAddress)>, + count: Option, + ) -> PyResult { + let hits = hits + .iter() + .map(|(f, d)| Ok((f.extract(py)?, d.clone()))) + .collect::>>()?; + Ok(Self { hits, count }) + } + fn __repr__(&self) -> PyResult { if let Some(count) = self.count { Ok(format!( @@ -74,6 +90,13 @@ impl SearchResult { } } + fn __getnewargs__( + &self, + py: Python, + ) -> PyResult<(Vec<(PyObject, DocAddress)>, Option)> { + Ok((self.hits(py)?, self.count)) + } + #[getter] /// The list of tuples that contains the scores and DocAddress of the /// search results. @@ -214,8 +237,8 @@ impl Searcher { /// It consists in an id identifying its segment, and its segment-local DocId. /// The id used for the segment is actually an ordinal in the list of segment /// hold by a Searcher. -#[pyclass(frozen)] -#[derive(Clone, Debug, PartialEq)] +#[pyclass(frozen, module = "tantivy")] +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] pub(crate) struct DocAddress { pub(crate) segment_ord: tv::SegmentOrdinal, pub(crate) doc: tv::DocId, @@ -223,6 +246,11 @@ pub(crate) struct DocAddress { #[pymethods] impl DocAddress { + #[new] + fn new(segment_ord: tv::SegmentOrdinal, doc: tv::DocId) -> Self { + DocAddress { segment_ord, doc } + } + /// The segment ordinal is an id identifying the segment hosting the /// document. It is only meaningful, in the context of a searcher. #[getter] @@ -248,6 +276,10 @@ impl DocAddress { _ => py.NotImplemented(), } } + + fn __getnewargs__(&self) -> PyResult<(tv::SegmentOrdinal, tv::DocId)> { + Ok((self.segment_ord, self.doc)) + } } impl From<&tv::DocAddress> for DocAddress { diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index efa23d5..a15fa30 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -1,6 +1,9 @@ from io import BytesIO + import copy +import datetime import tantivy +import pickle import pytest from tantivy import Document, Index, SchemaBuilder @@ -476,6 +479,15 @@ class TestClass(object): assert eng_result1 != esp_result assert eng_result2 != esp_result + def test_search_result_pickle(self, ram_index): + index = ram_index + query = index.parse_query("sea whale", ["title", "body"]) + + orig = index.searcher().search(query, 10) + pickled = pickle.loads(pickle.dumps(orig)) + + assert orig == pickled + class TestUpdateClass(object): def test_delete_update(self, ram_index): @@ -544,7 +556,10 @@ class TestFromDiskClass(object): class TestSearcher(object): def test_searcher_repr(self, ram_index, ram_index_numeric_fields): assert repr(ram_index.searcher()) == "Searcher(num_docs=3, num_segments=1)" - assert repr(ram_index_numeric_fields.searcher()) == "Searcher(num_docs=2, num_segments=1)" + assert ( + repr(ram_index_numeric_fields.searcher()) + == "Searcher(num_docs=2, num_segments=1)" + ) class TestDocument(object): @@ -557,8 +572,6 @@ class TestDocument(object): assert doc.to_dict() == {"name": ["Bill"], "reference": [1, 2]} def test_document_with_date(self): - import datetime - date = datetime.datetime(2019, 8, 12, 13, 0, 0) doc = tantivy.Document(name="Bill", date=date) assert doc["date"][0] == date @@ -607,6 +620,23 @@ class TestDocument(object): assert doc1 == doc3 assert doc2 == doc3 + def test_document_pickle(self): + orig = Document() + orig.add_unsigned("unsigned", 1) + orig.add_integer("integer", 5) + orig.add_float("float", 1.0) + orig.add_date("birth", datetime.datetime(2019, 8, 12, 13, 0, 5)) + orig.add_text("title", "hello world!") + orig.add_json("json", '{"a": 1, "b": 2}') + orig.add_bytes("bytes", b"abc") + + facet = tantivy.Facet.from_string("/europe/france") + orig.add_facet("facet", facet) + + pickled = pickle.loads(pickle.dumps(orig)) + + assert orig == pickled + class TestJsonField: def test_query_from_json_field(self): @@ -722,3 +752,35 @@ def test_facet_eq(): assert facet1 == facet2 assert facet1 != facet3 assert facet2 != facet3 + + +def test_schema_pickle(): + orig = ( + SchemaBuilder() + .add_integer_field("id", stored=True, indexed=True) + .add_unsigned_field("unsigned") + .add_float_field("rating", stored=True, indexed=True) + .add_text_field("body", stored=True) + .add_date_field("date") + .add_json_field("json") + .add_bytes_field("bytes") + .build() + ) + + pickled = pickle.loads(pickle.dumps(orig)) + + assert orig == pickled + + +def test_facet_pickle(): + orig = tantivy.Facet.from_string("/europe/france") + pickled = pickle.loads(pickle.dumps(orig)) + + assert orig == pickled + + +def test_doc_address_pickle(): + orig = tantivy.DocAddress(42, 123) + pickled = pickle.loads(pickle.dumps(orig)) + + assert orig == pickled