From 8b33e00c5895bf87f1b21e8767efad8d54b95d36 Mon Sep 17 00:00:00 2001 From: Chris Tam Date: Fri, 4 Aug 2023 03:23:31 -0400 Subject: [PATCH] Support copy, deepcopy, eq on types (#99) --- src/document.rs | 24 ++++++++++++++++++- src/facet.rs | 19 ++++++++++++--- src/query.rs | 2 +- src/schema.rs | 20 +++++++++++++--- src/searcher.rs | 37 +++++++++++++++++++++++++---- tests/tantivy_test.py | 54 +++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 143 insertions(+), 13 deletions(-) diff --git a/src/document.rs b/src/document.rs index 2a5a41a..529567a 100644 --- a/src/document.rs +++ b/src/document.rs @@ -3,6 +3,7 @@ use itertools::Itertools; use pyo3::{ + basic::CompareOp, prelude::*, types::{ PyAny, PyDateAccess, PyDateTime, PyDict, PyList, PyTimeAccess, PyTuple, @@ -148,7 +149,7 @@ fn value_to_string(value: &Value) -> String { /// schema, /// ) #[pyclass] -#[derive(Default)] +#[derive(Clone, Default, PartialEq)] pub(crate) struct Document { pub(crate) field_values: BTreeMap>, } @@ -552,6 +553,27 @@ impl Document { fn __repr__(&self) -> PyResult { Ok(format!("{self:?}")) } + + fn __copy__(&self) -> Self { + self.clone() + } + + fn __deepcopy__(&self, _memo: &PyDict) -> Self { + self.clone() + } + + fn __richcmp__( + &self, + other: &Self, + op: CompareOp, + py: Python<'_>, + ) -> PyObject { + match op { + CompareOp::Eq => (self == other).into_py(py), + CompareOp::Ne => (self != other).into_py(py), + _ => py.NotImplemented(), + } + } } impl Document { diff --git a/src/facet.rs b/src/facet.rs index b02cfb5..a624e24 100644 --- a/src/facet.rs +++ b/src/facet.rs @@ -1,4 +1,4 @@ -use pyo3::{prelude::*, types::PyType}; +use pyo3::{basic::CompareOp, prelude::*, types::PyType}; use tantivy::schema; /// A Facet represent a point in a given hierarchy. @@ -10,8 +10,8 @@ use tantivy::schema; /// implicitely imply that a document belonging to a facet also belongs to the /// ancestor of its facet. In the example above, /electronics/tv_and_video/ /// and /electronics. -#[pyclass] -#[derive(Clone)] +#[pyclass(frozen)] +#[derive(Clone, PartialEq)] pub(crate) struct Facet { pub(crate) inner: schema::Facet, } @@ -67,4 +67,17 @@ impl Facet { fn __repr__(&self) -> PyResult { Ok(format!("Facet({})", self.to_path_str())) } + + fn __richcmp__( + &self, + other: &Self, + op: CompareOp, + py: Python<'_>, + ) -> PyObject { + match op { + CompareOp::Eq => (self == other).into_py(py), + CompareOp::Ne => (self != other).into_py(py), + _ => py.NotImplemented(), + } + } } diff --git a/src/query.rs b/src/query.rs index 40e4382..ef841a0 100644 --- a/src/query.rs +++ b/src/query.rs @@ -2,7 +2,7 @@ use pyo3::prelude::*; use tantivy as tv; /// Tantivy's Query -#[pyclass] +#[pyclass(frozen)] pub(crate) struct Query { pub(crate) inner: Box, } diff --git a/src/schema.rs b/src/schema.rs index 00d0c53..61cf273 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -1,14 +1,28 @@ -use pyo3::prelude::*; +use pyo3::{basic::CompareOp, prelude::*}; use tantivy as tv; /// Tantivy schema. /// /// The schema is very strict. To build the schema the `SchemaBuilder` class is /// provided. -#[pyclass] +#[pyclass(frozen)] +#[derive(PartialEq)] pub(crate) struct Schema { pub(crate) inner: tv::schema::Schema, } #[pymethods] -impl Schema {} +impl Schema { + fn __richcmp__( + &self, + other: &Self, + op: CompareOp, + py: Python<'_>, + ) -> PyObject { + match op { + CompareOp::Eq => (self == other).into_py(py), + CompareOp::Ne => (self != other).into_py(py), + _ => py.NotImplemented(), + } + } +} diff --git a/src/searcher.rs b/src/searcher.rs index 7b82964..ae37fa5 100644 --- a/src/searcher.rs +++ b/src/searcher.rs @@ -1,7 +1,7 @@ #![allow(clippy::new_ret_no_self)] use crate::{document::Document, query::Query, to_pyerr}; -use pyo3::{exceptions::PyValueError, prelude::*}; +use pyo3::{basic::CompareOp, exceptions::PyValueError, prelude::*}; use tantivy as tv; use tantivy::collector::{Count, MultiCollector, TopDocs}; @@ -13,7 +13,7 @@ pub(crate) struct Searcher { pub(crate) inner: tv::Searcher, } -#[derive(Clone)] +#[derive(Clone, PartialEq)] enum Fruit { Score(f32), Order(u64), @@ -37,7 +37,8 @@ impl ToPyObject for Fruit { } } -#[pyclass] +#[pyclass(frozen)] +#[derive(Clone, PartialEq)] /// Object holding a results successful search. pub(crate) struct SearchResult { hits: Vec<(Fruit, DocAddress)>, @@ -60,6 +61,19 @@ impl SearchResult { } } + fn __richcmp__( + &self, + other: &Self, + op: CompareOp, + py: Python<'_>, + ) -> PyObject { + match op { + CompareOp::Eq => (self == other).into_py(py), + CompareOp::Ne => (self != other).into_py(py), + _ => py.NotImplemented(), + } + } + #[getter] /// The list of tuples that contains the scores and DocAddress of the /// search results. @@ -200,8 +214,8 @@ impl Searcher { /// It consists in an id identifying its segment, and its segment-local DocId. /// The id used for the segment is actually an ordinal in the list of segment /// hold by a Searcher. -#[pyclass] -#[derive(Clone, Debug)] +#[pyclass(frozen)] +#[derive(Clone, Debug, PartialEq)] pub(crate) struct DocAddress { pub(crate) segment_ord: tv::SegmentOrdinal, pub(crate) doc: tv::DocId, @@ -221,6 +235,19 @@ impl DocAddress { fn doc(&self) -> u32 { self.doc } + + fn __richcmp__( + &self, + other: &Self, + op: CompareOp, + py: Python<'_>, + ) -> PyObject { + match op { + CompareOp::Eq => (self == other).into_py(py), + CompareOp::Ne => (self != other).into_py(py), + _ => py.NotImplemented(), + } + } } impl From<&tv::DocAddress> for DocAddress { diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 799396b..c18aaae 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -1,4 +1,5 @@ from io import BytesIO +import copy import tantivy import pytest @@ -457,6 +458,21 @@ class TestClass(object): schema, ) + def test_search_result_eq(self, ram_index, spanish_index): + eng_index = ram_index + eng_query = eng_index.parse_query("sea whale", ["title", "body"]) + + esp_index = spanish_index + esp_query = esp_index.parse_query("vieja", ["title", "body"]) + + eng_result1 = eng_index.searcher().search(eng_query, 10) + eng_result2 = eng_index.searcher().search(eng_query, 10) + esp_result = esp_index.searcher().search(esp_query, 10) + + assert eng_result1 == eng_result2 + assert eng_result1 != esp_result + assert eng_result2 != esp_result + class TestUpdateClass(object): def test_delete_update(self, ram_index): @@ -570,6 +586,24 @@ class TestDocument(object): with pytest.raises(ValueError): tantivy.Document(name={}) + def test_document_eq(self): + doc1 = tantivy.Document(name="Bill", reference=[1, 2]) + doc2 = tantivy.Document.from_dict({"name": "Bill", "reference": [1, 2]}) + doc3 = tantivy.Document(name="Bob", reference=[3, 4]) + + assert doc1 == doc2 + assert doc1 != doc3 + assert doc2 != doc3 + + def test_document_copy(self): + doc1 = tantivy.Document(name="Bill", reference=[1, 2]) + doc2 = copy.copy(doc1) + doc3 = copy.deepcopy(doc2) + + assert doc1 == doc2 + assert doc1 == doc3 + assert doc2 == doc3 + class TestJsonField: def test_query_from_json_field(self): @@ -665,3 +699,23 @@ def test_bytes(bytes_kwarg, bytes_payload): writer.add_document(doc) writer.commit() index.reload() + + +def test_schema_eq(): + schema1 = schema() + schema2 = schema() + schema3 = schema_numeric_fields() + + assert schema1 == schema2 + assert schema1 != schema3 + assert schema2 != schema3 + + +def test_facet_eq(): + facet1 = tantivy.Facet.from_string("/europe/france") + facet2 = tantivy.Facet.from_string("/europe/france") + facet3 = tantivy.Facet.from_string("/europe/germany") + + assert facet1 == facet2 + assert facet1 != facet3 + assert facet2 != facet3