Add field_boosts and fuzzy_fields optional parameters to Index::parse_query (#202)

master
Adam Reichold 2024-02-05 12:01:26 +01:00 committed by GitHub
parent e7224f1016
commit e95a4569d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 127 additions and 89 deletions

View File

@ -125,7 +125,7 @@ pub(crate) fn extract_value_for_type(
Value::JsonObject( Value::JsonObject(
any.downcast::<PyDict>() any.downcast::<PyDict>()
.map(|dict| pythonize::depythonize(&dict)) .map(|dict| pythonize::depythonize(dict))
.map_err(to_pyerr_for_type("Json", field_name, any))? .map_err(to_pyerr_for_type("Json", field_name, any))?
.map_err(to_pyerr_for_type("Json", field_name, any))?, .map_err(to_pyerr_for_type("Json", field_name, any))?,
) )

View File

@ -1,5 +1,7 @@
#![allow(clippy::new_ret_no_self)] #![allow(clippy::new_ret_no_self)]
use std::collections::HashMap;
use pyo3::{exceptions, prelude::*, types::PyAny}; use pyo3::{exceptions, prelude::*, types::PyAny};
use crate::{ use crate::{
@ -358,44 +360,33 @@ impl Index {
/// ///
/// Args: /// Args:
/// query: the query, following the tantivy query language. /// query: the query, following the tantivy query language.
///
/// default_fields_names (List[Field]): A list of fields used to search if no /// default_fields_names (List[Field]): A list of fields used to search if no
/// field is specified in the query. /// field is specified in the query.
/// ///
#[pyo3(signature = (query, default_field_names = None))] /// field_boosts: A dictionary keyed on field names which provides default boosts
/// for the query constructed by this method.
///
/// fuzzy_fields: A dictionary keyed on field names which provides (prefix, distance, transpose_cost_one)
/// triples making queries constructed by this method fuzzy against the given fields
/// and using the given parameters.
/// `prefix` determines if terms which are prefixes of the given term match the query.
/// `distance` determines the maximum Levenshtein distance between terms matching the query and the given term.
/// `transpose_cost_one` determines if transpositions of neighbouring characters are counted only once against the Levenshtein distance.
#[pyo3(signature = (query, default_field_names = None, field_boosts = HashMap::new(), fuzzy_fields = HashMap::new()))]
pub fn parse_query( pub fn parse_query(
&self, &self,
query: &str, query: &str,
default_field_names: Option<Vec<String>>, default_field_names: Option<Vec<String>>,
field_boosts: HashMap<String, tv::Score>,
fuzzy_fields: HashMap<String, (bool, u8, bool)>,
) -> PyResult<Query> { ) -> PyResult<Query> {
let mut default_fields = vec![]; let parser = self.prepare_query_parser(
let schema = self.index.schema(); default_field_names,
if let Some(default_field_names_vec) = default_field_names { field_boosts,
for default_field_name in &default_field_names_vec { fuzzy_fields,
if let Ok(field) = schema.get_field(default_field_name) { )?;
let field_entry = schema.get_field_entry(field);
if !field_entry.is_indexed() {
return Err(exceptions::PyValueError::new_err(
format!(
"Field `{default_field_name}` is not set as indexed in the schema."
),
));
}
default_fields.push(field);
} else {
return Err(exceptions::PyValueError::new_err(format!(
"Field `{default_field_name}` is not defined in the schema."
)));
}
}
} else {
for (field, field_entry) in self.index.schema().fields() {
if field_entry.is_indexed() {
default_fields.push(field);
}
}
}
let parser =
tv::query::QueryParser::for_index(&self.index, default_fields);
let query = parser.parse_query(query).map_err(to_pyerr)?; let query = parser.parse_query(query).map_err(to_pyerr)?;
Ok(Query { inner: query }) Ok(Query { inner: query })
@ -410,64 +401,106 @@ impl Index {
/// ///
/// Args: /// Args:
/// query: the query, following the tantivy query language. /// query: the query, following the tantivy query language.
///
/// default_fields_names (List[Field]): A list of fields used to search if no /// default_fields_names (List[Field]): A list of fields used to search if no
/// field is specified in the query. /// field is specified in the query.
/// ///
/// field_boosts: A dictionary keyed on field names which provides default boosts
/// for the query constructed by this method.
///
/// fuzzy_fields: A dictionary keyed on field names which provides (prefix, distance, transpose_cost_one)
/// triples making queries constructed by this method fuzzy against the given fields
/// and using the given parameters.
/// `prefix` determines if terms which are prefixes of the given term match the query.
/// `distance` determines the maximum Levenshtein distance between terms matching the query and the given term.
/// `transpose_cost_one` determines if transpositions of neighbouring characters are counted only once against the Levenshtein distance.
///
/// Returns a tuple containing the parsed query and a list of errors. /// Returns a tuple containing the parsed query and a list of errors.
/// ///
/// Raises ValueError if a field in `default_field_names` is not defined or marked as indexed. /// Raises ValueError if a field in `default_field_names` is not defined or marked as indexed.
#[pyo3(signature = (query, default_field_names = None))] #[pyo3(signature = (query, default_field_names = None, field_boosts = HashMap::new(), fuzzy_fields = HashMap::new()))]
pub fn parse_query_lenient( pub fn parse_query_lenient(
&self, &self,
query: &str, query: &str,
default_field_names: Option<Vec<String>>, default_field_names: Option<Vec<String>>,
field_boosts: HashMap<String, tv::Score>,
fuzzy_fields: HashMap<String, (bool, u8, bool)>,
py: Python,
) -> PyResult<(Query, Vec<PyObject>)> { ) -> PyResult<(Query, Vec<PyObject>)> {
let schema = self.index.schema(); let parser = self.prepare_query_parser(
default_field_names,
field_boosts,
fuzzy_fields,
)?;
let default_fields = if let Some(default_field_names_vec) =
default_field_names
{
default_field_names_vec
.iter()
.map(|field_name| {
schema
.get_field(field_name)
.map_err(|_err| {
exceptions::PyValueError::new_err(format!(
"Field `{field_name}` is not defined in the schema."
))
})
.and_then(|field| {
schema.get_field_entry(field).is_indexed().then_some(field).ok_or(
exceptions::PyValueError::new_err(
format!(
"Field `{field_name}` is not set as indexed in the schema."
),
))
})
}).collect::<Result<Vec<_>, _>>()?
} else {
self.index
.schema()
.fields()
.filter_map(|(f, fe)| fe.is_indexed().then_some(f))
.collect::<Vec<_>>()
};
let parser =
tv::query::QueryParser::for_index(&self.index, default_fields);
let (query, errors) = parser.parse_query_lenient(query); let (query, errors) = parser.parse_query_lenient(query);
let errors = errors.into_iter().map(|err| err.into_py(py)).collect();
Python::with_gil(|py| { Ok((Query { inner: query }, errors))
let errors =
errors.into_iter().map(|err| err.into_py(py)).collect();
Ok((Query { inner: query }, errors))
})
} }
} }
impl Index { impl Index {
fn prepare_query_parser(
&self,
default_field_names: Option<Vec<String>>,
field_boosts: HashMap<String, tv::Score>,
fuzzy_fields: HashMap<String, (bool, u8, bool)>,
) -> PyResult<tv::query::QueryParser> {
let schema = self.index.schema();
let default_fields = if let Some(default_field_names) =
default_field_names
{
default_field_names.iter().map(|field_name| {
let field = schema.get_field(field_name).map_err(|_err| {
exceptions::PyValueError::new_err(format!(
"Field `{field_name}` is not defined in the schema."
))
})?;
let field_entry = schema.get_field_entry(field);
if !field_entry.is_indexed() {
return Err(exceptions::PyValueError::new_err(
format!("Field `{field_name}` is not set as indexed in the schema.")
));
}
Ok(field)
}).collect::<PyResult<_>>()?
} else {
schema
.fields()
.filter(|(_, field_entry)| field_entry.is_indexed())
.map(|(field, _)| field)
.collect()
};
let mut parser =
tv::query::QueryParser::for_index(&self.index, default_fields);
for (field_name, boost) in field_boosts {
let field = schema.get_field(&field_name).map_err(|_err| {
exceptions::PyValueError::new_err(format!(
"Field `{field_name}` is not defined in the schema."
))
})?;
parser.set_field_boost(field, boost);
}
for (field_name, (prefix, distance, transpose_cost_one)) in fuzzy_fields
{
let field = schema.get_field(&field_name).map_err(|_err| {
exceptions::PyValueError::new_err(format!(
"Field `{field_name}` is not defined in the schema."
))
})?;
parser.set_field_fuzzy(field, prefix, distance, transpose_cost_one);
}
Ok(parser)
}
fn register_custom_text_analyzers(index: &tv::Index) { fn register_custom_text_analyzers(index: &tv::Index) {
let analyzers = [ let analyzers = [
("ar_stem", Language::Arabic), ("ar_stem", Language::Arabic),

View File

@ -304,10 +304,7 @@ impl ExpectedBase64Error {
/// If `true`, an invalid byte was found in the query. Padding characters (`=`) interspersed in /// If `true`, an invalid byte was found in the query. Padding characters (`=`) interspersed in
/// the encoded form will be treated as invalid bytes. /// the encoded form will be treated as invalid bytes.
fn caused_by_invalid_byte(&self) -> bool { fn caused_by_invalid_byte(&self) -> bool {
match self.decode_error { matches!(self.decode_error, base64::DecodeError::InvalidByte { .. })
base64::DecodeError::InvalidByte { .. } => true,
_ => false,
}
} }
/// If the error was caused by an invalid byte, returns the offset and offending byte. /// If the error was caused by an invalid byte, returns the offset and offending byte.
@ -322,19 +319,16 @@ impl ExpectedBase64Error {
/// If `true`, the length of the base64 string was invalid. /// If `true`, the length of the base64 string was invalid.
fn caused_by_invalid_length(&self) -> bool { fn caused_by_invalid_length(&self) -> bool {
match self.decode_error { matches!(self.decode_error, base64::DecodeError::InvalidLength)
base64::DecodeError::InvalidLength => true,
_ => false,
}
} }
/// The last non-padding input symbol's encoded 6 bits have nonzero bits that will be discarded. /// The last non-padding input symbol's encoded 6 bits have nonzero bits that will be discarded.
/// If `true`, this is indicative of corrupted or truncated Base64. /// If `true`, this is indicative of corrupted or truncated Base64.
fn caused_by_invalid_last_symbol(&self) -> bool { fn caused_by_invalid_last_symbol(&self) -> bool {
match self.decode_error { matches!(
base64::DecodeError::InvalidLastSymbol { .. } => true, self.decode_error,
_ => false, base64::DecodeError::InvalidLastSymbol { .. }
} )
} }
/// If the error was caused by an invalid last symbol, returns the offset and offending byte. /// If the error was caused by an invalid last symbol, returns the offset and offending byte.
@ -350,10 +344,7 @@ impl ExpectedBase64Error {
/// The nature of the padding was not as configured: absent or incorrect when it must be /// The nature of the padding was not as configured: absent or incorrect when it must be
/// canonical, or present when it must be absent, etc. /// canonical, or present when it must be absent, etc.
fn caused_by_invalid_padding(&self) -> bool { fn caused_by_invalid_padding(&self) -> bool {
match self.decode_error { matches!(self.decode_error, base64::DecodeError::InvalidPadding)
base64::DecodeError::InvalidPadding => true,
_ => false,
}
} }
fn __repr__(&self) -> String { fn __repr__(&self) -> String {

View File

@ -62,10 +62,10 @@ impl SnippetGenerator {
tv::SnippetGenerator::create(&searcher.inner, query.get(), field) tv::SnippetGenerator::create(&searcher.inner, query.get(), field)
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
return Ok(SnippetGenerator { Ok(SnippetGenerator {
field_name: field_name.to_string(), field_name: field_name.to_string(),
inner: generator, inner: generator,
}); })
} }
pub fn snippet_from_doc(&self, doc: &crate::Document) -> crate::Snippet { pub fn snippet_from_doc(&self, doc: &crate::Document) -> crate::Snippet {

View File

@ -94,6 +94,20 @@ class TestClass(object):
== """Query(BooleanQuery { subqueries: [(Should, TermQuery(Term(field=0, type=Str, "winter"))), (Should, TermQuery(Term(field=1, type=Str, "winter")))] })""" == """Query(BooleanQuery { subqueries: [(Should, TermQuery(Term(field=0, type=Str, "winter"))), (Should, TermQuery(Term(field=1, type=Str, "winter")))] })"""
) )
def test_parse_query_field_boosts(self, ram_index):
query = ram_index.parse_query("winter", field_boosts={"title": 2.3})
assert (
repr(query)
== """Query(BooleanQuery { subqueries: [(Should, Boost(query=TermQuery(Term(field=0, type=Str, "winter")), boost=2.3)), (Should, TermQuery(Term(field=1, type=Str, "winter")))] })"""
)
def test_parse_query_field_boosts(self, ram_index):
query = ram_index.parse_query("winter", fuzzy_fields={"title": (True, 1, False)})
assert (
repr(query)
== """Query(BooleanQuery { subqueries: [(Should, FuzzyTermQuery { term: Term(field=0, type=Str, "winter"), distance: 1, transposition_cost_one: false, prefix: true }), (Should, TermQuery(Term(field=1, type=Str, "winter")))] })"""
)
def test_query_errors(self, ram_index): def test_query_errors(self, ram_index):
index = ram_index index = ram_index
# no "bod" field # no "bod" field