Add field_boosts and fuzzy_fields optional parameters to Index::parse_query (#202)
parent
e7224f1016
commit
e95a4569d4
|
@ -125,7 +125,7 @@ pub(crate) fn extract_value_for_type(
|
||||||
|
|
||||||
Value::JsonObject(
|
Value::JsonObject(
|
||||||
any.downcast::<PyDict>()
|
any.downcast::<PyDict>()
|
||||||
.map(|dict| pythonize::depythonize(&dict))
|
.map(|dict| pythonize::depythonize(dict))
|
||||||
.map_err(to_pyerr_for_type("Json", field_name, any))?
|
.map_err(to_pyerr_for_type("Json", field_name, any))?
|
||||||
.map_err(to_pyerr_for_type("Json", field_name, any))?,
|
.map_err(to_pyerr_for_type("Json", field_name, any))?,
|
||||||
)
|
)
|
||||||
|
|
171
src/index.rs
171
src/index.rs
|
@ -1,5 +1,7 @@
|
||||||
#![allow(clippy::new_ret_no_self)]
|
#![allow(clippy::new_ret_no_self)]
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use pyo3::{exceptions, prelude::*, types::PyAny};
|
use pyo3::{exceptions, prelude::*, types::PyAny};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -358,44 +360,33 @@ impl Index {
|
||||||
///
|
///
|
||||||
/// Args:
|
/// Args:
|
||||||
/// query: the query, following the tantivy query language.
|
/// query: the query, following the tantivy query language.
|
||||||
|
///
|
||||||
/// default_fields_names (List[Field]): A list of fields used to search if no
|
/// default_fields_names (List[Field]): A list of fields used to search if no
|
||||||
/// field is specified in the query.
|
/// field is specified in the query.
|
||||||
///
|
///
|
||||||
#[pyo3(signature = (query, default_field_names = None))]
|
/// field_boosts: A dictionary keyed on field names which provides default boosts
|
||||||
|
/// for the query constructed by this method.
|
||||||
|
///
|
||||||
|
/// fuzzy_fields: A dictionary keyed on field names which provides (prefix, distance, transpose_cost_one)
|
||||||
|
/// triples making queries constructed by this method fuzzy against the given fields
|
||||||
|
/// and using the given parameters.
|
||||||
|
/// `prefix` determines if terms which are prefixes of the given term match the query.
|
||||||
|
/// `distance` determines the maximum Levenshtein distance between terms matching the query and the given term.
|
||||||
|
/// `transpose_cost_one` determines if transpositions of neighbouring characters are counted only once against the Levenshtein distance.
|
||||||
|
#[pyo3(signature = (query, default_field_names = None, field_boosts = HashMap::new(), fuzzy_fields = HashMap::new()))]
|
||||||
pub fn parse_query(
|
pub fn parse_query(
|
||||||
&self,
|
&self,
|
||||||
query: &str,
|
query: &str,
|
||||||
default_field_names: Option<Vec<String>>,
|
default_field_names: Option<Vec<String>>,
|
||||||
|
field_boosts: HashMap<String, tv::Score>,
|
||||||
|
fuzzy_fields: HashMap<String, (bool, u8, bool)>,
|
||||||
) -> PyResult<Query> {
|
) -> PyResult<Query> {
|
||||||
let mut default_fields = vec![];
|
let parser = self.prepare_query_parser(
|
||||||
let schema = self.index.schema();
|
default_field_names,
|
||||||
if let Some(default_field_names_vec) = default_field_names {
|
field_boosts,
|
||||||
for default_field_name in &default_field_names_vec {
|
fuzzy_fields,
|
||||||
if let Ok(field) = schema.get_field(default_field_name) {
|
)?;
|
||||||
let field_entry = schema.get_field_entry(field);
|
|
||||||
if !field_entry.is_indexed() {
|
|
||||||
return Err(exceptions::PyValueError::new_err(
|
|
||||||
format!(
|
|
||||||
"Field `{default_field_name}` is not set as indexed in the schema."
|
|
||||||
),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
default_fields.push(field);
|
|
||||||
} else {
|
|
||||||
return Err(exceptions::PyValueError::new_err(format!(
|
|
||||||
"Field `{default_field_name}` is not defined in the schema."
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (field, field_entry) in self.index.schema().fields() {
|
|
||||||
if field_entry.is_indexed() {
|
|
||||||
default_fields.push(field);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let parser =
|
|
||||||
tv::query::QueryParser::for_index(&self.index, default_fields);
|
|
||||||
let query = parser.parse_query(query).map_err(to_pyerr)?;
|
let query = parser.parse_query(query).map_err(to_pyerr)?;
|
||||||
|
|
||||||
Ok(Query { inner: query })
|
Ok(Query { inner: query })
|
||||||
|
@ -410,64 +401,106 @@ impl Index {
|
||||||
///
|
///
|
||||||
/// Args:
|
/// Args:
|
||||||
/// query: the query, following the tantivy query language.
|
/// query: the query, following the tantivy query language.
|
||||||
|
///
|
||||||
/// default_fields_names (List[Field]): A list of fields used to search if no
|
/// default_fields_names (List[Field]): A list of fields used to search if no
|
||||||
/// field is specified in the query.
|
/// field is specified in the query.
|
||||||
///
|
///
|
||||||
|
/// field_boosts: A dictionary keyed on field names which provides default boosts
|
||||||
|
/// for the query constructed by this method.
|
||||||
|
///
|
||||||
|
/// fuzzy_fields: A dictionary keyed on field names which provides (prefix, distance, transpose_cost_one)
|
||||||
|
/// triples making queries constructed by this method fuzzy against the given fields
|
||||||
|
/// and using the given parameters.
|
||||||
|
/// `prefix` determines if terms which are prefixes of the given term match the query.
|
||||||
|
/// `distance` determines the maximum Levenshtein distance between terms matching the query and the given term.
|
||||||
|
/// `transpose_cost_one` determines if transpositions of neighbouring characters are counted only once against the Levenshtein distance.
|
||||||
|
///
|
||||||
/// Returns a tuple containing the parsed query and a list of errors.
|
/// Returns a tuple containing the parsed query and a list of errors.
|
||||||
///
|
///
|
||||||
/// Raises ValueError if a field in `default_field_names` is not defined or marked as indexed.
|
/// Raises ValueError if a field in `default_field_names` is not defined or marked as indexed.
|
||||||
#[pyo3(signature = (query, default_field_names = None))]
|
#[pyo3(signature = (query, default_field_names = None, field_boosts = HashMap::new(), fuzzy_fields = HashMap::new()))]
|
||||||
pub fn parse_query_lenient(
|
pub fn parse_query_lenient(
|
||||||
&self,
|
&self,
|
||||||
query: &str,
|
query: &str,
|
||||||
default_field_names: Option<Vec<String>>,
|
default_field_names: Option<Vec<String>>,
|
||||||
|
field_boosts: HashMap<String, tv::Score>,
|
||||||
|
fuzzy_fields: HashMap<String, (bool, u8, bool)>,
|
||||||
|
py: Python,
|
||||||
) -> PyResult<(Query, Vec<PyObject>)> {
|
) -> PyResult<(Query, Vec<PyObject>)> {
|
||||||
let schema = self.index.schema();
|
let parser = self.prepare_query_parser(
|
||||||
|
default_field_names,
|
||||||
|
field_boosts,
|
||||||
|
fuzzy_fields,
|
||||||
|
)?;
|
||||||
|
|
||||||
let default_fields = if let Some(default_field_names_vec) =
|
|
||||||
default_field_names
|
|
||||||
{
|
|
||||||
default_field_names_vec
|
|
||||||
.iter()
|
|
||||||
.map(|field_name| {
|
|
||||||
schema
|
|
||||||
.get_field(field_name)
|
|
||||||
.map_err(|_err| {
|
|
||||||
exceptions::PyValueError::new_err(format!(
|
|
||||||
"Field `{field_name}` is not defined in the schema."
|
|
||||||
))
|
|
||||||
})
|
|
||||||
.and_then(|field| {
|
|
||||||
schema.get_field_entry(field).is_indexed().then_some(field).ok_or(
|
|
||||||
exceptions::PyValueError::new_err(
|
|
||||||
format!(
|
|
||||||
"Field `{field_name}` is not set as indexed in the schema."
|
|
||||||
),
|
|
||||||
))
|
|
||||||
})
|
|
||||||
}).collect::<Result<Vec<_>, _>>()?
|
|
||||||
} else {
|
|
||||||
self.index
|
|
||||||
.schema()
|
|
||||||
.fields()
|
|
||||||
.filter_map(|(f, fe)| fe.is_indexed().then_some(f))
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
};
|
|
||||||
|
|
||||||
let parser =
|
|
||||||
tv::query::QueryParser::for_index(&self.index, default_fields);
|
|
||||||
let (query, errors) = parser.parse_query_lenient(query);
|
let (query, errors) = parser.parse_query_lenient(query);
|
||||||
|
let errors = errors.into_iter().map(|err| err.into_py(py)).collect();
|
||||||
Python::with_gil(|py| {
|
|
||||||
let errors =
|
|
||||||
errors.into_iter().map(|err| err.into_py(py)).collect();
|
|
||||||
|
|
||||||
Ok((Query { inner: query }, errors))
|
Ok((Query { inner: query }, errors))
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Index {
|
impl Index {
|
||||||
|
fn prepare_query_parser(
|
||||||
|
&self,
|
||||||
|
default_field_names: Option<Vec<String>>,
|
||||||
|
field_boosts: HashMap<String, tv::Score>,
|
||||||
|
fuzzy_fields: HashMap<String, (bool, u8, bool)>,
|
||||||
|
) -> PyResult<tv::query::QueryParser> {
|
||||||
|
let schema = self.index.schema();
|
||||||
|
|
||||||
|
let default_fields = if let Some(default_field_names) =
|
||||||
|
default_field_names
|
||||||
|
{
|
||||||
|
default_field_names.iter().map(|field_name| {
|
||||||
|
let field = schema.get_field(field_name).map_err(|_err| {
|
||||||
|
exceptions::PyValueError::new_err(format!(
|
||||||
|
"Field `{field_name}` is not defined in the schema."
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let field_entry = schema.get_field_entry(field);
|
||||||
|
if !field_entry.is_indexed() {
|
||||||
|
return Err(exceptions::PyValueError::new_err(
|
||||||
|
format!("Field `{field_name}` is not set as indexed in the schema.")
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(field)
|
||||||
|
}).collect::<PyResult<_>>()?
|
||||||
|
} else {
|
||||||
|
schema
|
||||||
|
.fields()
|
||||||
|
.filter(|(_, field_entry)| field_entry.is_indexed())
|
||||||
|
.map(|(field, _)| field)
|
||||||
|
.collect()
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut parser =
|
||||||
|
tv::query::QueryParser::for_index(&self.index, default_fields);
|
||||||
|
|
||||||
|
for (field_name, boost) in field_boosts {
|
||||||
|
let field = schema.get_field(&field_name).map_err(|_err| {
|
||||||
|
exceptions::PyValueError::new_err(format!(
|
||||||
|
"Field `{field_name}` is not defined in the schema."
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
parser.set_field_boost(field, boost);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (field_name, (prefix, distance, transpose_cost_one)) in fuzzy_fields
|
||||||
|
{
|
||||||
|
let field = schema.get_field(&field_name).map_err(|_err| {
|
||||||
|
exceptions::PyValueError::new_err(format!(
|
||||||
|
"Field `{field_name}` is not defined in the schema."
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
parser.set_field_fuzzy(field, prefix, distance, transpose_cost_one);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(parser)
|
||||||
|
}
|
||||||
|
|
||||||
fn register_custom_text_analyzers(index: &tv::Index) {
|
fn register_custom_text_analyzers(index: &tv::Index) {
|
||||||
let analyzers = [
|
let analyzers = [
|
||||||
("ar_stem", Language::Arabic),
|
("ar_stem", Language::Arabic),
|
||||||
|
|
|
@ -304,10 +304,7 @@ impl ExpectedBase64Error {
|
||||||
/// If `true`, an invalid byte was found in the query. Padding characters (`=`) interspersed in
|
/// If `true`, an invalid byte was found in the query. Padding characters (`=`) interspersed in
|
||||||
/// the encoded form will be treated as invalid bytes.
|
/// the encoded form will be treated as invalid bytes.
|
||||||
fn caused_by_invalid_byte(&self) -> bool {
|
fn caused_by_invalid_byte(&self) -> bool {
|
||||||
match self.decode_error {
|
matches!(self.decode_error, base64::DecodeError::InvalidByte { .. })
|
||||||
base64::DecodeError::InvalidByte { .. } => true,
|
|
||||||
_ => false,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// If the error was caused by an invalid byte, returns the offset and offending byte.
|
/// If the error was caused by an invalid byte, returns the offset and offending byte.
|
||||||
|
@ -322,19 +319,16 @@ impl ExpectedBase64Error {
|
||||||
|
|
||||||
/// If `true`, the length of the base64 string was invalid.
|
/// If `true`, the length of the base64 string was invalid.
|
||||||
fn caused_by_invalid_length(&self) -> bool {
|
fn caused_by_invalid_length(&self) -> bool {
|
||||||
match self.decode_error {
|
matches!(self.decode_error, base64::DecodeError::InvalidLength)
|
||||||
base64::DecodeError::InvalidLength => true,
|
|
||||||
_ => false,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The last non-padding input symbol's encoded 6 bits have nonzero bits that will be discarded.
|
/// The last non-padding input symbol's encoded 6 bits have nonzero bits that will be discarded.
|
||||||
/// If `true`, this is indicative of corrupted or truncated Base64.
|
/// If `true`, this is indicative of corrupted or truncated Base64.
|
||||||
fn caused_by_invalid_last_symbol(&self) -> bool {
|
fn caused_by_invalid_last_symbol(&self) -> bool {
|
||||||
match self.decode_error {
|
matches!(
|
||||||
base64::DecodeError::InvalidLastSymbol { .. } => true,
|
self.decode_error,
|
||||||
_ => false,
|
base64::DecodeError::InvalidLastSymbol { .. }
|
||||||
}
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// If the error was caused by an invalid last symbol, returns the offset and offending byte.
|
/// If the error was caused by an invalid last symbol, returns the offset and offending byte.
|
||||||
|
@ -350,10 +344,7 @@ impl ExpectedBase64Error {
|
||||||
/// The nature of the padding was not as configured: absent or incorrect when it must be
|
/// The nature of the padding was not as configured: absent or incorrect when it must be
|
||||||
/// canonical, or present when it must be absent, etc.
|
/// canonical, or present when it must be absent, etc.
|
||||||
fn caused_by_invalid_padding(&self) -> bool {
|
fn caused_by_invalid_padding(&self) -> bool {
|
||||||
match self.decode_error {
|
matches!(self.decode_error, base64::DecodeError::InvalidPadding)
|
||||||
base64::DecodeError::InvalidPadding => true,
|
|
||||||
_ => false,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn __repr__(&self) -> String {
|
fn __repr__(&self) -> String {
|
||||||
|
|
|
@ -62,10 +62,10 @@ impl SnippetGenerator {
|
||||||
tv::SnippetGenerator::create(&searcher.inner, query.get(), field)
|
tv::SnippetGenerator::create(&searcher.inner, query.get(), field)
|
||||||
.map_err(to_pyerr)?;
|
.map_err(to_pyerr)?;
|
||||||
|
|
||||||
return Ok(SnippetGenerator {
|
Ok(SnippetGenerator {
|
||||||
field_name: field_name.to_string(),
|
field_name: field_name.to_string(),
|
||||||
inner: generator,
|
inner: generator,
|
||||||
});
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn snippet_from_doc(&self, doc: &crate::Document) -> crate::Snippet {
|
pub fn snippet_from_doc(&self, doc: &crate::Document) -> crate::Snippet {
|
||||||
|
|
|
@ -94,6 +94,20 @@ class TestClass(object):
|
||||||
== """Query(BooleanQuery { subqueries: [(Should, TermQuery(Term(field=0, type=Str, "winter"))), (Should, TermQuery(Term(field=1, type=Str, "winter")))] })"""
|
== """Query(BooleanQuery { subqueries: [(Should, TermQuery(Term(field=0, type=Str, "winter"))), (Should, TermQuery(Term(field=1, type=Str, "winter")))] })"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_parse_query_field_boosts(self, ram_index):
|
||||||
|
query = ram_index.parse_query("winter", field_boosts={"title": 2.3})
|
||||||
|
assert (
|
||||||
|
repr(query)
|
||||||
|
== """Query(BooleanQuery { subqueries: [(Should, Boost(query=TermQuery(Term(field=0, type=Str, "winter")), boost=2.3)), (Should, TermQuery(Term(field=1, type=Str, "winter")))] })"""
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_parse_query_field_boosts(self, ram_index):
|
||||||
|
query = ram_index.parse_query("winter", fuzzy_fields={"title": (True, 1, False)})
|
||||||
|
assert (
|
||||||
|
repr(query)
|
||||||
|
== """Query(BooleanQuery { subqueries: [(Should, FuzzyTermQuery { term: Term(field=0, type=Str, "winter"), distance: 1, transposition_cost_one: false, prefix: true }), (Should, TermQuery(Term(field=1, type=Str, "winter")))] })"""
|
||||||
|
)
|
||||||
|
|
||||||
def test_query_errors(self, ram_index):
|
def test_query_errors(self, ram_index):
|
||||||
index = ram_index
|
index = ram_index
|
||||||
# no "bod" field
|
# no "bod" field
|
||||||
|
|
Loading…
Reference in New Issue