Expose const score query (#256)
parent
8216f17d60
commit
10553de76e
16
src/query.rs
16
src/query.rs
|
@ -157,6 +157,7 @@ impl Query {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Construct a Tantivy's BooleanQuery
|
||||||
#[staticmethod]
|
#[staticmethod]
|
||||||
#[pyo3(signature = (subqueries))]
|
#[pyo3(signature = (subqueries))]
|
||||||
pub(crate) fn boolean_query(
|
pub(crate) fn boolean_query(
|
||||||
|
@ -199,6 +200,7 @@ impl Query {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Construct a Tantivy's BoostQuery
|
||||||
#[staticmethod]
|
#[staticmethod]
|
||||||
#[pyo3(signature = (query, boost))]
|
#[pyo3(signature = (query, boost))]
|
||||||
pub(crate) fn boost_query(query: Query, boost: f32) -> PyResult<Query> {
|
pub(crate) fn boost_query(query: Query, boost: f32) -> PyResult<Query> {
|
||||||
|
@ -208,6 +210,7 @@ impl Query {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Construct a Tantivy's RegexQuery
|
||||||
#[staticmethod]
|
#[staticmethod]
|
||||||
#[pyo3(signature = (schema, field_name, regex_pattern))]
|
#[pyo3(signature = (schema, field_name, regex_pattern))]
|
||||||
pub(crate) fn regex_query(
|
pub(crate) fn regex_query(
|
||||||
|
@ -226,4 +229,17 @@ impl Query {
|
||||||
Err(e) => Err(to_pyerr(e)),
|
Err(e) => Err(to_pyerr(e)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Construct a Tantivy's ConstScoreQuery
|
||||||
|
#[staticmethod]
|
||||||
|
#[pyo3(signature = (query, score))]
|
||||||
|
pub(crate) fn const_score_query(
|
||||||
|
query: Query,
|
||||||
|
score: f32,
|
||||||
|
) -> PyResult<Query> {
|
||||||
|
let inner = tv::query::ConstScoreQuery::new(query.inner, score);
|
||||||
|
Ok(Query {
|
||||||
|
inner: Box::new(inner),
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -233,6 +233,10 @@ class Query:
|
||||||
def regex_query(schema: Schema, field_name: str, regex_pattern: str) -> Query:
|
def regex_query(schema: Schema, field_name: str, regex_pattern: str) -> Query:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def const_score_query(query: Query, score: float) -> Query:
|
||||||
|
pass
|
||||||
|
|
||||||
class Order(Enum):
|
class Order(Enum):
|
||||||
Asc = 1
|
Asc = 1
|
||||||
Desc = 2
|
Desc = 2
|
||||||
|
|
|
@ -1057,3 +1057,34 @@ class TestQuery(object):
|
||||||
ValueError, match=r"An invalid argument was passed: 'fish\('"
|
ValueError, match=r"An invalid argument was passed: 'fish\('"
|
||||||
):
|
):
|
||||||
Query.regex_query(index.schema, "body", "fish(")
|
Query.regex_query(index.schema, "body", "fish(")
|
||||||
|
|
||||||
|
def test_const_score_query(self, ram_index):
|
||||||
|
index = ram_index
|
||||||
|
|
||||||
|
query = Query.regex_query(index.schema, "body", "fish")
|
||||||
|
const_score_query = Query.const_score_query(
|
||||||
|
query, score = 1.0
|
||||||
|
)
|
||||||
|
result = index.searcher().search(const_score_query, 10)
|
||||||
|
assert len(result.hits) == 1
|
||||||
|
score, _ = result.hits[0]
|
||||||
|
# the score should be 1.0
|
||||||
|
assert score == pytest.approx(1.0)
|
||||||
|
|
||||||
|
const_score_query = Query.const_score_query(
|
||||||
|
Query.const_score_query(
|
||||||
|
query, score = 1.0
|
||||||
|
), score = 0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
result = index.searcher().search(const_score_query, 10)
|
||||||
|
assert len(result.hits) == 1
|
||||||
|
score, _ = result.hits[0]
|
||||||
|
# nested const score queries should retain the
|
||||||
|
# score of the outer query
|
||||||
|
assert score == pytest.approx(0.5)
|
||||||
|
|
||||||
|
# wrong score type
|
||||||
|
with pytest.raises(TypeError, match = r"argument 'score': must be real number, not str"):
|
||||||
|
Query.const_score_query(query, "0.1")
|
||||||
|
|
Loading…
Reference in New Issue