diff --git a/src/lib.rs b/src/lib.rs index 302a321..2bf9e3e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ use ::tantivy as tv; +use ::tantivy::schema::{Term, Value}; use pyo3::{exceptions, prelude::*, wrap_pymodule}; mod document; @@ -20,6 +21,8 @@ use schemabuilder::SchemaBuilder; use searcher::{DocAddress, Order, SearchResult, Searcher}; use snippet::{Snippet, SnippetGenerator}; +use crate::document::extract_value; + /// Python bindings for the search engine library Tantivy. /// /// Tantivy is a full text search engine library written in rust. @@ -153,3 +156,29 @@ pub(crate) fn get_field( Ok(field) } + +pub(crate) fn make_term( + schema: &tv::schema::Schema, + field_name: &str, + field_value: &PyAny, +) -> PyResult { + let field = get_field(schema, field_name)?; + let value = extract_value(field_value)?; + let term = match value { + Value::Str(text) => Term::from_field_text(field, &text), + Value::U64(num) => Term::from_field_u64(field, num), + Value::I64(num) => Term::from_field_i64(field, num), + Value::F64(num) => Term::from_field_f64(field, num), + Value::Date(d) => Term::from_field_date(field, d), + Value::Facet(facet) => Term::from_facet(field, &facet), + Value::Bool(b) => Term::from_field_bool(field, b), + Value::IpAddr(i) => Term::from_field_ip_addr(field, i), + _ => { + return Err(exceptions::PyValueError::new_err(format!( + "Can't create a term for Field `{field_name}` with value `{field_value}`." + ))) + } + }; + + Ok(term) +} diff --git a/src/query.rs b/src/query.rs index ef841a0..53da089 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,4 +1,5 @@ -use pyo3::prelude::*; +use crate::{make_term, Schema}; +use pyo3::{exceptions, prelude::*, types::PyAny}; use tantivy as tv; /// Tantivy's Query @@ -18,4 +19,28 @@ impl Query { fn __repr__(&self) -> PyResult { Ok(format!("Query({:?})", self.get())) } + + /// Construct a Tantivy's TermQuery + #[staticmethod] + #[pyo3(signature = (schema, field_name, field_value, index_option = "position"))] + pub(crate) fn term_query( + schema: &Schema, + field_name: &str, + field_value: &PyAny, + index_option: &str, + ) -> PyResult { + let term = make_term(&schema.inner, field_name, field_value)?; + let index_option = match index_option { + "position" => tv::schema::IndexRecordOption::WithFreqsAndPositions, + "freq" => tv::schema::IndexRecordOption::WithFreqs, + "basic" => tv::schema::IndexRecordOption::Basic, + _ => return Err(exceptions::PyValueError::new_err( + "Invalid index option, valid choices are: 'basic', 'freq' and 'position'" + )) + }; + let inner = tv::query::TermQuery::new(term, index_option); + Ok(Query { + inner: Box::new(inner), + }) + } } diff --git a/tantivy/tantivy.pyi b/tantivy/tantivy.pyi index 3d2c580..06c5690 100644 --- a/tantivy/tantivy.pyi +++ b/tantivy/tantivy.pyi @@ -189,7 +189,9 @@ class Document: class Query: - pass + @staticmethod + def term_query(schema: Schema, field_name: str, field_value: Any, index_option: str = "position") -> Query: + pass class Order(Enum): diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 6b4c2b4..80c1719 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -7,7 +7,7 @@ import tantivy import pickle import pytest import tantivy -from tantivy import Document, Index, SchemaBuilder, SnippetGenerator +from tantivy import Document, Index, SchemaBuilder, SnippetGenerator, Query def schema(): @@ -925,3 +925,15 @@ class TestSnippets(object): assert first.end == 23 html_snippet = snippet.to_html() assert html_snippet == "The Old Man and the Sea" + + +class TestQuery(object): + def test_term_query(self, ram_index): + index = ram_index + query = Query.term_query(index.schema, "title", "sea") + + result = index.searcher().search(query, 10) + assert len(result.hits) == 1 + _, doc_address = result.hits[0] + searched_doc = index.searcher().doc(doc_address) + assert searched_doc["title"] == ["The Old Man and the Sea"]