Expose boost query (#250)
parent
ed7374c7bd
commit
c74990aeb8
|
@ -178,4 +178,13 @@ impl Query {
|
||||||
inner: Box::new(dismax_query),
|
inner: Box::new(dismax_query),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[staticmethod]
|
||||||
|
#[pyo3(signature = (query, boost))]
|
||||||
|
pub(crate) fn boost_query(query: Query, boost: f32) -> PyResult<Query> {
|
||||||
|
let inner = tv::query::BoostQuery::new(query.inner, boost);
|
||||||
|
Ok(Query {
|
||||||
|
inner: Box::new(inner),
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -212,6 +212,10 @@ class Query:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def disjunction_max_query(subqueries: Sequence[Query], tie_breaker: Optional[float] = None) -> Query:
|
def disjunction_max_query(subqueries: Sequence[Query], tie_breaker: Optional[float] = None) -> Query:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def boost_query(query: Query, boost: float) -> Query:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Order(Enum):
|
class Order(Enum):
|
||||||
|
|
|
@ -827,20 +827,20 @@ class TestQuery(object):
|
||||||
(Occur.Must, query1),
|
(Occur.Must, query1),
|
||||||
(Occur.Must, query2)
|
(Occur.Must, query2)
|
||||||
])
|
])
|
||||||
|
|
||||||
# no document should match both queries
|
# no document should match both queries
|
||||||
result = index.searcher().search(query, 10)
|
result = index.searcher().search(query, 10)
|
||||||
assert len(result.hits) == 0
|
assert len(result.hits) == 0
|
||||||
|
|
||||||
query = Query.boolean_query([
|
query = Query.boolean_query([
|
||||||
(Occur.Should, query1),
|
(Occur.Should, query1),
|
||||||
(Occur.Should, query2)
|
(Occur.Should, query2)
|
||||||
])
|
])
|
||||||
|
|
||||||
# two documents should match, one for each query
|
# two documents should match, one for each query
|
||||||
result = index.searcher().search(query, 10)
|
result = index.searcher().search(query, 10)
|
||||||
assert len(result.hits) == 2
|
assert len(result.hits) == 2
|
||||||
|
|
||||||
titles = set()
|
titles = set()
|
||||||
for _, doc_address in result.hits:
|
for _, doc_address in result.hits:
|
||||||
titles.update(index.searcher().doc(doc_address)["title"])
|
titles.update(index.searcher().doc(doc_address)["title"])
|
||||||
|
@ -848,31 +848,31 @@ class TestQuery(object):
|
||||||
"The Old Man and the Sea" in titles and
|
"The Old Man and the Sea" in titles and
|
||||||
"Of Mice and Men" in titles
|
"Of Mice and Men" in titles
|
||||||
)
|
)
|
||||||
|
|
||||||
query = Query.boolean_query([
|
query = Query.boolean_query([
|
||||||
(Occur.MustNot, query1),
|
(Occur.MustNot, query1),
|
||||||
(Occur.Must, query1)
|
(Occur.Must, query1)
|
||||||
])
|
])
|
||||||
|
|
||||||
# must not should take precedence over must
|
# must not should take precedence over must
|
||||||
result = index.searcher().search(query, 10)
|
result = index.searcher().search(query, 10)
|
||||||
assert len(result.hits) == 0
|
assert len(result.hits) == 0
|
||||||
|
|
||||||
query = Query.boolean_query((
|
query = Query.boolean_query((
|
||||||
(Occur.Should, query1),
|
(Occur.Should, query1),
|
||||||
(Occur.Should, query2)
|
(Occur.Should, query2)
|
||||||
))
|
))
|
||||||
|
|
||||||
# the Vec signature should fit the tuple signature
|
# the Vec signature should fit the tuple signature
|
||||||
result = index.searcher().search(query, 10)
|
result = index.searcher().search(query, 10)
|
||||||
assert len(result.hits) == 2
|
assert len(result.hits) == 2
|
||||||
|
|
||||||
# test invalid queries
|
# test invalid queries
|
||||||
with pytest.raises(ValueError, match = "expected tuple of length 2, but got tuple of length 3"):
|
with pytest.raises(ValueError, match = "expected tuple of length 2, but got tuple of length 3"):
|
||||||
Query.boolean_query([
|
Query.boolean_query([
|
||||||
(Occur.Must, Occur.Must, query1),
|
(Occur.Must, Occur.Must, query1),
|
||||||
])
|
])
|
||||||
|
|
||||||
# test swapping the order of the tuple
|
# test swapping the order of the tuple
|
||||||
with pytest.raises(TypeError, match = r"'Query' object cannot be converted to 'Occur'"):
|
with pytest.raises(TypeError, match = r"'Query' object cannot be converted to 'Occur'"):
|
||||||
Query.boolean_query([
|
Query.boolean_query([
|
||||||
|
@ -899,3 +899,99 @@ class TestQuery(object):
|
||||||
|
|
||||||
with pytest.raises(TypeError, match = r"'str' object cannot be converted to 'Query'"):
|
with pytest.raises(TypeError, match = r"'str' object cannot be converted to 'Query'"):
|
||||||
query = Query.disjunction_max_query([query1, "not a query"], tie_breaker=0.5)
|
query = Query.disjunction_max_query([query1, "not a query"], tie_breaker=0.5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_boost_query(self, ram_index):
|
||||||
|
index = ram_index
|
||||||
|
query1 = Query.term_query(index.schema, "title", "sea")
|
||||||
|
boosted_query = Query.boost_query(query1, 2.0)
|
||||||
|
|
||||||
|
# Normal boost query
|
||||||
|
assert (
|
||||||
|
repr(boosted_query)
|
||||||
|
== """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=2))"""
|
||||||
|
)
|
||||||
|
|
||||||
|
query2 = Query.fuzzy_term_query(index.schema, "title", "ice")
|
||||||
|
combined_query = Query.boolean_query([
|
||||||
|
(Occur.Should, boosted_query),
|
||||||
|
(Occur.Should, query2)
|
||||||
|
])
|
||||||
|
boosted_query = Query.boost_query(combined_query, 2.0)
|
||||||
|
|
||||||
|
# Boosted boolean query
|
||||||
|
assert (
|
||||||
|
repr(boosted_query)
|
||||||
|
== """Query(Boost(query=BooleanQuery { subqueries: [(Should, Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=2)), (Should, FuzzyTermQuery { term: Term(field=0, type=Str, "ice"), distance: 1, transposition_cost_one: true, prefix: false })] }, boost=2))"""
|
||||||
|
)
|
||||||
|
|
||||||
|
boosted_query = Query.boost_query(query1, 0.1)
|
||||||
|
|
||||||
|
# Check for decimal boost values
|
||||||
|
assert(
|
||||||
|
repr(boosted_query)
|
||||||
|
== """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0.1))"""
|
||||||
|
)
|
||||||
|
|
||||||
|
boosted_query = Query.boost_query(query1, 0.0)
|
||||||
|
|
||||||
|
# Check for zero boost values
|
||||||
|
assert(
|
||||||
|
repr(boosted_query)
|
||||||
|
== """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0))"""
|
||||||
|
)
|
||||||
|
result = index.searcher().search(boosted_query, 10)
|
||||||
|
for _score, _ in result.hits:
|
||||||
|
# the score should be 0.0
|
||||||
|
assert _score == pytest.approx(0.0)
|
||||||
|
|
||||||
|
boosted_query = Query.boost_query(
|
||||||
|
Query.boost_query(
|
||||||
|
query1, 0.1
|
||||||
|
), 0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for nested boost queries
|
||||||
|
assert(
|
||||||
|
repr(boosted_query)
|
||||||
|
== """Query(Boost(query=Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=0.1), boost=0.1))"""
|
||||||
|
)
|
||||||
|
result = index.searcher().search(boosted_query, 10)
|
||||||
|
for _score, _ in result.hits:
|
||||||
|
# the score should be very small, due to
|
||||||
|
# the unknown score of BM25, we can only check for the relative difference
|
||||||
|
assert _score == pytest.approx(0.01, rel = 1)
|
||||||
|
|
||||||
|
|
||||||
|
boosted_query = Query.boost_query(
|
||||||
|
query1, -0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for negative boost values
|
||||||
|
assert(
|
||||||
|
repr(boosted_query)
|
||||||
|
== """Query(Boost(query=TermQuery(Term(field=0, type=Str, "sea")), boost=-0.1))"""
|
||||||
|
)
|
||||||
|
|
||||||
|
result = index.searcher().search(boosted_query, 10)
|
||||||
|
# Even with a negative boost, the query should still match the document
|
||||||
|
assert len(result.hits) == 1
|
||||||
|
titles = set()
|
||||||
|
for _score, doc_address in result.hits:
|
||||||
|
|
||||||
|
# the score should be negative
|
||||||
|
assert _score < 0
|
||||||
|
titles.update(index.searcher().doc(doc_address)["title"])
|
||||||
|
assert titles == {"The Old Man and the Sea"}
|
||||||
|
|
||||||
|
# wrong query type
|
||||||
|
with pytest.raises(TypeError, match = r"'int' object cannot be converted to 'Query'"):
|
||||||
|
Query.boost_query(1, 0.1)
|
||||||
|
|
||||||
|
# wrong boost type
|
||||||
|
with pytest.raises(TypeError, match = r"argument 'boost': must be real number, not str"):
|
||||||
|
Query.boost_query(query1, "0.1")
|
||||||
|
|
||||||
|
# no boost type error
|
||||||
|
with pytest.raises(TypeError, match = r"Query.boost_query\(\) missing 1 required positional argument: 'boost'"):
|
||||||
|
Query.boost_query(query1)
|
||||||
|
|
Loading…
Reference in New Issue