Support copy, deepcopy, eq on types (#99)

master
Chris Tam 2023-08-04 03:23:31 -04:00 committed by GitHub
parent 0032362e97
commit 8b33e00c58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 143 additions and 13 deletions

View File

@ -3,6 +3,7 @@
use itertools::Itertools;
use pyo3::{
basic::CompareOp,
prelude::*,
types::{
PyAny, PyDateAccess, PyDateTime, PyDict, PyList, PyTimeAccess, PyTuple,
@ -148,7 +149,7 @@ fn value_to_string(value: &Value) -> String {
/// schema,
/// )
#[pyclass]
#[derive(Default)]
#[derive(Clone, Default, PartialEq)]
pub(crate) struct Document {
pub(crate) field_values: BTreeMap<String, Vec<tv::schema::Value>>,
}
@ -552,6 +553,27 @@ impl Document {
fn __repr__(&self) -> PyResult<String> {
Ok(format!("{self:?}"))
}
fn __copy__(&self) -> Self {
self.clone()
}
fn __deepcopy__(&self, _memo: &PyDict) -> Self {
self.clone()
}
fn __richcmp__(
&self,
other: &Self,
op: CompareOp,
py: Python<'_>,
) -> PyObject {
match op {
CompareOp::Eq => (self == other).into_py(py),
CompareOp::Ne => (self != other).into_py(py),
_ => py.NotImplemented(),
}
}
}
impl Document {

View File

@ -1,4 +1,4 @@
use pyo3::{prelude::*, types::PyType};
use pyo3::{basic::CompareOp, prelude::*, types::PyType};
use tantivy::schema;
/// A Facet represent a point in a given hierarchy.
@ -10,8 +10,8 @@ use tantivy::schema;
/// implicitely imply that a document belonging to a facet also belongs to the
/// ancestor of its facet. In the example above, /electronics/tv_and_video/
/// and /electronics.
#[pyclass]
#[derive(Clone)]
#[pyclass(frozen)]
#[derive(Clone, PartialEq)]
pub(crate) struct Facet {
pub(crate) inner: schema::Facet,
}
@ -67,4 +67,17 @@ impl Facet {
fn __repr__(&self) -> PyResult<String> {
Ok(format!("Facet({})", self.to_path_str()))
}
fn __richcmp__(
&self,
other: &Self,
op: CompareOp,
py: Python<'_>,
) -> PyObject {
match op {
CompareOp::Eq => (self == other).into_py(py),
CompareOp::Ne => (self != other).into_py(py),
_ => py.NotImplemented(),
}
}
}

View File

@ -2,7 +2,7 @@ use pyo3::prelude::*;
use tantivy as tv;
/// Tantivy's Query
#[pyclass]
#[pyclass(frozen)]
pub(crate) struct Query {
pub(crate) inner: Box<dyn tv::query::Query>,
}

View File

@ -1,14 +1,28 @@
use pyo3::prelude::*;
use pyo3::{basic::CompareOp, prelude::*};
use tantivy as tv;
/// Tantivy schema.
///
/// The schema is very strict. To build the schema the `SchemaBuilder` class is
/// provided.
#[pyclass]
#[pyclass(frozen)]
#[derive(PartialEq)]
pub(crate) struct Schema {
pub(crate) inner: tv::schema::Schema,
}
#[pymethods]
impl Schema {}
impl Schema {
fn __richcmp__(
&self,
other: &Self,
op: CompareOp,
py: Python<'_>,
) -> PyObject {
match op {
CompareOp::Eq => (self == other).into_py(py),
CompareOp::Ne => (self != other).into_py(py),
_ => py.NotImplemented(),
}
}
}

View File

@ -1,7 +1,7 @@
#![allow(clippy::new_ret_no_self)]
use crate::{document::Document, query::Query, to_pyerr};
use pyo3::{exceptions::PyValueError, prelude::*};
use pyo3::{basic::CompareOp, exceptions::PyValueError, prelude::*};
use tantivy as tv;
use tantivy::collector::{Count, MultiCollector, TopDocs};
@ -13,7 +13,7 @@ pub(crate) struct Searcher {
pub(crate) inner: tv::Searcher,
}
#[derive(Clone)]
#[derive(Clone, PartialEq)]
enum Fruit {
Score(f32),
Order(u64),
@ -37,7 +37,8 @@ impl ToPyObject for Fruit {
}
}
#[pyclass]
#[pyclass(frozen)]
#[derive(Clone, PartialEq)]
/// Object holding a results successful search.
pub(crate) struct SearchResult {
hits: Vec<(Fruit, DocAddress)>,
@ -60,6 +61,19 @@ impl SearchResult {
}
}
fn __richcmp__(
&self,
other: &Self,
op: CompareOp,
py: Python<'_>,
) -> PyObject {
match op {
CompareOp::Eq => (self == other).into_py(py),
CompareOp::Ne => (self != other).into_py(py),
_ => py.NotImplemented(),
}
}
#[getter]
/// The list of tuples that contains the scores and DocAddress of the
/// search results.
@ -200,8 +214,8 @@ impl Searcher {
/// It consists in an id identifying its segment, and its segment-local DocId.
/// The id used for the segment is actually an ordinal in the list of segment
/// hold by a Searcher.
#[pyclass]
#[derive(Clone, Debug)]
#[pyclass(frozen)]
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct DocAddress {
pub(crate) segment_ord: tv::SegmentOrdinal,
pub(crate) doc: tv::DocId,
@ -221,6 +235,19 @@ impl DocAddress {
fn doc(&self) -> u32 {
self.doc
}
fn __richcmp__(
&self,
other: &Self,
op: CompareOp,
py: Python<'_>,
) -> PyObject {
match op {
CompareOp::Eq => (self == other).into_py(py),
CompareOp::Ne => (self != other).into_py(py),
_ => py.NotImplemented(),
}
}
}
impl From<&tv::DocAddress> for DocAddress {

View File

@ -1,4 +1,5 @@
from io import BytesIO
import copy
import tantivy
import pytest
@ -457,6 +458,21 @@ class TestClass(object):
schema,
)
def test_search_result_eq(self, ram_index, spanish_index):
eng_index = ram_index
eng_query = eng_index.parse_query("sea whale", ["title", "body"])
esp_index = spanish_index
esp_query = esp_index.parse_query("vieja", ["title", "body"])
eng_result1 = eng_index.searcher().search(eng_query, 10)
eng_result2 = eng_index.searcher().search(eng_query, 10)
esp_result = esp_index.searcher().search(esp_query, 10)
assert eng_result1 == eng_result2
assert eng_result1 != esp_result
assert eng_result2 != esp_result
class TestUpdateClass(object):
def test_delete_update(self, ram_index):
@ -570,6 +586,24 @@ class TestDocument(object):
with pytest.raises(ValueError):
tantivy.Document(name={})
def test_document_eq(self):
doc1 = tantivy.Document(name="Bill", reference=[1, 2])
doc2 = tantivy.Document.from_dict({"name": "Bill", "reference": [1, 2]})
doc3 = tantivy.Document(name="Bob", reference=[3, 4])
assert doc1 == doc2
assert doc1 != doc3
assert doc2 != doc3
def test_document_copy(self):
doc1 = tantivy.Document(name="Bill", reference=[1, 2])
doc2 = copy.copy(doc1)
doc3 = copy.deepcopy(doc2)
assert doc1 == doc2
assert doc1 == doc3
assert doc2 == doc3
class TestJsonField:
def test_query_from_json_field(self):
@ -665,3 +699,23 @@ def test_bytes(bytes_kwarg, bytes_payload):
writer.add_document(doc)
writer.commit()
index.reload()
def test_schema_eq():
schema1 = schema()
schema2 = schema()
schema3 = schema_numeric_fields()
assert schema1 == schema2
assert schema1 != schema3
assert schema2 != schema3
def test_facet_eq():
facet1 = tantivy.Facet.from_string("/europe/france")
facet2 = tantivy.Facet.from_string("/europe/france")
facet3 = tantivy.Facet.from_string("/europe/germany")
assert facet1 == facet2
assert facet1 != facet3
assert facet2 != facet3