diff --git a/src/document.rs b/src/document.rs index 529567a..7a534e0 100644 --- a/src/document.rs +++ b/src/document.rs @@ -22,6 +22,114 @@ use std::{ }; use tantivy::schema::Value; +pub(crate) fn extract_value(any: &PyAny) -> PyResult { + if let Ok(s) = any.extract::() { + return Ok(Value::Str(s)); + } + if let Ok(num) = any.extract::() { + return Ok(Value::I64(num)); + } + if let Ok(num) = any.extract::() { + return Ok(Value::F64(num)); + } + if let Ok(datetime) = any.extract::() { + return Ok(Value::Date(tv::DateTime::from_timestamp_secs( + datetime.timestamp(), + ))); + } + if let Ok(facet) = any.extract::() { + return Ok(Value::Facet(facet.inner)); + } + if let Ok(b) = any.extract::>() { + return Ok(Value::Bytes(b)); + } + Err(to_pyerr(format!("Value unsupported {any:?}"))) +} + +pub(crate) fn extract_value_for_type( + any: &PyAny, + tv_type: tv::schema::Type, + field_name: &str, +) -> PyResult { + // Helper function to create `PyErr`s returned by this function. + fn to_pyerr_for_type<'a, E: std::error::Error>( + type_name: &'a str, + field_name: &'a str, + any: &'a PyAny, + ) -> impl Fn(E) -> PyErr + 'a { + move |_| { + to_pyerr(format!( + "Expected {} type for field {}, got {:?}", + type_name, field_name, any + )) + } + } + + let value = match tv_type { + tv::schema::Type::Str => Value::Str( + any.extract::() + .map_err(to_pyerr_for_type("Str", field_name, any))?, + ), + tv::schema::Type::U64 => Value::U64( + any.extract::() + .map_err(to_pyerr_for_type("U64", field_name, any))?, + ), + tv::schema::Type::I64 => Value::I64( + any.extract::() + .map_err(to_pyerr_for_type("I64", field_name, any))?, + ), + tv::schema::Type::F64 => Value::F64( + any.extract::() + .map_err(to_pyerr_for_type("F64", field_name, any))?, + ), + tv::schema::Type::Date => { + let datetime = any + .extract::() + .map_err(to_pyerr_for_type("DateTime", field_name, any))?; + + Value::Date(tv::DateTime::from_timestamp_secs(datetime.timestamp())) + } + tv::schema::Type::Facet => Value::Facet( + any.extract::() + .map_err(to_pyerr_for_type("Facet", field_name, any))? + .inner, + ), + _ => return Err(to_pyerr(format!("Value unsupported {:?}", any))), + }; + + Ok(value) +} + +fn extract_value_single_or_list(any: &PyAny) -> PyResult> { + if let Ok(values) = any.downcast::() { + values.iter().map(extract_value).collect() + } else { + Ok(vec![extract_value(any)?]) + } +} + +fn extract_value_single_or_list_for_type( + any: &PyAny, + field_type: &tv::schema::FieldType, + field_name: &str, +) -> PyResult> { + // Check if a numeric fast field supports multivalues. + if let Ok(values) = any.downcast::() { + values + .iter() + .map(|any| { + extract_value_for_type(any, field_type.value_type(), field_name) + }) + .collect() + } else { + Ok(vec![extract_value_for_type( + any, + field_type.value_type(), + field_name, + )?]) + } +} + fn value_to_object(val: &JsonValue, py: Python<'_>) -> PyObject { match val { JsonValue::Null => py.None(), @@ -174,179 +282,6 @@ impl fmt::Debug for Document { } } -fn add_value(doc: &mut Document, field_name: String, value: T) -where - Value: From, -{ - doc.field_values - .entry(field_name) - .or_insert_with(Vec::new) - .push(Value::from(value)); -} - -pub(crate) fn extract_value(any: &PyAny) -> PyResult { - if let Ok(s) = any.extract::() { - return Ok(Value::Str(s)); - } - if let Ok(num) = any.extract::() { - return Ok(Value::I64(num)); - } - if let Ok(num) = any.extract::() { - return Ok(Value::F64(num)); - } - if let Ok(datetime) = any.extract::() { - return Ok(Value::Date(tv::DateTime::from_timestamp_secs( - datetime.timestamp(), - ))); - } - if let Ok(facet) = any.extract::() { - return Ok(Value::Facet(facet.inner)); - } - if let Ok(b) = any.extract::>() { - return Ok(Value::Bytes(b)); - } - Err(to_pyerr(format!("Value unsupported {any:?}"))) -} - -pub(crate) fn extract_value_for_type( - any: &PyAny, - tv_type: tv::schema::Type, - field_name: &str, -) -> PyResult { - // Helper function to create `PyErr`s returned by this function. - fn to_pyerr_for_type<'a, E: std::error::Error>( - type_name: &'a str, - field_name: &'a str, - any: &'a PyAny, - ) -> impl Fn(E) -> PyErr + 'a { - move |_| { - to_pyerr(format!( - "Expected {} type for field {}, got {:?}", - type_name, field_name, any - )) - } - } - - let value = match tv_type { - tv::schema::Type::Str => Value::Str( - any.extract::() - .map_err(to_pyerr_for_type("Str", field_name, any))?, - ), - tv::schema::Type::U64 => Value::U64( - any.extract::() - .map_err(to_pyerr_for_type("U64", field_name, any))?, - ), - tv::schema::Type::I64 => Value::I64( - any.extract::() - .map_err(to_pyerr_for_type("I64", field_name, any))?, - ), - tv::schema::Type::F64 => Value::F64( - any.extract::() - .map_err(to_pyerr_for_type("F64", field_name, any))?, - ), - tv::schema::Type::Date => { - let datetime = any - .extract::() - .map_err(to_pyerr_for_type("DateTime", field_name, any))?; - - Value::Date(tv::DateTime::from_timestamp_secs(datetime.timestamp())) - } - tv::schema::Type::Facet => Value::Facet( - any.extract::() - .map_err(to_pyerr_for_type("Facet", field_name, any))? - .inner, - ), - _ => return Err(to_pyerr(format!("Value unsupported {:?}", any))), - }; - - Ok(value) -} - -fn extract_value_single_or_list(any: &PyAny) -> PyResult> { - if let Ok(values) = any.downcast::() { - values.iter().map(extract_value).collect() - } else { - Ok(vec![extract_value(any)?]) - } -} - -fn extract_value_single_or_list_for_type( - any: &PyAny, - field_type: &tv::schema::FieldType, - field_name: &str, -) -> PyResult> { - // Check if a numeric fast field supports multivalues. - if let Ok(values) = any.downcast::() { - values - .iter() - .map(|any| { - extract_value_for_type(any, field_type.value_type(), field_name) - }) - .collect() - } else { - Ok(vec![extract_value_for_type( - any, - field_type.value_type(), - field_name, - )?]) - } -} - -impl Document { - fn extract_py_values_from_dict( - py_dict: &PyDict, - schema: Option<&Schema>, - out_field_values: &mut BTreeMap>, - ) -> PyResult<()> { - // TODO: Reserve when https://github.com/rust-lang/rust/issues/72631 is stable. - // out_field_values.reserve(py_dict.len()); - - for key_value_any in py_dict.items() { - if let Ok(key_value) = key_value_any.downcast::() { - if key_value.len() != 2 { - continue; - } - let key = key_value.get_item(0)?.extract::()?; - - let field_type = if let Some(schema) = schema { - let field_type = schema - .inner - .get_field(key.as_str()) - .map(|field| { - schema.inner.get_field_entry(field).field_type() - }) - .ok(); - - if let Some(field_type) = field_type { - // A field type was found, so validate it after the values are extracted. - Some(field_type) - } else { - // The field does not exist in the schema, so skip over it. - continue; - } - } else { - // No schema was provided, so do not validate anything. - None - }; - - let value_list = if let Some(field_type) = field_type { - extract_value_single_or_list_for_type( - key_value.get_item(1)?, - field_type, - key.as_str(), - )? - } else { - extract_value_single_or_list(key_value.get_item(1)?)? - }; - - out_field_values.insert(key, value_list); - } - } - - Ok(()) - } -} - #[pymethods] impl Document { /// Creates a new document with optional fields from `**kwargs`. @@ -417,7 +352,7 @@ impl Document { /// field_name (str): The field name for which we are adding the text. /// text (str): The text that will be added to the document. fn add_text(&mut self, field_name: String, text: &str) { - add_value(self, field_name, text); + self.add_value(field_name, text); } /// Add an unsigned integer value to the document. @@ -426,7 +361,7 @@ impl Document { /// field_name (str): The field name for which we are adding the unsigned integer. /// value (int): The integer that will be added to the document. fn add_unsigned(&mut self, field_name: String, value: u64) { - add_value(self, field_name, value); + self.add_value(field_name, value); } /// Add a signed integer value to the document. @@ -435,7 +370,7 @@ impl Document { /// field_name (str): The field name for which we are adding the integer. /// value (int): The integer that will be added to the document. fn add_integer(&mut self, field_name: String, value: i64) { - add_value(self, field_name, value); + self.add_value(field_name, value); } /// Add a float value to the document. @@ -444,7 +379,7 @@ impl Document { /// field_name (str): The field name for which we are adding the value. /// value (f64): The float that will be added to the document. fn add_float(&mut self, field_name: String, value: f64) { - add_value(self, field_name, value); + self.add_value(field_name, value); } /// Add a date value to the document. @@ -464,8 +399,7 @@ impl Document { ) .single() .unwrap(); - add_value( - self, + self.add_value( field_name, tv::DateTime::from_timestamp_secs(datetime.timestamp()), ); @@ -476,7 +410,7 @@ impl Document { /// field_name (str): The field name for which we are adding the facet. /// value (Facet): The Facet that will be added to the document. fn add_facet(&mut self, field_name: String, facet: &Facet) { - add_value(self, field_name, facet.inner.clone()); + self.add_value(field_name, facet.inner.clone()); } /// Add a bytes value to the document. @@ -485,7 +419,7 @@ impl Document { /// field_name (str): The field for which we are adding the bytes. /// value (bytes): The bytes that will be added to the document. fn add_bytes(&mut self, field_name: String, bytes: Vec) { - add_value(self, field_name, bytes); + self.add_value(field_name, bytes); } /// Add a bytes value to the document. @@ -496,7 +430,7 @@ impl Document { fn add_json(&mut self, field_name: String, json: &str) { let json_object: serde_json::Value = serde_json::from_str(json).unwrap(); - add_value(self, field_name, json_object); + self.add_value(field_name, json_object); } /// Returns the number of added fields that have been added to the document @@ -577,6 +511,69 @@ impl Document { } impl Document { + fn add_value(&mut self, field_name: String, value: T) + where + Value: From, + { + self.field_values + .entry(field_name) + .or_insert_with(Vec::new) + .push(Value::from(value)); + } + + fn extract_py_values_from_dict( + py_dict: &PyDict, + schema: Option<&Schema>, + out_field_values: &mut BTreeMap>, + ) -> PyResult<()> { + // TODO: Reserve when https://github.com/rust-lang/rust/issues/72631 is stable. + // out_field_values.reserve(py_dict.len()); + + for key_value_any in py_dict.items() { + if let Ok(key_value) = key_value_any.downcast::() { + if key_value.len() != 2 { + continue; + } + let key = key_value.get_item(0)?.extract::()?; + + let field_type = if let Some(schema) = schema { + let field_type = schema + .inner + .get_field(key.as_str()) + .map(|field| { + schema.inner.get_field_entry(field).field_type() + }) + .ok(); + + if let Some(field_type) = field_type { + // A field type was found, so validate it after the values are extracted. + Some(field_type) + } else { + // The field does not exist in the schema, so skip over it. + continue; + } + } else { + // No schema was provided, so do not validate anything. + None + }; + + let value_list = if let Some(field_type) = field_type { + extract_value_single_or_list_for_type( + key_value.get_item(1)?, + field_type, + key.as_str(), + )? + } else { + extract_value_single_or_list(key_value.get_item(1)?)? + }; + + out_field_values.insert(key, value_list); + } + } + + Ok(()) + } + fn iter_values_for_field<'a>( &'a self, field: &str,