From 2f65cc65ff1e0f4546f3e4c788b46e5027a0113c Mon Sep 17 00:00:00 2001 From: Caleb Hattingh Date: Sun, 26 Mar 2023 15:03:31 +0200 Subject: [PATCH] Include check for bytes in extract_value, fixes #72 --- src/document.rs | 3 +++ tests/tantivy_test.py | 25 +++++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/src/document.rs b/src/document.rs index f737d9a..5fa3d46 100644 --- a/src/document.rs +++ b/src/document.rs @@ -194,6 +194,9 @@ pub(crate) fn extract_value(any: &PyAny) -> PyResult { if let Ok(facet) = any.extract::() { return Ok(Value::Facet(facet.inner)); } + if let Ok(b) = any.extract::>() { + return Ok(Value::Bytes(b)) + } Err(to_pyerr(format!("Value unsupported {any:?}"))) } diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 0157103..6a5246d 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -1,3 +1,4 @@ +from io import BytesIO import tantivy import pytest @@ -531,3 +532,27 @@ class TestJsonField: # ) # result = index.searcher().search(query, 2) # assert len(result.hits) == 1 + + +@pytest.mark.parametrize('bytes_kwarg', [True, False]) +@pytest.mark.parametrize('bytes_payload', [ + b"abc", + bytearray(b"abc"), + memoryview(b"abc"), + BytesIO(b"abc").read(), + BytesIO(b"abc").getbuffer(), +]) +def test_bytes(bytes_kwarg, bytes_payload): + schema = SchemaBuilder().add_bytes_field("embedding",).build() + index = Index(schema) + writer = index.writer() + + if bytes_kwarg: + doc = Document(id=1, embedding=bytes_payload) + else: + doc = Document(id=1) + doc.add_bytes("embedding", bytes_payload) + + writer.add_document(doc) + writer.commit() + index.reload()