diff --git a/src/query.rs b/src/query.rs index 716e0b8..00e35ca 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,4 +1,4 @@ -use crate::{make_term, Schema}; +use crate::{get_field, make_term, to_pyerr, Schema}; use pyo3::{ exceptions, prelude::*, @@ -187,4 +187,23 @@ impl Query { inner: Box::new(inner), }) } + + #[staticmethod] + #[pyo3(signature = (schema, field_name, regex_pattern))] + pub(crate) fn regex_query( + schema: &Schema, + field_name: &str, + regex_pattern: &str, + ) -> PyResult { + let field = get_field(&schema.inner, field_name)?; + + let inner_result = + tv::query::RegexQuery::from_pattern(regex_pattern, field); + match inner_result { + Ok(inner) => Ok(Query { + inner: Box::new(inner), + }), + Err(e) => Err(to_pyerr(e)), + } + } } diff --git a/tantivy/tantivy.pyi b/tantivy/tantivy.pyi index 710358c..524523c 100644 --- a/tantivy/tantivy.pyi +++ b/tantivy/tantivy.pyi @@ -2,108 +2,105 @@ import datetime from enum import Enum from typing import Any, Optional, Sequence - class Schema: pass class SchemaBuilder: - @staticmethod def is_valid_field_name(name: str) -> bool: pass def add_text_field( - self, - name: str, - stored: bool = False, - tokenizer_name: str = "default", - index_option: str = "position", + self, + name: str, + stored: bool = False, + tokenizer_name: str = "default", + index_option: str = "position", ) -> SchemaBuilder: pass def add_integer_field( - self, - name: str, - stored: bool = False, - indexed: bool = False, - fast: bool = False, + self, + name: str, + stored: bool = False, + indexed: bool = False, + fast: bool = False, ) -> SchemaBuilder: pass def add_float_field( - self, - name: str, - stored: bool = False, - indexed: bool = False, - fast: bool = False, + self, + name: str, + stored: bool = False, + indexed: bool = False, + fast: bool = False, ) -> SchemaBuilder: pass def add_unsigned_field( - self, - name: str, - stored: bool = False, - indexed: bool = False, - fast: bool = False, + self, + name: str, + stored: bool = False, + indexed: bool = False, + fast: bool = False, ) -> SchemaBuilder: pass def add_boolean_field( - self, - name: str, - stored: bool = False, - indexed: bool = False, - fast: bool = False, + self, + name: str, + stored: bool = False, + indexed: bool = False, + fast: bool = False, ) -> SchemaBuilder: pass def add_date_field( - self, - name: str, - stored: bool = False, - indexed: bool = False, - fast: bool = False, + self, + name: str, + stored: bool = False, + indexed: bool = False, + fast: bool = False, ) -> SchemaBuilder: pass def add_json_field( - self, - name: str, - stored: bool = False, - tokenizer_name: str = "default", - index_option: str = "position", + self, + name: str, + stored: bool = False, + tokenizer_name: str = "default", + index_option: str = "position", ) -> SchemaBuilder: pass def add_facet_field( - self, - name: str, + self, + name: str, ) -> SchemaBuilder: pass def add_bytes_field( - self, - name: str, - stored: bool = False, - indexed: bool = False, - fast: bool = False, - index_option: str = "position", + self, + name: str, + stored: bool = False, + indexed: bool = False, + fast: bool = False, + index_option: str = "position", ) -> SchemaBuilder: pass def add_ip_addr_field( - self, - name: str, - stored: bool = False, - indexed: bool = False, - fast: bool = False, + self, + name: str, + stored: bool = False, + indexed: bool = False, + fast: bool = False, ) -> SchemaBuilder: pass def build(self) -> Schema: pass - class Facet: @staticmethod def from_encoded(encoded_bytes: bytes) -> Facet: @@ -130,9 +127,7 @@ class Facet: def to_path_str(self) -> str: pass - class Document: - def __new__(cls, **kwargs) -> Document: pass @@ -194,7 +189,12 @@ class Occur(Enum): class Query: @staticmethod - def term_query(schema: Schema, field_name: str, field_value: Any, index_option: str = "position") -> Query: + def term_query( + schema: Schema, + field_name: str, + field_value: Any, + index_option: str = "position", + ) -> Query: pass @staticmethod @@ -202,9 +202,16 @@ class Query: pass @staticmethod - def fuzzy_term_query(schema: Schema, field_name: str, text: str, distance: int = 1, transposition_cost_one: bool = True, prefix = False) -> Query: + def fuzzy_term_query( + schema: Schema, + field_name: str, + text: str, + distance: int = 1, + transposition_cost_one: bool = True, + prefix=False, + ) -> Query: pass - + @staticmethod def boolean_query(subqueries: Sequence[tuple[Occur, Query]]) -> Query: pass @@ -218,13 +225,15 @@ class Query: pass + @staticmethod + def regex_query(schema: Schema, field_name: str, regex_pattern: str) -> Query: + pass + class Order(Enum): Asc = 1 Desc = 2 - class DocAddress: - def __new__(cls, segment_ord: int, doc: int) -> DocAddress: pass @@ -237,22 +246,19 @@ class DocAddress: pass class SearchResult: - @property def hits(self) -> list[tuple[Any, DocAddress]]: pass - class Searcher: - def search( - self, - query: Query, - limit: int = 10, - count: bool = True, - order_by_field: Optional[str] = None, - offset: int = 0, - order: Order = Order.Desc, + self, + query: Query, + limit: int = 10, + count: bool = True, + order_by_field: Optional[str] = None, + offset: int = 0, + order: Order = Order.Desc, ) -> SearchResult: pass @@ -267,9 +273,7 @@ class Searcher: def doc(self, doc_address: DocAddress) -> Document: pass - class IndexWriter: - def add_document(self, doc: Document) -> int: pass @@ -298,10 +302,10 @@ class IndexWriter: def wait_merging_threads(self) -> None: pass - class Index: - - def __new__(cls, schema: Schema, path: Optional[str] = None, reuse: bool = True) -> Index: + def __new__( + cls, schema: Schema, path: Optional[str] = None, reuse: bool = True + ) -> Index: pass @staticmethod @@ -311,7 +315,9 @@ class Index: def writer(self, heap_size: int = 128_000_000, num_threads: int = 0) -> IndexWriter: pass - def config_reader(self, reload_policy: str = "commit", num_warmers: int = 0) -> None: + def config_reader( + self, reload_policy: str = "commit", num_warmers: int = 0 + ) -> None: pass def searcher(self) -> Searcher: @@ -328,15 +334,17 @@ class Index: def reload(self) -> None: pass - def parse_query(self, query: str, default_field_names: Optional[list[str]] = None) -> Query: + def parse_query( + self, query: str, default_field_names: Optional[list[str]] = None + ) -> Query: pass - def parse_query_lenient(self, query: str, default_field_names: Optional[list[str]] = None) -> Query: + def parse_query_lenient( + self, query: str, default_field_names: Optional[list[str]] = None + ) -> Query: pass - class Range: - @property def start(self) -> int: pass @@ -345,24 +353,17 @@ class Range: def end(self) -> int: pass - class Snippet: - def to_html(self) -> str: pass def highlighted(self) -> list[Range]: pass - class SnippetGenerator: - @staticmethod def create( - searcher: Searcher, - query: Query, - schema: Schema, - field_name: str + searcher: Searcher, query: Query, schema: Schema, field_name: str ) -> SnippetGenerator: pass diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 7ff8544..0f6b987 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -995,3 +995,36 @@ class TestQuery(object): # no boost type error with pytest.raises(TypeError, match = r"Query.boost_query\(\) missing 1 required positional argument: 'boost'"): Query.boost_query(query1) + + + def test_regex_query(self, ram_index): + index = ram_index + + query = Query.regex_query(index.schema, "body", "fish") + result = index.searcher().search(query, 10) + assert len(result.hits) == 1 + _, doc_address = result.hits[0] + searched_doc = index.searcher().doc(doc_address) + assert searched_doc["title"] == ["The Old Man and the Sea"] + + query = Query.regex_query(index.schema, "title", "(?:man|men)") + result = index.searcher().search(query, 10) + assert len(result.hits) == 2 + _, doc_address = result.hits[0] + searched_doc = index.searcher().doc(doc_address) + assert searched_doc["title"] == ["The Old Man and the Sea"] + _, doc_address = result.hits[1] + searched_doc = index.searcher().doc(doc_address) + assert searched_doc["title"] == ["Of Mice and Men"] + + # unknown field in the schema + with pytest.raises( + ValueError, match="Field `unknown_field` is not defined in the schema." + ): + Query.regex_query(index.schema, "unknown_field", "fish") + + # invalid regex pattern + with pytest.raises( + ValueError, match=r"An invalid argument was passed: 'fish\('" + ): + Query.regex_query(index.schema, "body", "fish(")