Expose const score query (#256)

master
alex-au-922 2024-04-28 18:54:21 +08:00 committed by GitHub
parent 8216f17d60
commit 10553de76e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 51 additions and 0 deletions

View File

@ -157,6 +157,7 @@ impl Query {
})
}
/// Construct a Tantivy's BooleanQuery
#[staticmethod]
#[pyo3(signature = (subqueries))]
pub(crate) fn boolean_query(
@ -199,6 +200,7 @@ impl Query {
})
}
/// Construct a Tantivy's BoostQuery
#[staticmethod]
#[pyo3(signature = (query, boost))]
pub(crate) fn boost_query(query: Query, boost: f32) -> PyResult<Query> {
@ -208,6 +210,7 @@ impl Query {
})
}
/// Construct a Tantivy's RegexQuery
#[staticmethod]
#[pyo3(signature = (schema, field_name, regex_pattern))]
pub(crate) fn regex_query(
@ -226,4 +229,17 @@ impl Query {
Err(e) => Err(to_pyerr(e)),
}
}
/// Construct a Tantivy's ConstScoreQuery
#[staticmethod]
#[pyo3(signature = (query, score))]
pub(crate) fn const_score_query(
query: Query,
score: f32,
) -> PyResult<Query> {
let inner = tv::query::ConstScoreQuery::new(query.inner, score);
Ok(Query {
inner: Box::new(inner),
})
}
}

View File

@ -233,6 +233,10 @@ class Query:
def regex_query(schema: Schema, field_name: str, regex_pattern: str) -> Query:
pass
@staticmethod
def const_score_query(query: Query, score: float) -> Query:
pass
class Order(Enum):
Asc = 1
Desc = 2

View File

@ -1057,3 +1057,34 @@ class TestQuery(object):
ValueError, match=r"An invalid argument was passed: 'fish\('"
):
Query.regex_query(index.schema, "body", "fish(")
def test_const_score_query(self, ram_index):
index = ram_index
query = Query.regex_query(index.schema, "body", "fish")
const_score_query = Query.const_score_query(
query, score = 1.0
)
result = index.searcher().search(const_score_query, 10)
assert len(result.hits) == 1
score, _ = result.hits[0]
# the score should be 1.0
assert score == pytest.approx(1.0)
const_score_query = Query.const_score_query(
Query.const_score_query(
query, score = 1.0
), score = 0.5
)
result = index.searcher().search(const_score_query, 10)
assert len(result.hits) == 1
score, _ = result.hits[0]
# nested const score queries should retain the
# score of the outer query
assert score == pytest.approx(0.5)
# wrong score type
with pytest.raises(TypeError, match = r"argument 'score': must be real number, not str"):
Query.const_score_query(query, "0.1")