Add field_boosts and fuzzy_fields optional parameters to Index::parse_query (#202)
parent
e7224f1016
commit
e95a4569d4
|
@ -125,7 +125,7 @@ pub(crate) fn extract_value_for_type(
|
|||
|
||||
Value::JsonObject(
|
||||
any.downcast::<PyDict>()
|
||||
.map(|dict| pythonize::depythonize(&dict))
|
||||
.map(|dict| pythonize::depythonize(dict))
|
||||
.map_err(to_pyerr_for_type("Json", field_name, any))?
|
||||
.map_err(to_pyerr_for_type("Json", field_name, any))?,
|
||||
)
|
||||
|
|
173
src/index.rs
173
src/index.rs
|
@ -1,5 +1,7 @@
|
|||
#![allow(clippy::new_ret_no_self)]
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use pyo3::{exceptions, prelude::*, types::PyAny};
|
||||
|
||||
use crate::{
|
||||
|
@ -358,44 +360,33 @@ impl Index {
|
|||
///
|
||||
/// Args:
|
||||
/// query: the query, following the tantivy query language.
|
||||
///
|
||||
/// default_fields_names (List[Field]): A list of fields used to search if no
|
||||
/// field is specified in the query.
|
||||
///
|
||||
#[pyo3(signature = (query, default_field_names = None))]
|
||||
/// field_boosts: A dictionary keyed on field names which provides default boosts
|
||||
/// for the query constructed by this method.
|
||||
///
|
||||
/// fuzzy_fields: A dictionary keyed on field names which provides (prefix, distance, transpose_cost_one)
|
||||
/// triples making queries constructed by this method fuzzy against the given fields
|
||||
/// and using the given parameters.
|
||||
/// `prefix` determines if terms which are prefixes of the given term match the query.
|
||||
/// `distance` determines the maximum Levenshtein distance between terms matching the query and the given term.
|
||||
/// `transpose_cost_one` determines if transpositions of neighbouring characters are counted only once against the Levenshtein distance.
|
||||
#[pyo3(signature = (query, default_field_names = None, field_boosts = HashMap::new(), fuzzy_fields = HashMap::new()))]
|
||||
pub fn parse_query(
|
||||
&self,
|
||||
query: &str,
|
||||
default_field_names: Option<Vec<String>>,
|
||||
field_boosts: HashMap<String, tv::Score>,
|
||||
fuzzy_fields: HashMap<String, (bool, u8, bool)>,
|
||||
) -> PyResult<Query> {
|
||||
let mut default_fields = vec![];
|
||||
let schema = self.index.schema();
|
||||
if let Some(default_field_names_vec) = default_field_names {
|
||||
for default_field_name in &default_field_names_vec {
|
||||
if let Ok(field) = schema.get_field(default_field_name) {
|
||||
let field_entry = schema.get_field_entry(field);
|
||||
if !field_entry.is_indexed() {
|
||||
return Err(exceptions::PyValueError::new_err(
|
||||
format!(
|
||||
"Field `{default_field_name}` is not set as indexed in the schema."
|
||||
),
|
||||
));
|
||||
}
|
||||
default_fields.push(field);
|
||||
} else {
|
||||
return Err(exceptions::PyValueError::new_err(format!(
|
||||
"Field `{default_field_name}` is not defined in the schema."
|
||||
)));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (field, field_entry) in self.index.schema().fields() {
|
||||
if field_entry.is_indexed() {
|
||||
default_fields.push(field);
|
||||
}
|
||||
}
|
||||
}
|
||||
let parser =
|
||||
tv::query::QueryParser::for_index(&self.index, default_fields);
|
||||
let parser = self.prepare_query_parser(
|
||||
default_field_names,
|
||||
field_boosts,
|
||||
fuzzy_fields,
|
||||
)?;
|
||||
|
||||
let query = parser.parse_query(query).map_err(to_pyerr)?;
|
||||
|
||||
Ok(Query { inner: query })
|
||||
|
@ -410,64 +401,106 @@ impl Index {
|
|||
///
|
||||
/// Args:
|
||||
/// query: the query, following the tantivy query language.
|
||||
///
|
||||
/// default_fields_names (List[Field]): A list of fields used to search if no
|
||||
/// field is specified in the query.
|
||||
///
|
||||
/// field_boosts: A dictionary keyed on field names which provides default boosts
|
||||
/// for the query constructed by this method.
|
||||
///
|
||||
/// fuzzy_fields: A dictionary keyed on field names which provides (prefix, distance, transpose_cost_one)
|
||||
/// triples making queries constructed by this method fuzzy against the given fields
|
||||
/// and using the given parameters.
|
||||
/// `prefix` determines if terms which are prefixes of the given term match the query.
|
||||
/// `distance` determines the maximum Levenshtein distance between terms matching the query and the given term.
|
||||
/// `transpose_cost_one` determines if transpositions of neighbouring characters are counted only once against the Levenshtein distance.
|
||||
///
|
||||
/// Returns a tuple containing the parsed query and a list of errors.
|
||||
///
|
||||
/// Raises ValueError if a field in `default_field_names` is not defined or marked as indexed.
|
||||
#[pyo3(signature = (query, default_field_names = None))]
|
||||
#[pyo3(signature = (query, default_field_names = None, field_boosts = HashMap::new(), fuzzy_fields = HashMap::new()))]
|
||||
pub fn parse_query_lenient(
|
||||
&self,
|
||||
query: &str,
|
||||
default_field_names: Option<Vec<String>>,
|
||||
field_boosts: HashMap<String, tv::Score>,
|
||||
fuzzy_fields: HashMap<String, (bool, u8, bool)>,
|
||||
py: Python,
|
||||
) -> PyResult<(Query, Vec<PyObject>)> {
|
||||
let schema = self.index.schema();
|
||||
let parser = self.prepare_query_parser(
|
||||
default_field_names,
|
||||
field_boosts,
|
||||
fuzzy_fields,
|
||||
)?;
|
||||
|
||||
let default_fields = if let Some(default_field_names_vec) =
|
||||
default_field_names
|
||||
{
|
||||
default_field_names_vec
|
||||
.iter()
|
||||
.map(|field_name| {
|
||||
schema
|
||||
.get_field(field_name)
|
||||
.map_err(|_err| {
|
||||
exceptions::PyValueError::new_err(format!(
|
||||
"Field `{field_name}` is not defined in the schema."
|
||||
))
|
||||
})
|
||||
.and_then(|field| {
|
||||
schema.get_field_entry(field).is_indexed().then_some(field).ok_or(
|
||||
exceptions::PyValueError::new_err(
|
||||
format!(
|
||||
"Field `{field_name}` is not set as indexed in the schema."
|
||||
),
|
||||
))
|
||||
})
|
||||
}).collect::<Result<Vec<_>, _>>()?
|
||||
} else {
|
||||
self.index
|
||||
.schema()
|
||||
.fields()
|
||||
.filter_map(|(f, fe)| fe.is_indexed().then_some(f))
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
let parser =
|
||||
tv::query::QueryParser::for_index(&self.index, default_fields);
|
||||
let (query, errors) = parser.parse_query_lenient(query);
|
||||
let errors = errors.into_iter().map(|err| err.into_py(py)).collect();
|
||||
|
||||
Python::with_gil(|py| {
|
||||
let errors =
|
||||
errors.into_iter().map(|err| err.into_py(py)).collect();
|
||||
|
||||
Ok((Query { inner: query }, errors))
|
||||
})
|
||||
Ok((Query { inner: query }, errors))
|
||||
}
|
||||
}
|
||||
|
||||
impl Index {
|
||||
fn prepare_query_parser(
|
||||
&self,
|
||||
default_field_names: Option<Vec<String>>,
|
||||
field_boosts: HashMap<String, tv::Score>,
|
||||
fuzzy_fields: HashMap<String, (bool, u8, bool)>,
|
||||
) -> PyResult<tv::query::QueryParser> {
|
||||
let schema = self.index.schema();
|
||||
|
||||
let default_fields = if let Some(default_field_names) =
|
||||
default_field_names
|
||||
{
|
||||
default_field_names.iter().map(|field_name| {
|
||||
let field = schema.get_field(field_name).map_err(|_err| {
|
||||
exceptions::PyValueError::new_err(format!(
|
||||
"Field `{field_name}` is not defined in the schema."
|
||||
))
|
||||
})?;
|
||||
|
||||
let field_entry = schema.get_field_entry(field);
|
||||
if !field_entry.is_indexed() {
|
||||
return Err(exceptions::PyValueError::new_err(
|
||||
format!("Field `{field_name}` is not set as indexed in the schema.")
|
||||
));
|
||||
}
|
||||
|
||||
Ok(field)
|
||||
}).collect::<PyResult<_>>()?
|
||||
} else {
|
||||
schema
|
||||
.fields()
|
||||
.filter(|(_, field_entry)| field_entry.is_indexed())
|
||||
.map(|(field, _)| field)
|
||||
.collect()
|
||||
};
|
||||
|
||||
let mut parser =
|
||||
tv::query::QueryParser::for_index(&self.index, default_fields);
|
||||
|
||||
for (field_name, boost) in field_boosts {
|
||||
let field = schema.get_field(&field_name).map_err(|_err| {
|
||||
exceptions::PyValueError::new_err(format!(
|
||||
"Field `{field_name}` is not defined in the schema."
|
||||
))
|
||||
})?;
|
||||
parser.set_field_boost(field, boost);
|
||||
}
|
||||
|
||||
for (field_name, (prefix, distance, transpose_cost_one)) in fuzzy_fields
|
||||
{
|
||||
let field = schema.get_field(&field_name).map_err(|_err| {
|
||||
exceptions::PyValueError::new_err(format!(
|
||||
"Field `{field_name}` is not defined in the schema."
|
||||
))
|
||||
})?;
|
||||
parser.set_field_fuzzy(field, prefix, distance, transpose_cost_one);
|
||||
}
|
||||
|
||||
Ok(parser)
|
||||
}
|
||||
|
||||
fn register_custom_text_analyzers(index: &tv::Index) {
|
||||
let analyzers = [
|
||||
("ar_stem", Language::Arabic),
|
||||
|
|
|
@ -304,10 +304,7 @@ impl ExpectedBase64Error {
|
|||
/// If `true`, an invalid byte was found in the query. Padding characters (`=`) interspersed in
|
||||
/// the encoded form will be treated as invalid bytes.
|
||||
fn caused_by_invalid_byte(&self) -> bool {
|
||||
match self.decode_error {
|
||||
base64::DecodeError::InvalidByte { .. } => true,
|
||||
_ => false,
|
||||
}
|
||||
matches!(self.decode_error, base64::DecodeError::InvalidByte { .. })
|
||||
}
|
||||
|
||||
/// If the error was caused by an invalid byte, returns the offset and offending byte.
|
||||
|
@ -322,19 +319,16 @@ impl ExpectedBase64Error {
|
|||
|
||||
/// If `true`, the length of the base64 string was invalid.
|
||||
fn caused_by_invalid_length(&self) -> bool {
|
||||
match self.decode_error {
|
||||
base64::DecodeError::InvalidLength => true,
|
||||
_ => false,
|
||||
}
|
||||
matches!(self.decode_error, base64::DecodeError::InvalidLength)
|
||||
}
|
||||
|
||||
/// The last non-padding input symbol's encoded 6 bits have nonzero bits that will be discarded.
|
||||
/// If `true`, this is indicative of corrupted or truncated Base64.
|
||||
fn caused_by_invalid_last_symbol(&self) -> bool {
|
||||
match self.decode_error {
|
||||
base64::DecodeError::InvalidLastSymbol { .. } => true,
|
||||
_ => false,
|
||||
}
|
||||
matches!(
|
||||
self.decode_error,
|
||||
base64::DecodeError::InvalidLastSymbol { .. }
|
||||
)
|
||||
}
|
||||
|
||||
/// If the error was caused by an invalid last symbol, returns the offset and offending byte.
|
||||
|
@ -350,10 +344,7 @@ impl ExpectedBase64Error {
|
|||
/// The nature of the padding was not as configured: absent or incorrect when it must be
|
||||
/// canonical, or present when it must be absent, etc.
|
||||
fn caused_by_invalid_padding(&self) -> bool {
|
||||
match self.decode_error {
|
||||
base64::DecodeError::InvalidPadding => true,
|
||||
_ => false,
|
||||
}
|
||||
matches!(self.decode_error, base64::DecodeError::InvalidPadding)
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
|
|
|
@ -62,10 +62,10 @@ impl SnippetGenerator {
|
|||
tv::SnippetGenerator::create(&searcher.inner, query.get(), field)
|
||||
.map_err(to_pyerr)?;
|
||||
|
||||
return Ok(SnippetGenerator {
|
||||
Ok(SnippetGenerator {
|
||||
field_name: field_name.to_string(),
|
||||
inner: generator,
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
pub fn snippet_from_doc(&self, doc: &crate::Document) -> crate::Snippet {
|
||||
|
|
|
@ -94,6 +94,20 @@ class TestClass(object):
|
|||
== """Query(BooleanQuery { subqueries: [(Should, TermQuery(Term(field=0, type=Str, "winter"))), (Should, TermQuery(Term(field=1, type=Str, "winter")))] })"""
|
||||
)
|
||||
|
||||
def test_parse_query_field_boosts(self, ram_index):
|
||||
query = ram_index.parse_query("winter", field_boosts={"title": 2.3})
|
||||
assert (
|
||||
repr(query)
|
||||
== """Query(BooleanQuery { subqueries: [(Should, Boost(query=TermQuery(Term(field=0, type=Str, "winter")), boost=2.3)), (Should, TermQuery(Term(field=1, type=Str, "winter")))] })"""
|
||||
)
|
||||
|
||||
def test_parse_query_field_boosts(self, ram_index):
|
||||
query = ram_index.parse_query("winter", fuzzy_fields={"title": (True, 1, False)})
|
||||
assert (
|
||||
repr(query)
|
||||
== """Query(BooleanQuery { subqueries: [(Should, FuzzyTermQuery { term: Term(field=0, type=Str, "winter"), distance: 1, transposition_cost_one: false, prefix: true }), (Should, TermQuery(Term(field=1, type=Str, "winter")))] })"""
|
||||
)
|
||||
|
||||
def test_query_errors(self, ram_index):
|
||||
index = ram_index
|
||||
# no "bod" field
|
||||
|
|
Loading…
Reference in New Issue