tantivy-py/src/query.rs

246 lines
7.2 KiB
Rust
Raw Normal View History

use crate::{get_field, make_term, to_pyerr, Schema};
2024-04-22 23:36:48 +00:00
use pyo3::{
exceptions,
prelude::*,
types::{PyAny, PyFloat, PyString, PyTuple},
2024-04-22 23:36:48 +00:00
};
use tantivy as tv;
/// Custom Tuple struct to represent a pair of Occur and Query
/// for the BooleanQuery
struct OccurQueryPair(Occur, Query);
2024-04-22 23:36:48 +00:00
impl<'source> FromPyObject<'source> for OccurQueryPair {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let tuple = ob.downcast::<PyTuple>()?;
let occur = tuple.get_item(0)?.extract()?;
let query = tuple.get_item(1)?.extract()?;
Ok(OccurQueryPair(occur, query))
}
}
/// Tantivy's Occur
#[pyclass(frozen, module = "tantivy.tantivy")]
#[derive(Clone)]
pub enum Occur {
Must,
Should,
MustNot,
}
impl From<Occur> for tv::query::Occur {
fn from(occur: Occur) -> tv::query::Occur {
match occur {
Occur::Must => tv::query::Occur::Must,
Occur::Should => tv::query::Occur::Should,
Occur::MustNot => tv::query::Occur::MustNot,
}
}
}
/// Tantivy's Query
2024-01-21 20:16:34 +00:00
#[pyclass(frozen, module = "tantivy.tantivy")]
pub(crate) struct Query {
pub(crate) inner: Box<dyn tv::query::Query>,
}
impl Clone for Query {
fn clone(&self) -> Self {
Query {
inner: self.inner.box_clone(),
}
}
}
impl Query {
pub(crate) fn get(&self) -> &dyn tv::query::Query {
&self.inner
}
}
2022-04-15 03:50:37 +00:00
#[pymethods]
impl Query {
2019-08-02 11:23:10 +00:00
fn __repr__(&self) -> PyResult<String> {
Ok(format!("Query({:?})", self.get()))
}
2023-12-20 09:40:50 +00:00
/// Construct a Tantivy's TermQuery
#[staticmethod]
#[pyo3(signature = (schema, field_name, field_value, index_option = "position"))]
pub(crate) fn term_query(
schema: &Schema,
field_name: &str,
field_value: &PyAny,
index_option: &str,
) -> PyResult<Query> {
let term = make_term(&schema.inner, field_name, field_value)?;
let index_option = match index_option {
"position" => tv::schema::IndexRecordOption::WithFreqsAndPositions,
"freq" => tv::schema::IndexRecordOption::WithFreqs,
"basic" => tv::schema::IndexRecordOption::Basic,
_ => return Err(exceptions::PyValueError::new_err(
"Invalid index option, valid choices are: 'basic', 'freq' and 'position'"
))
};
let inner = tv::query::TermQuery::new(term, index_option);
Ok(Query {
inner: Box::new(inner),
})
}
2024-03-31 11:56:22 +00:00
2024-04-26 11:21:46 +00:00
/// Construct a Tantivy's TermSetQuery
#[staticmethod]
#[pyo3(signature = (schema, field_name, field_values))]
pub(crate) fn term_set_query(
schema: &Schema,
field_name: &str,
field_values: Vec<&PyAny>,
) -> PyResult<Query> {
let terms = field_values
.into_iter()
.map(|field_value| {
make_term(&schema.inner, field_name, &field_value)
})
.collect::<Result<Vec<_>, _>>()?;
let inner = tv::query::TermSetQuery::new(terms);
Ok(Query {
inner: Box::new(inner),
})
}
2024-03-31 11:56:22 +00:00
/// Construct a Tantivy's AllQuery
#[staticmethod]
pub(crate) fn all_query() -> PyResult<Query> {
let inner = tv::query::AllQuery {};
Ok(Query {
inner: Box::new(inner),
})
}
2024-04-13 09:14:56 +00:00
/// Construct a Tantivy's FuzzyTermQuery
///
/// # Arguments
///
/// * `schema` - Schema of the target index.
/// * `field_name` - Field name to be searched.
/// * `text` - String representation of the query term.
/// * `distance` - (Optional) Edit distance you are going to alow. When not specified, the default is 1.
/// * `transposition_cost_one` - (Optional) If true, a transposition (swapping) cost will be 1; otherwise it will be 2. When not specified, the default is true.
/// * `prefix` - (Optional) If true, prefix levenshtein distance is applied. When not specified, the default is false.
2024-04-13 09:14:56 +00:00
#[staticmethod]
#[pyo3(signature = (schema, field_name, text, distance = 1, transposition_cost_one = true, prefix = false))]
pub(crate) fn fuzzy_term_query(
schema: &Schema,
field_name: &str,
text: &PyString,
distance: u8,
transposition_cost_one: bool,
prefix: bool,
) -> PyResult<Query> {
let term = make_term(&schema.inner, field_name, &text)?;
let inner = if prefix {
tv::query::FuzzyTermQuery::new_prefix(
term,
distance,
transposition_cost_one,
)
} else {
tv::query::FuzzyTermQuery::new(
term,
distance,
transposition_cost_one,
)
};
Ok(Query {
inner: Box::new(inner),
})
}
2024-04-28 10:54:21 +00:00
/// Construct a Tantivy's BooleanQuery
#[staticmethod]
#[pyo3(signature = (subqueries))]
pub(crate) fn boolean_query(
2024-04-22 23:36:48 +00:00
subqueries: Vec<(Occur, Query)>,
) -> PyResult<Query> {
let dyn_subqueries = subqueries
.into_iter()
.map(|(occur, query)| (occur.into(), query.inner.box_clone()))
.collect::<Vec<_>>();
2024-04-22 23:36:48 +00:00
let inner = tv::query::BooleanQuery::from(dyn_subqueries);
Ok(Query {
inner: Box::new(inner),
})
}
/// Construct a Tantivy's DisjunctionMaxQuery
#[staticmethod]
pub(crate) fn disjunction_max_query(
subqueries: Vec<Query>,
tie_breaker: Option<&PyFloat>,
) -> PyResult<Query> {
let inner_queries: Vec<Box<dyn tv::query::Query>> = subqueries
.iter()
.map(|query| query.inner.box_clone())
.collect();
let dismax_query = if let Some(tie_breaker) = tie_breaker {
tv::query::DisjunctionMaxQuery::with_tie_breaker(
inner_queries,
tie_breaker.extract::<f32>()?,
)
} else {
tv::query::DisjunctionMaxQuery::new(inner_queries)
};
Ok(Query {
inner: Box::new(dismax_query),
})
}
2024-04-24 21:57:16 +00:00
2024-04-28 10:54:21 +00:00
/// Construct a Tantivy's BoostQuery
2024-04-24 21:57:16 +00:00
#[staticmethod]
#[pyo3(signature = (query, boost))]
pub(crate) fn boost_query(query: Query, boost: f32) -> PyResult<Query> {
let inner = tv::query::BoostQuery::new(query.inner, boost);
Ok(Query {
inner: Box::new(inner),
})
}
2024-04-28 10:54:21 +00:00
/// Construct a Tantivy's RegexQuery
#[staticmethod]
#[pyo3(signature = (schema, field_name, regex_pattern))]
pub(crate) fn regex_query(
schema: &Schema,
field_name: &str,
regex_pattern: &str,
) -> PyResult<Query> {
let field = get_field(&schema.inner, field_name)?;
let inner_result =
tv::query::RegexQuery::from_pattern(regex_pattern, field);
match inner_result {
Ok(inner) => Ok(Query {
inner: Box::new(inner),
}),
Err(e) => Err(to_pyerr(e)),
}
}
2024-04-28 10:54:21 +00:00
/// Construct a Tantivy's ConstScoreQuery
#[staticmethod]
#[pyo3(signature = (query, score))]
pub(crate) fn const_score_query(
query: Query,
score: f32,
) -> PyResult<Query> {
let inner = tv::query::ConstScoreQuery::new(query.inner, score);
Ok(Query {
inner: Box::new(inner),
})
}
}