From e95a4569d488059ce01f0cea4f1bc4a371f903a0 Mon Sep 17 00:00:00 2001 From: Adam Reichold Date: Mon, 5 Feb 2024 12:01:26 +0100 Subject: [PATCH] Add field_boosts and fuzzy_fields optional parameters to Index::parse_query (#202) --- src/document.rs | 2 +- src/index.rs | 173 +++++++++++++++++++++++++----------------- src/parser_error.rs | 23 ++---- src/snippet.rs | 4 +- tests/tantivy_test.py | 14 ++++ 5 files changed, 127 insertions(+), 89 deletions(-) diff --git a/src/document.rs b/src/document.rs index b09e754..06899dc 100644 --- a/src/document.rs +++ b/src/document.rs @@ -125,7 +125,7 @@ pub(crate) fn extract_value_for_type( Value::JsonObject( any.downcast::() - .map(|dict| pythonize::depythonize(&dict)) + .map(|dict| pythonize::depythonize(dict)) .map_err(to_pyerr_for_type("Json", field_name, any))? .map_err(to_pyerr_for_type("Json", field_name, any))?, ) diff --git a/src/index.rs b/src/index.rs index 4636d24..55780db 100644 --- a/src/index.rs +++ b/src/index.rs @@ -1,5 +1,7 @@ #![allow(clippy::new_ret_no_self)] +use std::collections::HashMap; + use pyo3::{exceptions, prelude::*, types::PyAny}; use crate::{ @@ -358,44 +360,33 @@ impl Index { /// /// Args: /// query: the query, following the tantivy query language. + /// /// default_fields_names (List[Field]): A list of fields used to search if no /// field is specified in the query. /// - #[pyo3(signature = (query, default_field_names = None))] + /// field_boosts: A dictionary keyed on field names which provides default boosts + /// for the query constructed by this method. + /// + /// fuzzy_fields: A dictionary keyed on field names which provides (prefix, distance, transpose_cost_one) + /// triples making queries constructed by this method fuzzy against the given fields + /// and using the given parameters. + /// `prefix` determines if terms which are prefixes of the given term match the query. + /// `distance` determines the maximum Levenshtein distance between terms matching the query and the given term. + /// `transpose_cost_one` determines if transpositions of neighbouring characters are counted only once against the Levenshtein distance. + #[pyo3(signature = (query, default_field_names = None, field_boosts = HashMap::new(), fuzzy_fields = HashMap::new()))] pub fn parse_query( &self, query: &str, default_field_names: Option>, + field_boosts: HashMap, + fuzzy_fields: HashMap, ) -> PyResult { - let mut default_fields = vec![]; - let schema = self.index.schema(); - if let Some(default_field_names_vec) = default_field_names { - for default_field_name in &default_field_names_vec { - if let Ok(field) = schema.get_field(default_field_name) { - let field_entry = schema.get_field_entry(field); - if !field_entry.is_indexed() { - return Err(exceptions::PyValueError::new_err( - format!( - "Field `{default_field_name}` is not set as indexed in the schema." - ), - )); - } - default_fields.push(field); - } else { - return Err(exceptions::PyValueError::new_err(format!( - "Field `{default_field_name}` is not defined in the schema." - ))); - } - } - } else { - for (field, field_entry) in self.index.schema().fields() { - if field_entry.is_indexed() { - default_fields.push(field); - } - } - } - let parser = - tv::query::QueryParser::for_index(&self.index, default_fields); + let parser = self.prepare_query_parser( + default_field_names, + field_boosts, + fuzzy_fields, + )?; + let query = parser.parse_query(query).map_err(to_pyerr)?; Ok(Query { inner: query }) @@ -410,64 +401,106 @@ impl Index { /// /// Args: /// query: the query, following the tantivy query language. + /// /// default_fields_names (List[Field]): A list of fields used to search if no /// field is specified in the query. /// + /// field_boosts: A dictionary keyed on field names which provides default boosts + /// for the query constructed by this method. + /// + /// fuzzy_fields: A dictionary keyed on field names which provides (prefix, distance, transpose_cost_one) + /// triples making queries constructed by this method fuzzy against the given fields + /// and using the given parameters. + /// `prefix` determines if terms which are prefixes of the given term match the query. + /// `distance` determines the maximum Levenshtein distance between terms matching the query and the given term. + /// `transpose_cost_one` determines if transpositions of neighbouring characters are counted only once against the Levenshtein distance. + /// /// Returns a tuple containing the parsed query and a list of errors. /// /// Raises ValueError if a field in `default_field_names` is not defined or marked as indexed. - #[pyo3(signature = (query, default_field_names = None))] + #[pyo3(signature = (query, default_field_names = None, field_boosts = HashMap::new(), fuzzy_fields = HashMap::new()))] pub fn parse_query_lenient( &self, query: &str, default_field_names: Option>, + field_boosts: HashMap, + fuzzy_fields: HashMap, + py: Python, ) -> PyResult<(Query, Vec)> { - let schema = self.index.schema(); + let parser = self.prepare_query_parser( + default_field_names, + field_boosts, + fuzzy_fields, + )?; - let default_fields = if let Some(default_field_names_vec) = - default_field_names - { - default_field_names_vec - .iter() - .map(|field_name| { - schema - .get_field(field_name) - .map_err(|_err| { - exceptions::PyValueError::new_err(format!( - "Field `{field_name}` is not defined in the schema." - )) - }) - .and_then(|field| { - schema.get_field_entry(field).is_indexed().then_some(field).ok_or( - exceptions::PyValueError::new_err( - format!( - "Field `{field_name}` is not set as indexed in the schema." - ), - )) - }) - }).collect::, _>>()? - } else { - self.index - .schema() - .fields() - .filter_map(|(f, fe)| fe.is_indexed().then_some(f)) - .collect::>() - }; - - let parser = - tv::query::QueryParser::for_index(&self.index, default_fields); let (query, errors) = parser.parse_query_lenient(query); + let errors = errors.into_iter().map(|err| err.into_py(py)).collect(); - Python::with_gil(|py| { - let errors = - errors.into_iter().map(|err| err.into_py(py)).collect(); - - Ok((Query { inner: query }, errors)) - }) + Ok((Query { inner: query }, errors)) } } impl Index { + fn prepare_query_parser( + &self, + default_field_names: Option>, + field_boosts: HashMap, + fuzzy_fields: HashMap, + ) -> PyResult { + let schema = self.index.schema(); + + let default_fields = if let Some(default_field_names) = + default_field_names + { + default_field_names.iter().map(|field_name| { + let field = schema.get_field(field_name).map_err(|_err| { + exceptions::PyValueError::new_err(format!( + "Field `{field_name}` is not defined in the schema." + )) + })?; + + let field_entry = schema.get_field_entry(field); + if !field_entry.is_indexed() { + return Err(exceptions::PyValueError::new_err( + format!("Field `{field_name}` is not set as indexed in the schema.") + )); + } + + Ok(field) + }).collect::>()? + } else { + schema + .fields() + .filter(|(_, field_entry)| field_entry.is_indexed()) + .map(|(field, _)| field) + .collect() + }; + + let mut parser = + tv::query::QueryParser::for_index(&self.index, default_fields); + + for (field_name, boost) in field_boosts { + let field = schema.get_field(&field_name).map_err(|_err| { + exceptions::PyValueError::new_err(format!( + "Field `{field_name}` is not defined in the schema." + )) + })?; + parser.set_field_boost(field, boost); + } + + for (field_name, (prefix, distance, transpose_cost_one)) in fuzzy_fields + { + let field = schema.get_field(&field_name).map_err(|_err| { + exceptions::PyValueError::new_err(format!( + "Field `{field_name}` is not defined in the schema." + )) + })?; + parser.set_field_fuzzy(field, prefix, distance, transpose_cost_one); + } + + Ok(parser) + } + fn register_custom_text_analyzers(index: &tv::Index) { let analyzers = [ ("ar_stem", Language::Arabic), diff --git a/src/parser_error.rs b/src/parser_error.rs index faff172..d91f1c8 100644 --- a/src/parser_error.rs +++ b/src/parser_error.rs @@ -304,10 +304,7 @@ impl ExpectedBase64Error { /// If `true`, an invalid byte was found in the query. Padding characters (`=`) interspersed in /// the encoded form will be treated as invalid bytes. fn caused_by_invalid_byte(&self) -> bool { - match self.decode_error { - base64::DecodeError::InvalidByte { .. } => true, - _ => false, - } + matches!(self.decode_error, base64::DecodeError::InvalidByte { .. }) } /// If the error was caused by an invalid byte, returns the offset and offending byte. @@ -322,19 +319,16 @@ impl ExpectedBase64Error { /// If `true`, the length of the base64 string was invalid. fn caused_by_invalid_length(&self) -> bool { - match self.decode_error { - base64::DecodeError::InvalidLength => true, - _ => false, - } + matches!(self.decode_error, base64::DecodeError::InvalidLength) } /// The last non-padding input symbol's encoded 6 bits have nonzero bits that will be discarded. /// If `true`, this is indicative of corrupted or truncated Base64. fn caused_by_invalid_last_symbol(&self) -> bool { - match self.decode_error { - base64::DecodeError::InvalidLastSymbol { .. } => true, - _ => false, - } + matches!( + self.decode_error, + base64::DecodeError::InvalidLastSymbol { .. } + ) } /// If the error was caused by an invalid last symbol, returns the offset and offending byte. @@ -350,10 +344,7 @@ impl ExpectedBase64Error { /// The nature of the padding was not as configured: absent or incorrect when it must be /// canonical, or present when it must be absent, etc. fn caused_by_invalid_padding(&self) -> bool { - match self.decode_error { - base64::DecodeError::InvalidPadding => true, - _ => false, - } + matches!(self.decode_error, base64::DecodeError::InvalidPadding) } fn __repr__(&self) -> String { diff --git a/src/snippet.rs b/src/snippet.rs index e5decc1..bb19d82 100644 --- a/src/snippet.rs +++ b/src/snippet.rs @@ -62,10 +62,10 @@ impl SnippetGenerator { tv::SnippetGenerator::create(&searcher.inner, query.get(), field) .map_err(to_pyerr)?; - return Ok(SnippetGenerator { + Ok(SnippetGenerator { field_name: field_name.to_string(), inner: generator, - }); + }) } pub fn snippet_from_doc(&self, doc: &crate::Document) -> crate::Snippet { diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index cabe977..5d32f87 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -94,6 +94,20 @@ class TestClass(object): == """Query(BooleanQuery { subqueries: [(Should, TermQuery(Term(field=0, type=Str, "winter"))), (Should, TermQuery(Term(field=1, type=Str, "winter")))] })""" ) + def test_parse_query_field_boosts(self, ram_index): + query = ram_index.parse_query("winter", field_boosts={"title": 2.3}) + assert ( + repr(query) + == """Query(BooleanQuery { subqueries: [(Should, Boost(query=TermQuery(Term(field=0, type=Str, "winter")), boost=2.3)), (Should, TermQuery(Term(field=1, type=Str, "winter")))] })""" + ) + + def test_parse_query_field_boosts(self, ram_index): + query = ram_index.parse_query("winter", fuzzy_fields={"title": (True, 1, False)}) + assert ( + repr(query) + == """Query(BooleanQuery { subqueries: [(Should, FuzzyTermQuery { term: Term(field=0, type=Str, "winter"), distance: 1, transposition_cost_one: false, prefix: true }), (Should, TermQuery(Term(field=1, type=Str, "winter")))] })""" + ) + def test_query_errors(self, ram_index): index = ram_index # no "bod" field