diff --git a/datafusion/substrait/src/logical_plan/consumer/utils.rs b/datafusion/substrait/src/logical_plan/consumer/utils.rs index eb330dcc3663..8ffbd8bf608d 100644 --- a/datafusion/substrait/src/logical_plan/consumer/utils.rs +++ b/datafusion/substrait/src/logical_plan/consumer/utils.rs @@ -16,7 +16,7 @@ // under the License. use crate::logical_plan::consumer::SubstraitConsumer; -use datafusion::arrow::datatypes::{DataType, Field, FieldRef, Schema}; +use datafusion::arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema}; use datafusion::common::{ not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef, TableReference, @@ -139,6 +139,39 @@ pub(super) fn rename_field( .with_data_type(DataType::LargeList(FieldRef::new(renamed_inner))) .with_name(name)) } + DataType::Map(inner, sorted) => match inner.data_type() { + DataType::Struct(key_and_value) if key_and_value.len() == 2 => { + let renamed_keys = rename_field( + key_and_value[0].as_ref(), + dfs_names, + 0, + name_idx, + /*rename_self=*/ false, + )?; + let renamed_values = rename_field( + key_and_value[1].as_ref(), + dfs_names, + 0, + name_idx, + /*rename_self=*/ false, + )?; + Ok(field + .to_owned() + .with_data_type(DataType::Map( + Arc::new(Field::new( + inner.name(), + DataType::Struct(Fields::from(vec![ + renamed_keys, + renamed_values, + ])), + inner.is_nullable(), + )), + *sorted, + )) + .with_name(name)) + } + _ => substrait_err!("Map fields must contain a Struct with exactly 2 fields"), + }, _ => Ok(field.to_owned().with_name(name)), } } @@ -381,11 +414,17 @@ pub async fn from_substrait_sorts( #[cfg(test)] pub(crate) mod tests { + use super::make_renamed_schema; use crate::extensions::Extensions; use crate::logical_plan::consumer::DefaultSubstraitConsumer; + use datafusion::arrow::datatypes::{DataType, Field}; + use datafusion::common::DFSchema; + use datafusion::error::Result; use datafusion::execution::SessionState; use datafusion::prelude::SessionContext; - use std::sync::LazyLock; + use datafusion::sql::TableReference; + use std::collections::HashMap; + use std::sync::{Arc, LazyLock}; pub(crate) static TEST_SESSION_STATE: LazyLock = LazyLock::new(|| SessionContext::default().state()); @@ -396,4 +435,129 @@ pub(crate) mod tests { let state = &TEST_SESSION_STATE; DefaultSubstraitConsumer::new(extensions, state) } + + #[tokio::test] + async fn rename_schema() -> Result<()> { + let table_ref = TableReference::bare("test"); + let fields = vec![ + ( + Some(table_ref.clone()), + Arc::new(Field::new("0", DataType::Int32, false)), + ), + ( + Some(table_ref.clone()), + Arc::new(Field::new_struct( + "1", + vec![ + Field::new("2", DataType::Int32, false), + Field::new_struct( + "3", + vec![Field::new("4", DataType::Int32, false)], + false, + ), + ], + false, + )), + ), + ( + Some(table_ref.clone()), + Arc::new(Field::new_list( + "5", + Arc::new(Field::new_struct( + "item", + vec![Field::new("6", DataType::Int32, false)], + false, + )), + false, + )), + ), + ( + Some(table_ref.clone()), + Arc::new(Field::new_map( + "7", + "entries", + Arc::new(Field::new_struct( + "keys", + vec![Field::new("8", DataType::Int32, false)], + false, + )), + Arc::new(Field::new_struct( + "values", + vec![Field::new("9", DataType::Int32, false)], + false, + )), + false, + false, + )), + ), + ]; + + let schema = Arc::new(DFSchema::new_with_metadata(fields, HashMap::default())?); + let dfs_names = vec![ + "a".to_string(), + "b".to_string(), + "c".to_string(), + "d".to_string(), + "e".to_string(), + "f".to_string(), + "g".to_string(), + "h".to_string(), + "i".to_string(), + "j".to_string(), + ]; + let renamed_schema = make_renamed_schema(&schema, &dfs_names)?; + + assert_eq!(renamed_schema.fields().len(), 4); + assert_eq!( + *renamed_schema.field(0), + Field::new("a", DataType::Int32, false) + ); + assert_eq!( + *renamed_schema.field(1), + Field::new_struct( + "b", + vec![ + Field::new("c", DataType::Int32, false), + Field::new_struct( + "d", + vec![Field::new("e", DataType::Int32, false)], + false, + ) + ], + false, + ) + ); + assert_eq!( + *renamed_schema.field(2), + Field::new_list( + "f", + Arc::new(Field::new_struct( + "item", + vec![Field::new("g", DataType::Int32, false)], + false, + )), + false, + ) + ); + assert_eq!( + *renamed_schema.field(3), + Field::new_map( + "h", + "entries", + Arc::new(Field::new_struct( + "keys", + vec![Field::new("i", DataType::Int32, false)], + false, + )), + Arc::new(Field::new_struct( + "values", + vec![Field::new("j", DataType::Int32, false)], + false, + )), + false, + false, + ) + ); + Ok(()) + } }