diff --git a/Cargo.toml b/Cargo.toml index 330d748..84a1bea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tantivy" -version = "0.16.0" +version = "0.17.0" readme = "README.md" authors = ["Damir Jelić "] edition = "2018" @@ -11,14 +11,15 @@ name = "tantivy" crate-type = ["cdylib"] [build-dependencies] -pyo3-build-config = "0.15.1" +pyo3-build-config = "0.16.3" [dependencies] chrono = "0.4.19" -tantivy = "0.16.1" -itertools = "0.10.0" -futures = "0.3.5" +tantivy = "0.17" +itertools = "0.10.3" +futures = "0.3.21" +serde_json = "1.0.64" [dependencies.pyo3] -version = "0.15.1" +version = "0.16.3" features = ["extension-module"] diff --git a/src/document.rs b/src/document.rs index 5088f51..acb79f5 100644 --- a/src/document.rs +++ b/src/document.rs @@ -14,10 +14,37 @@ use chrono::{offset::TimeZone, Datelike, Timelike, Utc}; use tantivy as tv; use crate::{facet::Facet, to_pyerr}; -use pyo3::{PyMappingProtocol, PyObjectProtocol}; -use std::{collections::BTreeMap, fmt}; +use serde_json::Value as JsonValue; +use std::{ + collections::{BTreeMap, HashMap}, + fmt, +}; use tantivy::schema::Value; +fn value_to_object(val: &JsonValue, py: Python<'_>) -> PyObject { + match val { + JsonValue::Null => py.None(), + JsonValue::Bool(b) => b.to_object(py), + JsonValue::Number(n) => match n { + n if n.is_i64() => n.as_i64().to_object(py), + n if n.is_u64() => n.as_u64().to_object(py), + n if n.is_f64() => n.as_f64().to_object(py), + _ => panic!("number too large"), + }, + JsonValue::String(s) => s.to_object(py), + JsonValue::Array(v) => { + let inner: Vec<_> = + v.iter().map(|x| value_to_object(x, py)).collect(); + inner.to_object(py) + } + JsonValue::Object(m) => { + let inner: HashMap<_, _> = + m.iter().map(|(k, v)| (k, value_to_object(v, py))).collect(); + inner.to_object(py) + } + } +} + fn value_to_py(py: Python, value: &Value) -> PyResult { Ok(match value { Value::Str(text) => text.into_py(py), @@ -42,6 +69,13 @@ fn value_to_py(py: Python, value: &Value) -> PyResult { )? .into_py(py), Value::Facet(f) => Facet { inner: f.clone() }.into_py(py), + Value::JsonObject(json_object) => { + let inner: HashMap<_, _> = json_object + .iter() + .map(|(k, v)| (k, value_to_object(&v, py))) + .collect(); + inner.to_object(py) + } }) } @@ -58,6 +92,9 @@ fn value_to_string(value: &Value) -> String { // TODO implement me unimplemented!(); } + Value::JsonObject(json_object) => { + serde_json::to_string(&json_object).unwrap() + } } } @@ -293,6 +330,17 @@ impl Document { add_value(self, field_name, bytes); } + /// Add a bytes value to the document. + /// + /// Args: + /// field_name (str): The field for which we are adding the bytes. + /// value (str): The json object that will be added to the document. + fn add_json(&mut self, field_name: String, json: &str) { + let json_object: serde_json::Value = + serde_json::from_str(json).unwrap(); + add_value(self, field_name, json_object); + } + /// Returns the number of added fields that have been added to the document #[getter] fn num_fields(&self) -> usize { @@ -337,6 +385,16 @@ impl Document { .map(|value| value_to_py(py, value)) .collect::>>() } + + fn __getitem__(&self, field_name: &str) -> PyResult> { + let gil = Python::acquire_gil(); + let py = gil.python(); + self.get_all(py, field_name) + } + + fn __repr__(&self) -> PyResult { + Ok(format!("{:?}", self)) + } } impl Document { @@ -350,19 +408,3 @@ impl Document { .flat_map(|values| values.iter()) } } - -#[pyproto] -impl PyMappingProtocol for Document { - fn __getitem__(&self, field_name: &str) -> PyResult> { - let gil = Python::acquire_gil(); - let py = gil.python(); - self.get_all(py, field_name) - } -} - -#[pyproto] -impl PyObjectProtocol for Document { - fn __repr__(&self) -> PyResult { - Ok(format!("{:?}", self)) - } -} diff --git a/src/facet.rs b/src/facet.rs index 72fdd4f..b02cfb5 100644 --- a/src/facet.rs +++ b/src/facet.rs @@ -1,4 +1,4 @@ -use pyo3::{basic::PyObjectProtocol, prelude::*, types::PyType}; +use pyo3::{prelude::*, types::PyType}; use tantivy::schema; /// A Facet represent a point in a given hierarchy. @@ -63,10 +63,7 @@ impl Facet { fn to_path_str(&self) -> String { self.inner.to_string() } -} -#[pyproto] -impl PyObjectProtocol for Facet { fn __repr__(&self) -> PyResult { Ok(format!("Facet({})", self.to_path_str())) } diff --git a/src/index.rs b/src/index.rs index 7a93082..a24647b 100644 --- a/src/index.rs +++ b/src/index.rs @@ -41,7 +41,7 @@ impl IndexWriter { pub fn add_document(&mut self, doc: &Document) -> PyResult { let named_doc = NamedFieldDocument(doc.field_values.clone()); let doc = self.schema.convert_named_doc(named_doc).map_err(to_pyerr)?; - Ok(self.inner_index_writer.add_document(doc)) + self.inner_index_writer.add_document(doc).map_err(to_pyerr) } /// Helper for the `add_document` method, but passing a json string. @@ -55,7 +55,7 @@ impl IndexWriter { pub fn add_json(&mut self, json: &str) -> PyResult { let doc = self.schema.parse_document(json).map_err(to_pyerr)?; let opstamp = self.inner_index_writer.add_document(doc); - Ok(opstamp) + opstamp.map_err(to_pyerr) } /// Commits all of the pending changes @@ -134,6 +134,12 @@ impl IndexWriter { field_name ))) } + Value::JsonObject(_) => { + return Err(exceptions::PyValueError::new_err(format!( + "Field `{}` is json object type not deletable.", + field_name + ))) + } }; Ok(self.inner_index_writer.delete_term(term)) } @@ -281,7 +287,7 @@ impl Index { #[staticmethod] fn exists(path: &str) -> PyResult { let directory = MmapDirectory::open(path).map_err(to_pyerr)?; - Ok(tv::Index::exists(&directory).unwrap()) + tv::Index::exists(&directory).map_err(to_pyerr) } /// The schema of the current index. @@ -304,7 +310,7 @@ impl Index { /// /// Args: /// query: the query, following the tantivy query language. - /// default_fields (List[Field]): A list of fields used to search if no + /// default_fields_names (List[Field]): A list of fields used to search if no /// field is specified in the query. /// #[args(reload_policy = "RELOAD_POLICY")] diff --git a/src/query.rs b/src/query.rs index e520953..40e4382 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,4 +1,4 @@ -use pyo3::{prelude::*, PyObjectProtocol}; +use pyo3::prelude::*; use tantivy as tv; /// Tantivy's Query @@ -13,8 +13,8 @@ impl Query { } } -#[pyproto] -impl PyObjectProtocol for Query { +#[pymethods] +impl Query { fn __repr__(&self) -> PyResult { Ok(format!("Query({:?})", self.get())) } diff --git a/src/schemabuilder.rs b/src/schemabuilder.rs index e3735fb..30cbd29 100644 --- a/src/schemabuilder.rs +++ b/src/schemabuilder.rs @@ -73,26 +73,11 @@ impl SchemaBuilder { index_option: &str, ) -> PyResult { let builder = &mut self.builder; - let index_option = match index_option { - "position" => schema::IndexRecordOption::WithFreqsAndPositions, - "freq" => schema::IndexRecordOption::WithFreqs, - "basic" => schema::IndexRecordOption::Basic, - _ => return Err(exceptions::PyValueError::new_err( - "Invalid index option, valid choices are: 'basic', 'freq' and 'position'" - )) - }; - - let indexing = schema::TextFieldIndexing::default() - .set_tokenizer(tokenizer_name) - .set_index_option(index_option); - - let options = - schema::TextOptions::default().set_indexing_options(indexing); - let options = if stored { - options.set_stored() - } else { - options - }; + let options = SchemaBuilder::build_text_option( + stored, + tokenizer_name, + index_option, + )?; if let Some(builder) = builder.write().unwrap().as_mut() { builder.add_text_field(name, options); @@ -230,6 +215,55 @@ impl SchemaBuilder { Ok(self.clone()) } + /// Add a new json field to the schema. + /// + /// Args: + /// name (str): the name of the field. + /// stored (bool, optional): If true sets the field as stored, the + /// content of the field can be later restored from a Searcher. + /// Defaults to False. + /// tokenizer_name (str, optional): The name of the tokenizer that + /// should be used to process the field. Defaults to 'default' + /// index_option (str, optional): Sets which information should be + /// indexed with the tokens. Can be one of 'position', 'freq' or + /// 'basic'. Defaults to 'position'. The 'basic' index_option + /// records only the document ID, the 'freq' option records the + /// document id and the term frequency, while the 'position' option + /// records the document id, term frequency and the positions of + /// the term occurrences in the document. + /// + /// Returns the associated field handle. + /// Raises a ValueError if there was an error with the field creation. + #[args( + stored = false, + tokenizer_name = "TOKENIZER", + index_option = "RECORD" + )] + fn add_json_field( + &mut self, + name: &str, + stored: bool, + tokenizer_name: &str, + index_option: &str, + ) -> PyResult { + let builder = &mut self.builder; + let options = SchemaBuilder::build_text_option( + stored, + tokenizer_name, + index_option, + )?; + + if let Some(builder) = builder.write().unwrap().as_mut() { + builder.add_json_field(name, options); + } else { + return Err(exceptions::PyValueError::new_err( + "Schema builder object isn't valid anymore.", + )); + } + + Ok(self.clone()) + } + /// Add a Facet field to the schema. /// Args: /// name (str): The name of the field. @@ -289,8 +323,8 @@ impl SchemaBuilder { stored: bool, indexed: bool, fast: Option<&str>, - ) -> PyResult { - let opts = schema::IntOptions::default(); + ) -> PyResult { + let opts = schema::NumericOptions::default(); let opts = if stored { opts.set_stored() } else { opts }; let opts = if indexed { opts.set_indexed() } else { opts }; @@ -317,4 +351,33 @@ impl SchemaBuilder { Ok(opts) } + + fn build_text_option( + stored: bool, + tokenizer_name: &str, + index_option: &str, + ) -> PyResult { + let index_option = match index_option { + "position" => schema::IndexRecordOption::WithFreqsAndPositions, + "freq" => schema::IndexRecordOption::WithFreqs, + "basic" => schema::IndexRecordOption::Basic, + _ => return Err(exceptions::PyValueError::new_err( + "Invalid index option, valid choices are: 'basic', 'freq' and 'position'" + )) + }; + + let indexing = schema::TextFieldIndexing::default() + .set_tokenizer(tokenizer_name) + .set_index_option(index_option); + + let options = + schema::TextOptions::default().set_indexing_options(indexing); + let options = if stored { + options.set_stored() + } else { + options + }; + + Ok(options) + } } diff --git a/src/searcher.rs b/src/searcher.rs index 33e33cb..c2b6796 100644 --- a/src/searcher.rs +++ b/src/searcher.rs @@ -1,7 +1,7 @@ #![allow(clippy::new_ret_no_self)] use crate::{document::Document, get_field, query::Query, to_pyerr}; -use pyo3::{exceptions::PyValueError, prelude::*, PyObjectProtocol}; +use pyo3::{exceptions::PyValueError, prelude::*}; use tantivy as tv; use tantivy::collector::{Count, MultiCollector, TopDocs}; @@ -47,8 +47,8 @@ pub(crate) struct SearchResult { count: Option, } -#[pyproto] -impl PyObjectProtocol for SearchResult { +#[pymethods] +impl SearchResult { fn __repr__(&self) -> PyResult { if let Some(count) = self.count { Ok(format!( @@ -59,10 +59,7 @@ impl PyObjectProtocol for SearchResult { Ok(format!("SearchResult(hits: {:?})", self.hits)) } } -} -#[pymethods] -impl SearchResult { #[getter] /// The list of tuples that contains the scores and DocAddress of the /// search results. @@ -185,6 +182,14 @@ impl Searcher { field_values: named_doc.0, }) } + + fn __repr__(&self) -> PyResult { + Ok(format!( + "Searcher(num_docs={}, num_segments={})", + self.inner.num_docs(), + self.inner.segment_readers().len() + )) + } } /// DocAddress contains all the necessary information to identify a document @@ -233,14 +238,3 @@ impl Into for &DocAddress { } } } - -#[pyproto] -impl PyObjectProtocol for Searcher { - fn __repr__(&self) -> PyResult { - Ok(format!( - "Searcher(num_docs={}, num_segments={})", - self.inner.num_docs(), - self.inner.segment_readers().len() - )) - } -} diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 81c2f9c..0d6d898 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -5,7 +5,13 @@ from tantivy import Document, Index, SchemaBuilder def schema(): - return SchemaBuilder().add_text_field("title", stored=True).add_text_field("body").build() + return ( + SchemaBuilder() + .add_text_field("title", stored=True) + .add_text_field("body") + .build() + ) + def create_index(dir=None): # assume all tests will use the same documents for now @@ -99,7 +105,9 @@ class TestClass(object): def test_and_query(self, ram_index): index = ram_index - query = index.parse_query("title:men AND body:summer", default_field_names=["title", "body"]) + query = index.parse_query( + "title:men AND body:summer", default_field_names=["title", "body"] + ) # look for an intersection of documents searcher = index.searcher() result = searcher.search(query, 10) @@ -114,15 +122,13 @@ class TestClass(object): def test_and_query_parser_default_fields(self, ram_index): query = ram_index.parse_query("winter", default_field_names=["title"]) - assert repr(query) == """Query(TermQuery(Term(field=0,bytes=[119, 105, 110, 116, 101, 114])))""" + assert repr(query) == """Query(TermQuery(Term(type=Str, field=0, "winter")))""" def test_and_query_parser_default_fields_undefined(self, ram_index): query = ram_index.parse_query("winter") assert ( - repr(query) == "Query(BooleanQuery { subqueries: [" - "(Should, TermQuery(Term(field=0,bytes=[119, 105, 110, 116, 101, 114]))), " - "(Should, TermQuery(Term(field=1,bytes=[119, 105, 110, 116, 101, 114])))] " - "})" + repr(query) + == """Query(BooleanQuery { subqueries: [(Should, TermQuery(Term(type=Str, field=0, "winter"))), (Should, TermQuery(Term(type=Str, field=1, "winter")))] })""" ) def test_query_errors(self, ram_index): @@ -132,9 +138,11 @@ class TestClass(object): index.parse_query("bod:men", ["title", "body"]) def test_order_by_search(self): - schema = (SchemaBuilder() + schema = ( + SchemaBuilder() .add_unsigned_field("order", fast="single") - .add_text_field("title", stored=True).build() + .add_text_field("title", stored=True) + .build() ) index = Index(schema) @@ -155,7 +163,6 @@ class TestClass(object): doc.add_unsigned("order", 1) doc.add_text("title", "Another test title") - writer.add_document(doc) writer.commit() @@ -163,7 +170,6 @@ class TestClass(object): query = index.parse_query("test") - searcher = index.searcher() result = searcher.search(query, 10, offset=2, order_by_field="order") @@ -187,9 +193,11 @@ class TestClass(object): assert searched_doc["title"] == ["Test title"] def test_order_by_search_without_fast_field(self): - schema = (SchemaBuilder() + schema = ( + SchemaBuilder() .add_unsigned_field("order") - .add_text_field("title", stored=True).build() + .add_text_field("title", stored=True) + .build() ) index = Index(schema) @@ -316,3 +324,72 @@ class TestDocument(object): def test_document_error(self): with pytest.raises(ValueError): tantivy.Document(name={}) + + +class TestJsonField: + def test_query_from_json_field(self): + schema = ( + SchemaBuilder() + .add_json_field( + "attributes", + stored=True, + tokenizer_name="default", + index_option="position", + ) + .build() + ) + + index = Index(schema) + + writer = index.writer() + + doc = Document() + doc.add_json( + "attributes", + """{ + "order":1.1, + "target": "submit-button", + "cart": {"product_id": 103}, + "description": "the best vacuum cleaner ever" + }""", + ) + + writer.add_document(doc) + + doc = Document() + doc.add_json( + "attributes", + """{ + "order":1.2, + "target": "submit-button", + "cart": {"product_id": 133}, + "description": "das keyboard" + }""", + ) + + writer.add_document(doc) + + writer.commit() + index.reload() + + query = index.parse_query("target:submit-button", ["attributes"]) + result = index.searcher().search(query, 2) + assert len(result.hits) == 2 + + query = index.parse_query("target:submit", ["attributes"]) + result = index.searcher().search(query, 2) + assert len(result.hits) == 2 + + query = index.parse_query("order:1.1", ["attributes"]) + result = index.searcher().search(query, 2) + assert len(result.hits) == 1 + + # query = index.parse_query_for_attributes("cart.product_id:103") + # result = index.searcher().search(query, 1) + # assert len(result.hits) == 1 + + # query = index.parse_query_for_attributes( + # "target:submit-button AND cart.product_id:133" + # ) + # result = index.searcher().search(query, 2) + # assert len(result.hits) == 1