Expose const score query (#256)
parent
8216f17d60
commit
10553de76e
16
src/query.rs
16
src/query.rs
|
@ -157,6 +157,7 @@ impl Query {
|
|||
})
|
||||
}
|
||||
|
||||
/// Construct a Tantivy's BooleanQuery
|
||||
#[staticmethod]
|
||||
#[pyo3(signature = (subqueries))]
|
||||
pub(crate) fn boolean_query(
|
||||
|
@ -199,6 +200,7 @@ impl Query {
|
|||
})
|
||||
}
|
||||
|
||||
/// Construct a Tantivy's BoostQuery
|
||||
#[staticmethod]
|
||||
#[pyo3(signature = (query, boost))]
|
||||
pub(crate) fn boost_query(query: Query, boost: f32) -> PyResult<Query> {
|
||||
|
@ -208,6 +210,7 @@ impl Query {
|
|||
})
|
||||
}
|
||||
|
||||
/// Construct a Tantivy's RegexQuery
|
||||
#[staticmethod]
|
||||
#[pyo3(signature = (schema, field_name, regex_pattern))]
|
||||
pub(crate) fn regex_query(
|
||||
|
@ -226,4 +229,17 @@ impl Query {
|
|||
Err(e) => Err(to_pyerr(e)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a Tantivy's ConstScoreQuery
|
||||
#[staticmethod]
|
||||
#[pyo3(signature = (query, score))]
|
||||
pub(crate) fn const_score_query(
|
||||
query: Query,
|
||||
score: f32,
|
||||
) -> PyResult<Query> {
|
||||
let inner = tv::query::ConstScoreQuery::new(query.inner, score);
|
||||
Ok(Query {
|
||||
inner: Box::new(inner),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -233,6 +233,10 @@ class Query:
|
|||
def regex_query(schema: Schema, field_name: str, regex_pattern: str) -> Query:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def const_score_query(query: Query, score: float) -> Query:
|
||||
pass
|
||||
|
||||
class Order(Enum):
|
||||
Asc = 1
|
||||
Desc = 2
|
||||
|
|
|
@ -1057,3 +1057,34 @@ class TestQuery(object):
|
|||
ValueError, match=r"An invalid argument was passed: 'fish\('"
|
||||
):
|
||||
Query.regex_query(index.schema, "body", "fish(")
|
||||
|
||||
def test_const_score_query(self, ram_index):
|
||||
index = ram_index
|
||||
|
||||
query = Query.regex_query(index.schema, "body", "fish")
|
||||
const_score_query = Query.const_score_query(
|
||||
query, score = 1.0
|
||||
)
|
||||
result = index.searcher().search(const_score_query, 10)
|
||||
assert len(result.hits) == 1
|
||||
score, _ = result.hits[0]
|
||||
# the score should be 1.0
|
||||
assert score == pytest.approx(1.0)
|
||||
|
||||
const_score_query = Query.const_score_query(
|
||||
Query.const_score_query(
|
||||
query, score = 1.0
|
||||
), score = 0.5
|
||||
)
|
||||
|
||||
result = index.searcher().search(const_score_query, 10)
|
||||
assert len(result.hits) == 1
|
||||
score, _ = result.hits[0]
|
||||
# nested const score queries should retain the
|
||||
# score of the outer query
|
||||
assert score == pytest.approx(0.5)
|
||||
|
||||
# wrong score type
|
||||
with pytest.raises(TypeError, match = r"argument 'score': must be real number, not str"):
|
||||
Query.const_score_query(query, "0.1")
|
||||
|
Loading…
Reference in New Issue