Expose Tantivy's MoreLikeThisQuery (#257)

master
Tomoko Uchida 2024-05-04 05:15:21 +09:00 committed by GitHub
parent 03b1c89fa3
commit 9fafdf25cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 99 additions and 4 deletions

View File

@ -1,4 +1,4 @@
use crate::{get_field, make_term, to_pyerr, Schema}; use crate::{get_field, make_term, to_pyerr, DocAddress, Schema};
use pyo3::{ use pyo3::{
exceptions, exceptions,
prelude::*, prelude::*,
@ -100,7 +100,7 @@ impl Query {
let terms = field_values let terms = field_values
.into_iter() .into_iter()
.map(|field_value| { .map(|field_value| {
make_term(&schema.inner, field_name, &field_value) make_term(&schema.inner, field_name, field_value)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let inner = tv::query::TermSetQuery::new(terms); let inner = tv::query::TermSetQuery::new(terms);
@ -138,7 +138,7 @@ impl Query {
transposition_cost_one: bool, transposition_cost_one: bool,
prefix: bool, prefix: bool,
) -> PyResult<Query> { ) -> PyResult<Query> {
let term = make_term(&schema.inner, field_name, &text)?; let term = make_term(&schema.inner, field_name, text)?;
let inner = if prefix { let inner = if prefix {
tv::query::FuzzyTermQuery::new_prefix( tv::query::FuzzyTermQuery::new_prefix(
term, term,
@ -272,6 +272,50 @@ impl Query {
} }
} }
#[staticmethod]
#[pyo3(signature = (doc_address, min_doc_frequency = Some(5), max_doc_frequency = None, min_term_frequency = Some(2), max_query_terms = Some(25), min_word_length = None, max_word_length = None, boost_factor = Some(1.0), stop_words = vec![]))]
#[allow(clippy::too_many_arguments)]
pub(crate) fn more_like_this_query(
doc_address: &DocAddress,
min_doc_frequency: Option<u64>,
max_doc_frequency: Option<u64>,
min_term_frequency: Option<usize>,
max_query_terms: Option<usize>,
min_word_length: Option<usize>,
max_word_length: Option<usize>,
boost_factor: Option<f32>,
stop_words: Vec<String>,
) -> PyResult<Query> {
let mut builder = tv::query::MoreLikeThisQuery::builder();
if let Some(value) = min_doc_frequency {
builder = builder.with_min_doc_frequency(value);
}
if let Some(value) = max_doc_frequency {
builder = builder.with_max_doc_frequency(value);
}
if let Some(value) = min_term_frequency {
builder = builder.with_min_term_frequency(value);
}
if let Some(value) = max_query_terms {
builder = builder.with_max_query_terms(value);
}
if let Some(value) = min_word_length {
builder = builder.with_min_word_length(value);
}
if let Some(value) = max_word_length {
builder = builder.with_max_word_length(value);
}
if let Some(value) = boost_factor {
builder = builder.with_boost_factor(value);
}
builder = builder.with_stop_words(stop_words);
let inner = builder.with_document(tv::DocAddress::from(doc_address));
Ok(Query {
inner: Box::new(inner),
})
}
/// Construct a Tantivy's ConstScoreQuery /// Construct a Tantivy's ConstScoreQuery
#[staticmethod] #[staticmethod]
#[pyo3(signature = (query, score))] #[pyo3(signature = (query, score))]

View File

@ -153,6 +153,7 @@ impl Searcher {
/// ///
/// Raises a ValueError if there was an error with the search. /// Raises a ValueError if there was an error with the search.
#[pyo3(signature = (query, limit = 10, count = true, order_by_field = None, offset = 0, order = Order::Desc))] #[pyo3(signature = (query, limit = 10, count = true, order_by_field = None, offset = 0, order = Order::Desc))]
#[allow(clippy::too_many_arguments)]
fn search( fn search(
&self, &self,
py: Python, py: Python,

View File

@ -227,6 +227,20 @@ class Query:
def regex_query(schema: Schema, field_name: str, regex_pattern: str) -> Query: def regex_query(schema: Schema, field_name: str, regex_pattern: str) -> Query:
pass pass
@staticmethod
def more_like_this_query(
doc_address: DocAddress,
min_doc_frequency: Optional[int] = 5,
max_doc_frequency: Optional[int] = None,
min_term_frequency: Optional[int] = 2,
max_query_terms: Optional[int] = 25,
min_word_length: Optional[int] = None,
max_word_length: Optional[int] = None,
boost_factor: Optional[float] = 1.0,
stop_words: list[str] = []
) -> Query:
pass
@staticmethod @staticmethod
def const_score_query(query: Query, score: float) -> Query: def const_score_query(query: Query, score: float) -> Query:
pass pass

View File

@ -1090,6 +1090,42 @@ class TestQuery(object):
): ):
Query.regex_query(index.schema, "body", "fish(") Query.regex_query(index.schema, "body", "fish(")
def test_more_like_this_query(self, ram_index):
index = ram_index
# first, search the target doc
query = Query.term_query(index.schema, "title", "man")
result = index.searcher().search(query, 1)
_, doc_address = result.hits[0]
searched_doc = index.searcher().doc(doc_address)
assert searched_doc["title"] == ["The Old Man and the Sea"]
# construct the default MLT Query
mlt_query = Query.more_like_this_query(doc_address)
assert (
repr(mlt_query)
== "Query(MoreLikeThisQuery { mlt: MoreLikeThis { min_doc_frequency: Some(5), max_doc_frequency: None, min_term_frequency: Some(2), max_query_terms: Some(25), min_word_length: None, max_word_length: None, boost_factor: Some(1.0), stop_words: [] }, target: DocumentAdress(DocAddress { segment_ord: 0, doc_id: 0 }) })"
)
result = index.searcher().search(mlt_query, 10)
assert len(result.hits) == 0
# construct a fine-tuned MLT Query
mlt_query = Query.more_like_this_query(
doc_address,
min_doc_frequency=2,
max_doc_frequency=10,
min_term_frequency=1,
max_query_terms=10,
min_word_length=2,
max_word_length=20,
boost_factor=2.0,
stop_words=["fish"])
assert (
repr(mlt_query)
== "Query(MoreLikeThisQuery { mlt: MoreLikeThis { min_doc_frequency: Some(2), max_doc_frequency: Some(10), min_term_frequency: Some(1), max_query_terms: Some(10), min_word_length: Some(2), max_word_length: Some(20), boost_factor: Some(2.0), stop_words: [\"fish\"] }, target: DocumentAdress(DocAddress { segment_ord: 0, doc_id: 0 }) })"
)
result = index.searcher().search(mlt_query, 10)
assert len(result.hits) > 0
def test_const_score_query(self, ram_index): def test_const_score_query(self, ram_index):
index = ram_index index = ram_index