expose regex query (#241)

Co-authored-by: alexau <alexau@hket.com>
master
alex-au-922 2024-04-25 10:57:09 +08:00 committed by GitHub
parent c74990aeb8
commit 5c3666349b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 140 additions and 87 deletions

View File

@ -1,4 +1,4 @@
use crate::{make_term, Schema};
use crate::{get_field, make_term, to_pyerr, Schema};
use pyo3::{
exceptions,
prelude::*,
@ -187,4 +187,23 @@ impl Query {
inner: Box::new(inner),
})
}
#[staticmethod]
#[pyo3(signature = (schema, field_name, regex_pattern))]
pub(crate) fn regex_query(
schema: &Schema,
field_name: &str,
regex_pattern: &str,
) -> PyResult<Query> {
let field = get_field(&schema.inner, field_name)?;
let inner_result =
tv::query::RegexQuery::from_pattern(regex_pattern, field);
match inner_result {
Ok(inner) => Ok(Query {
inner: Box::new(inner),
}),
Err(e) => Err(to_pyerr(e)),
}
}
}

View File

@ -2,12 +2,10 @@ import datetime
from enum import Enum
from typing import Any, Optional, Sequence
class Schema:
pass
class SchemaBuilder:
@staticmethod
def is_valid_field_name(name: str) -> bool:
pass
@ -103,7 +101,6 @@ class SchemaBuilder:
def build(self) -> Schema:
pass
class Facet:
@staticmethod
def from_encoded(encoded_bytes: bytes) -> Facet:
@ -130,9 +127,7 @@ class Facet:
def to_path_str(self) -> str:
pass
class Document:
def __new__(cls, **kwargs) -> Document:
pass
@ -194,7 +189,12 @@ class Occur(Enum):
class Query:
@staticmethod
def term_query(schema: Schema, field_name: str, field_value: Any, index_option: str = "position") -> Query:
def term_query(
schema: Schema,
field_name: str,
field_value: Any,
index_option: str = "position",
) -> Query:
pass
@staticmethod
@ -202,7 +202,14 @@ class Query:
pass
@staticmethod
def fuzzy_term_query(schema: Schema, field_name: str, text: str, distance: int = 1, transposition_cost_one: bool = True, prefix = False) -> Query:
def fuzzy_term_query(
schema: Schema,
field_name: str,
text: str,
distance: int = 1,
transposition_cost_one: bool = True,
prefix=False,
) -> Query:
pass
@staticmethod
@ -218,13 +225,15 @@ class Query:
pass
@staticmethod
def regex_query(schema: Schema, field_name: str, regex_pattern: str) -> Query:
pass
class Order(Enum):
Asc = 1
Desc = 2
class DocAddress:
def __new__(cls, segment_ord: int, doc: int) -> DocAddress:
pass
@ -237,14 +246,11 @@ class DocAddress:
pass
class SearchResult:
@property
def hits(self) -> list[tuple[Any, DocAddress]]:
pass
class Searcher:
def search(
self,
query: Query,
@ -267,9 +273,7 @@ class Searcher:
def doc(self, doc_address: DocAddress) -> Document:
pass
class IndexWriter:
def add_document(self, doc: Document) -> int:
pass
@ -298,10 +302,10 @@ class IndexWriter:
def wait_merging_threads(self) -> None:
pass
class Index:
def __new__(cls, schema: Schema, path: Optional[str] = None, reuse: bool = True) -> Index:
def __new__(
cls, schema: Schema, path: Optional[str] = None, reuse: bool = True
) -> Index:
pass
@staticmethod
@ -311,7 +315,9 @@ class Index:
def writer(self, heap_size: int = 128_000_000, num_threads: int = 0) -> IndexWriter:
pass
def config_reader(self, reload_policy: str = "commit", num_warmers: int = 0) -> None:
def config_reader(
self, reload_policy: str = "commit", num_warmers: int = 0
) -> None:
pass
def searcher(self) -> Searcher:
@ -328,15 +334,17 @@ class Index:
def reload(self) -> None:
pass
def parse_query(self, query: str, default_field_names: Optional[list[str]] = None) -> Query:
def parse_query(
self, query: str, default_field_names: Optional[list[str]] = None
) -> Query:
pass
def parse_query_lenient(self, query: str, default_field_names: Optional[list[str]] = None) -> Query:
def parse_query_lenient(
self, query: str, default_field_names: Optional[list[str]] = None
) -> Query:
pass
class Range:
@property
def start(self) -> int:
pass
@ -345,24 +353,17 @@ class Range:
def end(self) -> int:
pass
class Snippet:
def to_html(self) -> str:
pass
def highlighted(self) -> list[Range]:
pass
class SnippetGenerator:
@staticmethod
def create(
searcher: Searcher,
query: Query,
schema: Schema,
field_name: str
searcher: Searcher, query: Query, schema: Schema, field_name: str
) -> SnippetGenerator:
pass

View File

@ -995,3 +995,36 @@ class TestQuery(object):
# no boost type error
with pytest.raises(TypeError, match = r"Query.boost_query\(\) missing 1 required positional argument: 'boost'"):
Query.boost_query(query1)
def test_regex_query(self, ram_index):
index = ram_index
query = Query.regex_query(index.schema, "body", "fish")
result = index.searcher().search(query, 10)
assert len(result.hits) == 1
_, doc_address = result.hits[0]
searched_doc = index.searcher().doc(doc_address)
assert searched_doc["title"] == ["The Old Man and the Sea"]
query = Query.regex_query(index.schema, "title", "(?:man|men)")
result = index.searcher().search(query, 10)
assert len(result.hits) == 2
_, doc_address = result.hits[0]
searched_doc = index.searcher().doc(doc_address)
assert searched_doc["title"] == ["The Old Man and the Sea"]
_, doc_address = result.hits[1]
searched_doc = index.searcher().doc(doc_address)
assert searched_doc["title"] == ["Of Mice and Men"]
# unknown field in the schema
with pytest.raises(
ValueError, match="Field `unknown_field` is not defined in the schema."
):
Query.regex_query(index.schema, "unknown_field", "fish")
# invalid regex pattern
with pytest.raises(
ValueError, match=r"An invalid argument was passed: 'fish\('"
):
Query.regex_query(index.schema, "body", "fish(")