Clean up document.rs (#101)
parent
a01ccd99cb
commit
50809a186d
361
src/document.rs
361
src/document.rs
|
@ -22,6 +22,114 @@ use std::{
|
||||||
};
|
};
|
||||||
use tantivy::schema::Value;
|
use tantivy::schema::Value;
|
||||||
|
|
||||||
|
pub(crate) fn extract_value(any: &PyAny) -> PyResult<Value> {
|
||||||
|
if let Ok(s) = any.extract::<String>() {
|
||||||
|
return Ok(Value::Str(s));
|
||||||
|
}
|
||||||
|
if let Ok(num) = any.extract::<i64>() {
|
||||||
|
return Ok(Value::I64(num));
|
||||||
|
}
|
||||||
|
if let Ok(num) = any.extract::<f64>() {
|
||||||
|
return Ok(Value::F64(num));
|
||||||
|
}
|
||||||
|
if let Ok(datetime) = any.extract::<NaiveDateTime>() {
|
||||||
|
return Ok(Value::Date(tv::DateTime::from_timestamp_secs(
|
||||||
|
datetime.timestamp(),
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
if let Ok(facet) = any.extract::<Facet>() {
|
||||||
|
return Ok(Value::Facet(facet.inner));
|
||||||
|
}
|
||||||
|
if let Ok(b) = any.extract::<Vec<u8>>() {
|
||||||
|
return Ok(Value::Bytes(b));
|
||||||
|
}
|
||||||
|
Err(to_pyerr(format!("Value unsupported {any:?}")))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn extract_value_for_type(
|
||||||
|
any: &PyAny,
|
||||||
|
tv_type: tv::schema::Type,
|
||||||
|
field_name: &str,
|
||||||
|
) -> PyResult<Value> {
|
||||||
|
// Helper function to create `PyErr`s returned by this function.
|
||||||
|
fn to_pyerr_for_type<'a, E: std::error::Error>(
|
||||||
|
type_name: &'a str,
|
||||||
|
field_name: &'a str,
|
||||||
|
any: &'a PyAny,
|
||||||
|
) -> impl Fn(E) -> PyErr + 'a {
|
||||||
|
move |_| {
|
||||||
|
to_pyerr(format!(
|
||||||
|
"Expected {} type for field {}, got {:?}",
|
||||||
|
type_name, field_name, any
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let value = match tv_type {
|
||||||
|
tv::schema::Type::Str => Value::Str(
|
||||||
|
any.extract::<String>()
|
||||||
|
.map_err(to_pyerr_for_type("Str", field_name, any))?,
|
||||||
|
),
|
||||||
|
tv::schema::Type::U64 => Value::U64(
|
||||||
|
any.extract::<u64>()
|
||||||
|
.map_err(to_pyerr_for_type("U64", field_name, any))?,
|
||||||
|
),
|
||||||
|
tv::schema::Type::I64 => Value::I64(
|
||||||
|
any.extract::<i64>()
|
||||||
|
.map_err(to_pyerr_for_type("I64", field_name, any))?,
|
||||||
|
),
|
||||||
|
tv::schema::Type::F64 => Value::F64(
|
||||||
|
any.extract::<f64>()
|
||||||
|
.map_err(to_pyerr_for_type("F64", field_name, any))?,
|
||||||
|
),
|
||||||
|
tv::schema::Type::Date => {
|
||||||
|
let datetime = any
|
||||||
|
.extract::<NaiveDateTime>()
|
||||||
|
.map_err(to_pyerr_for_type("DateTime", field_name, any))?;
|
||||||
|
|
||||||
|
Value::Date(tv::DateTime::from_timestamp_secs(datetime.timestamp()))
|
||||||
|
}
|
||||||
|
tv::schema::Type::Facet => Value::Facet(
|
||||||
|
any.extract::<Facet>()
|
||||||
|
.map_err(to_pyerr_for_type("Facet", field_name, any))?
|
||||||
|
.inner,
|
||||||
|
),
|
||||||
|
_ => return Err(to_pyerr(format!("Value unsupported {:?}", any))),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_value_single_or_list(any: &PyAny) -> PyResult<Vec<Value>> {
|
||||||
|
if let Ok(values) = any.downcast::<PyList>() {
|
||||||
|
values.iter().map(extract_value).collect()
|
||||||
|
} else {
|
||||||
|
Ok(vec![extract_value(any)?])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_value_single_or_list_for_type(
|
||||||
|
any: &PyAny,
|
||||||
|
field_type: &tv::schema::FieldType,
|
||||||
|
field_name: &str,
|
||||||
|
) -> PyResult<Vec<Value>> {
|
||||||
|
// Check if a numeric fast field supports multivalues.
|
||||||
|
if let Ok(values) = any.downcast::<PyList>() {
|
||||||
|
values
|
||||||
|
.iter()
|
||||||
|
.map(|any| {
|
||||||
|
extract_value_for_type(any, field_type.value_type(), field_name)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
Ok(vec![extract_value_for_type(
|
||||||
|
any,
|
||||||
|
field_type.value_type(),
|
||||||
|
field_name,
|
||||||
|
)?])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn value_to_object(val: &JsonValue, py: Python<'_>) -> PyObject {
|
fn value_to_object(val: &JsonValue, py: Python<'_>) -> PyObject {
|
||||||
match val {
|
match val {
|
||||||
JsonValue::Null => py.None(),
|
JsonValue::Null => py.None(),
|
||||||
|
@ -174,179 +282,6 @@ impl fmt::Debug for Document {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_value<T>(doc: &mut Document, field_name: String, value: T)
|
|
||||||
where
|
|
||||||
Value: From<T>,
|
|
||||||
{
|
|
||||||
doc.field_values
|
|
||||||
.entry(field_name)
|
|
||||||
.or_insert_with(Vec::new)
|
|
||||||
.push(Value::from(value));
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn extract_value(any: &PyAny) -> PyResult<Value> {
|
|
||||||
if let Ok(s) = any.extract::<String>() {
|
|
||||||
return Ok(Value::Str(s));
|
|
||||||
}
|
|
||||||
if let Ok(num) = any.extract::<i64>() {
|
|
||||||
return Ok(Value::I64(num));
|
|
||||||
}
|
|
||||||
if let Ok(num) = any.extract::<f64>() {
|
|
||||||
return Ok(Value::F64(num));
|
|
||||||
}
|
|
||||||
if let Ok(datetime) = any.extract::<NaiveDateTime>() {
|
|
||||||
return Ok(Value::Date(tv::DateTime::from_timestamp_secs(
|
|
||||||
datetime.timestamp(),
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
if let Ok(facet) = any.extract::<Facet>() {
|
|
||||||
return Ok(Value::Facet(facet.inner));
|
|
||||||
}
|
|
||||||
if let Ok(b) = any.extract::<Vec<u8>>() {
|
|
||||||
return Ok(Value::Bytes(b));
|
|
||||||
}
|
|
||||||
Err(to_pyerr(format!("Value unsupported {any:?}")))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn extract_value_for_type(
|
|
||||||
any: &PyAny,
|
|
||||||
tv_type: tv::schema::Type,
|
|
||||||
field_name: &str,
|
|
||||||
) -> PyResult<Value> {
|
|
||||||
// Helper function to create `PyErr`s returned by this function.
|
|
||||||
fn to_pyerr_for_type<'a, E: std::error::Error>(
|
|
||||||
type_name: &'a str,
|
|
||||||
field_name: &'a str,
|
|
||||||
any: &'a PyAny,
|
|
||||||
) -> impl Fn(E) -> PyErr + 'a {
|
|
||||||
move |_| {
|
|
||||||
to_pyerr(format!(
|
|
||||||
"Expected {} type for field {}, got {:?}",
|
|
||||||
type_name, field_name, any
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let value = match tv_type {
|
|
||||||
tv::schema::Type::Str => Value::Str(
|
|
||||||
any.extract::<String>()
|
|
||||||
.map_err(to_pyerr_for_type("Str", field_name, any))?,
|
|
||||||
),
|
|
||||||
tv::schema::Type::U64 => Value::U64(
|
|
||||||
any.extract::<u64>()
|
|
||||||
.map_err(to_pyerr_for_type("U64", field_name, any))?,
|
|
||||||
),
|
|
||||||
tv::schema::Type::I64 => Value::I64(
|
|
||||||
any.extract::<i64>()
|
|
||||||
.map_err(to_pyerr_for_type("I64", field_name, any))?,
|
|
||||||
),
|
|
||||||
tv::schema::Type::F64 => Value::F64(
|
|
||||||
any.extract::<f64>()
|
|
||||||
.map_err(to_pyerr_for_type("F64", field_name, any))?,
|
|
||||||
),
|
|
||||||
tv::schema::Type::Date => {
|
|
||||||
let datetime = any
|
|
||||||
.extract::<NaiveDateTime>()
|
|
||||||
.map_err(to_pyerr_for_type("DateTime", field_name, any))?;
|
|
||||||
|
|
||||||
Value::Date(tv::DateTime::from_timestamp_secs(datetime.timestamp()))
|
|
||||||
}
|
|
||||||
tv::schema::Type::Facet => Value::Facet(
|
|
||||||
any.extract::<Facet>()
|
|
||||||
.map_err(to_pyerr_for_type("Facet", field_name, any))?
|
|
||||||
.inner,
|
|
||||||
),
|
|
||||||
_ => return Err(to_pyerr(format!("Value unsupported {:?}", any))),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(value)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_value_single_or_list(any: &PyAny) -> PyResult<Vec<Value>> {
|
|
||||||
if let Ok(values) = any.downcast::<PyList>() {
|
|
||||||
values.iter().map(extract_value).collect()
|
|
||||||
} else {
|
|
||||||
Ok(vec![extract_value(any)?])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_value_single_or_list_for_type(
|
|
||||||
any: &PyAny,
|
|
||||||
field_type: &tv::schema::FieldType,
|
|
||||||
field_name: &str,
|
|
||||||
) -> PyResult<Vec<Value>> {
|
|
||||||
// Check if a numeric fast field supports multivalues.
|
|
||||||
if let Ok(values) = any.downcast::<PyList>() {
|
|
||||||
values
|
|
||||||
.iter()
|
|
||||||
.map(|any| {
|
|
||||||
extract_value_for_type(any, field_type.value_type(), field_name)
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
} else {
|
|
||||||
Ok(vec![extract_value_for_type(
|
|
||||||
any,
|
|
||||||
field_type.value_type(),
|
|
||||||
field_name,
|
|
||||||
)?])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Document {
|
|
||||||
fn extract_py_values_from_dict(
|
|
||||||
py_dict: &PyDict,
|
|
||||||
schema: Option<&Schema>,
|
|
||||||
out_field_values: &mut BTreeMap<String, Vec<tv::schema::Value>>,
|
|
||||||
) -> PyResult<()> {
|
|
||||||
// TODO: Reserve when https://github.com/rust-lang/rust/issues/72631 is stable.
|
|
||||||
// out_field_values.reserve(py_dict.len());
|
|
||||||
|
|
||||||
for key_value_any in py_dict.items() {
|
|
||||||
if let Ok(key_value) = key_value_any.downcast::<PyTuple>() {
|
|
||||||
if key_value.len() != 2 {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
let key = key_value.get_item(0)?.extract::<String>()?;
|
|
||||||
|
|
||||||
let field_type = if let Some(schema) = schema {
|
|
||||||
let field_type = schema
|
|
||||||
.inner
|
|
||||||
.get_field(key.as_str())
|
|
||||||
.map(|field| {
|
|
||||||
schema.inner.get_field_entry(field).field_type()
|
|
||||||
})
|
|
||||||
.ok();
|
|
||||||
|
|
||||||
if let Some(field_type) = field_type {
|
|
||||||
// A field type was found, so validate it after the values are extracted.
|
|
||||||
Some(field_type)
|
|
||||||
} else {
|
|
||||||
// The field does not exist in the schema, so skip over it.
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// No schema was provided, so do not validate anything.
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let value_list = if let Some(field_type) = field_type {
|
|
||||||
extract_value_single_or_list_for_type(
|
|
||||||
key_value.get_item(1)?,
|
|
||||||
field_type,
|
|
||||||
key.as_str(),
|
|
||||||
)?
|
|
||||||
} else {
|
|
||||||
extract_value_single_or_list(key_value.get_item(1)?)?
|
|
||||||
};
|
|
||||||
|
|
||||||
out_field_values.insert(key, value_list);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl Document {
|
impl Document {
|
||||||
/// Creates a new document with optional fields from `**kwargs`.
|
/// Creates a new document with optional fields from `**kwargs`.
|
||||||
|
@ -417,7 +352,7 @@ impl Document {
|
||||||
/// field_name (str): The field name for which we are adding the text.
|
/// field_name (str): The field name for which we are adding the text.
|
||||||
/// text (str): The text that will be added to the document.
|
/// text (str): The text that will be added to the document.
|
||||||
fn add_text(&mut self, field_name: String, text: &str) {
|
fn add_text(&mut self, field_name: String, text: &str) {
|
||||||
add_value(self, field_name, text);
|
self.add_value(field_name, text);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add an unsigned integer value to the document.
|
/// Add an unsigned integer value to the document.
|
||||||
|
@ -426,7 +361,7 @@ impl Document {
|
||||||
/// field_name (str): The field name for which we are adding the unsigned integer.
|
/// field_name (str): The field name for which we are adding the unsigned integer.
|
||||||
/// value (int): The integer that will be added to the document.
|
/// value (int): The integer that will be added to the document.
|
||||||
fn add_unsigned(&mut self, field_name: String, value: u64) {
|
fn add_unsigned(&mut self, field_name: String, value: u64) {
|
||||||
add_value(self, field_name, value);
|
self.add_value(field_name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a signed integer value to the document.
|
/// Add a signed integer value to the document.
|
||||||
|
@ -435,7 +370,7 @@ impl Document {
|
||||||
/// field_name (str): The field name for which we are adding the integer.
|
/// field_name (str): The field name for which we are adding the integer.
|
||||||
/// value (int): The integer that will be added to the document.
|
/// value (int): The integer that will be added to the document.
|
||||||
fn add_integer(&mut self, field_name: String, value: i64) {
|
fn add_integer(&mut self, field_name: String, value: i64) {
|
||||||
add_value(self, field_name, value);
|
self.add_value(field_name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a float value to the document.
|
/// Add a float value to the document.
|
||||||
|
@ -444,7 +379,7 @@ impl Document {
|
||||||
/// field_name (str): The field name for which we are adding the value.
|
/// field_name (str): The field name for which we are adding the value.
|
||||||
/// value (f64): The float that will be added to the document.
|
/// value (f64): The float that will be added to the document.
|
||||||
fn add_float(&mut self, field_name: String, value: f64) {
|
fn add_float(&mut self, field_name: String, value: f64) {
|
||||||
add_value(self, field_name, value);
|
self.add_value(field_name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a date value to the document.
|
/// Add a date value to the document.
|
||||||
|
@ -464,8 +399,7 @@ impl Document {
|
||||||
)
|
)
|
||||||
.single()
|
.single()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
add_value(
|
self.add_value(
|
||||||
self,
|
|
||||||
field_name,
|
field_name,
|
||||||
tv::DateTime::from_timestamp_secs(datetime.timestamp()),
|
tv::DateTime::from_timestamp_secs(datetime.timestamp()),
|
||||||
);
|
);
|
||||||
|
@ -476,7 +410,7 @@ impl Document {
|
||||||
/// field_name (str): The field name for which we are adding the facet.
|
/// field_name (str): The field name for which we are adding the facet.
|
||||||
/// value (Facet): The Facet that will be added to the document.
|
/// value (Facet): The Facet that will be added to the document.
|
||||||
fn add_facet(&mut self, field_name: String, facet: &Facet) {
|
fn add_facet(&mut self, field_name: String, facet: &Facet) {
|
||||||
add_value(self, field_name, facet.inner.clone());
|
self.add_value(field_name, facet.inner.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a bytes value to the document.
|
/// Add a bytes value to the document.
|
||||||
|
@ -485,7 +419,7 @@ impl Document {
|
||||||
/// field_name (str): The field for which we are adding the bytes.
|
/// field_name (str): The field for which we are adding the bytes.
|
||||||
/// value (bytes): The bytes that will be added to the document.
|
/// value (bytes): The bytes that will be added to the document.
|
||||||
fn add_bytes(&mut self, field_name: String, bytes: Vec<u8>) {
|
fn add_bytes(&mut self, field_name: String, bytes: Vec<u8>) {
|
||||||
add_value(self, field_name, bytes);
|
self.add_value(field_name, bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a bytes value to the document.
|
/// Add a bytes value to the document.
|
||||||
|
@ -496,7 +430,7 @@ impl Document {
|
||||||
fn add_json(&mut self, field_name: String, json: &str) {
|
fn add_json(&mut self, field_name: String, json: &str) {
|
||||||
let json_object: serde_json::Value =
|
let json_object: serde_json::Value =
|
||||||
serde_json::from_str(json).unwrap();
|
serde_json::from_str(json).unwrap();
|
||||||
add_value(self, field_name, json_object);
|
self.add_value(field_name, json_object);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the number of added fields that have been added to the document
|
/// Returns the number of added fields that have been added to the document
|
||||||
|
@ -577,6 +511,69 @@ impl Document {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Document {
|
impl Document {
|
||||||
|
fn add_value<T>(&mut self, field_name: String, value: T)
|
||||||
|
where
|
||||||
|
Value: From<T>,
|
||||||
|
{
|
||||||
|
self.field_values
|
||||||
|
.entry(field_name)
|
||||||
|
.or_insert_with(Vec::new)
|
||||||
|
.push(Value::from(value));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_py_values_from_dict(
|
||||||
|
py_dict: &PyDict,
|
||||||
|
schema: Option<&Schema>,
|
||||||
|
out_field_values: &mut BTreeMap<String, Vec<tv::schema::Value>>,
|
||||||
|
) -> PyResult<()> {
|
||||||
|
// TODO: Reserve when https://github.com/rust-lang/rust/issues/72631 is stable.
|
||||||
|
// out_field_values.reserve(py_dict.len());
|
||||||
|
|
||||||
|
for key_value_any in py_dict.items() {
|
||||||
|
if let Ok(key_value) = key_value_any.downcast::<PyTuple>() {
|
||||||
|
if key_value.len() != 2 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let key = key_value.get_item(0)?.extract::<String>()?;
|
||||||
|
|
||||||
|
let field_type = if let Some(schema) = schema {
|
||||||
|
let field_type = schema
|
||||||
|
.inner
|
||||||
|
.get_field(key.as_str())
|
||||||
|
.map(|field| {
|
||||||
|
schema.inner.get_field_entry(field).field_type()
|
||||||
|
})
|
||||||
|
.ok();
|
||||||
|
|
||||||
|
if let Some(field_type) = field_type {
|
||||||
|
// A field type was found, so validate it after the values are extracted.
|
||||||
|
Some(field_type)
|
||||||
|
} else {
|
||||||
|
// The field does not exist in the schema, so skip over it.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No schema was provided, so do not validate anything.
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let value_list = if let Some(field_type) = field_type {
|
||||||
|
extract_value_single_or_list_for_type(
|
||||||
|
key_value.get_item(1)?,
|
||||||
|
field_type,
|
||||||
|
key.as_str(),
|
||||||
|
)?
|
||||||
|
} else {
|
||||||
|
extract_value_single_or_list(key_value.get_item(1)?)?
|
||||||
|
};
|
||||||
|
|
||||||
|
out_field_values.insert(key, value_list);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn iter_values_for_field<'a>(
|
fn iter_values_for_field<'a>(
|
||||||
&'a self,
|
&'a self,
|
||||||
field: &str,
|
field: &str,
|
||||||
|
|
Loading…
Reference in New Issue