diff --git a/src/document.rs b/src/document.rs index 6601472..f737d9a 100644 --- a/src/document.rs +++ b/src/document.rs @@ -300,6 +300,15 @@ impl Document { add_value(self, field_name, value); } + /// Add a float value to the document. + /// + /// Args: + /// field_name (str): The field name for which we are adding the value. + /// value (f64): The float that will be added to the document. + fn add_float(&mut self, field_name: String, value: f64) { + add_value(self, field_name, value); + } + /// Add a date value to the document. /// /// Args: diff --git a/src/schemabuilder.rs b/src/schemabuilder.rs index 7fd6e56..17493c0 100644 --- a/src/schemabuilder.rs +++ b/src/schemabuilder.rs @@ -120,7 +120,7 @@ impl SchemaBuilder { ) -> PyResult { let builder = &mut self.builder; - let opts = SchemaBuilder::build_int_option(stored, indexed, fast)?; + let opts = SchemaBuilder::build_numeric_option(stored, indexed, fast)?; if let Some(builder) = builder.write().unwrap().as_mut() { builder.add_i64_field(name, opts); @@ -132,6 +132,28 @@ impl SchemaBuilder { Ok(self.clone()) } + #[pyo3(signature = (name, stored = false, indexed = false, fast = None))] + fn add_float_field( + &mut self, + name: &str, + stored: bool, + indexed: bool, + fast: Option<&str>, + ) -> PyResult { + let builder = &mut self.builder; + + let opts = SchemaBuilder::build_numeric_option(stored, indexed, fast)?; + + if let Some(builder) = builder.write().unwrap().as_mut() { + builder.add_f64_field(name, opts); + } else { + return Err(exceptions::PyValueError::new_err( + "Schema builder object isn't valid anymore.", + )); + } + Ok(self.clone()) + } + /// Add a new unsigned integer field to the schema. /// /// Args: @@ -162,7 +184,7 @@ impl SchemaBuilder { ) -> PyResult { let builder = &mut self.builder; - let opts = SchemaBuilder::build_int_option(stored, indexed, fast)?; + let opts = SchemaBuilder::build_numeric_option(stored, indexed, fast)?; if let Some(builder) = builder.write().unwrap().as_mut() { builder.add_u64_field(name, opts); @@ -343,7 +365,7 @@ impl SchemaBuilder { } impl SchemaBuilder { - fn build_int_option( + fn build_numeric_option( stored: bool, indexed: bool, fast: Option<&str>, diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 74f444d..0157103 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -12,6 +12,14 @@ def schema(): .build() ) +def schema_numeric_fields(): + return ( + SchemaBuilder() + .add_integer_field("id", stored=True, indexed=True) + .add_float_field("rating", stored=True, indexed=True) + .add_text_field("body", stored=True) + .build() + ) def create_index(dir=None): # assume all tests will use the same documents for now @@ -66,6 +74,46 @@ def create_index(dir=None): index.reload() return index +def create_index_with_numeric_fields(dir=None): + index = Index(schema_numeric_fields(), dir) + writer = index.writer() + + doc = Document() + doc.add_integer("id", 1) + doc.add_float("rating", 3.5) + doc.add_text( + "body", + ( + "He was an old man who fished alone in a skiff in" + "the Gulf Stream and he had gone eighty-four days " + "now without taking a fish." + ), + ) + writer.add_document(doc) + doc = Document.from_dict( + { + "id": 2, + "rating": 4.5, + "body": ( + "A few miles south of Soledad, the Salinas River drops " + "in close to the hillside bank and runs deep and " + "green. The water is warm too, for it has slipped " + "twinkling over the yellow sands in the sunlight " + "before reaching the narrow pool. On one side of the " + "river the golden foothill slopes curve up to the " + "strong and rocky Gabilan Mountains, but on the valley " + "side the water is lined with trees—willows fresh and " + "green with every spring, carrying in their lower leaf " + "junctures the debris of the winter’s flooding; and " + "sycamores with mottled, white, recumbent limbs and " + "branches that arch over the pool" + ), + } + ) + writer.add_document(doc) + writer.commit() + index.reload() + return index def spanish_schema(): return ( @@ -127,6 +175,11 @@ def ram_index(): return create_index() +@pytest.fixture(scope="class") +def ram_index_numeric_fields(): + return create_index_with_numeric_fields() + + @pytest.fixture(scope="class") def spanish_index(): return create_spanish_index() @@ -185,6 +238,25 @@ class TestClass(object): assert len(result.hits) == 1 + def test_and_query_numeric_fields(self, ram_index_numeric_fields): + index = ram_index_numeric_fields + searcher = index.searcher() + + # 1 result + float_query = index.parse_query("3.5", ["rating"]) + result = searcher.search(float_query) + assert len(result.hits) == 1 + assert searcher.doc(result.hits[0][1])['rating'][0] == 3.5 + + integer_query = index.parse_query("1", ["id"]) + result = searcher.search(integer_query) + assert len(result.hits) == 1 + + # 0 result + integer_query = index.parse_query("10", ["id"]) + result = searcher.search(integer_query) + assert len(result.hits) == 0 + def test_and_query_parser_default_fields(self, ram_index): query = ram_index.parse_query("winter", default_field_names=["title"]) assert repr(query) == """Query(TermQuery(Term(type=Str, field=0, "winter")))""" @@ -344,8 +416,9 @@ class TestFromDiskClass(object): class TestSearcher(object): - def test_searcher_repr(self, ram_index): + def test_searcher_repr(self, ram_index, ram_index_numeric_fields): assert repr(ram_index.searcher()) == "Searcher(num_docs=3, num_segments=1)" + assert repr(ram_index_numeric_fields.searcher()) == "Searcher(num_docs=2, num_segments=1)" class TestDocument(object):