diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 91a871d52e9a..2f6eaf1c052f 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1468,19 +1468,37 @@ impl ValuesFields { } } +// `name_map` tracks a mapping between a field name and the number of appearances of that field. +// +// Some field names might already come to this function with the count (number of times it appeared) +// as a sufix e.g. id:1, so there's still a chance of name collisions, for example, +// if these three fields passed to this function: "col:1", "col" and "col", the function +// would rename them to -> col:1, col, col:1 causing a posteriror error when building the DFSchema. +// that's why we need the `seen` set, so the fields are always unique. +// pub fn change_redundant_column(fields: &Fields) -> Vec { let mut name_map = HashMap::new(); + let mut seen: HashSet = HashSet::new(); + fields .into_iter() .map(|field| { - let counter = name_map.entry(field.name().to_string()).or_insert(0); - *counter += 1; - if *counter > 1 { - let new_name = format!("{}:{}", field.name(), *counter - 1); - Field::new(new_name, field.data_type().clone(), field.is_nullable()) - } else { - field.as_ref().clone() + let base_name = field.name(); + let count = name_map.entry(base_name.clone()).or_insert(0); + let mut new_name = base_name.clone(); + + // Loop until we find a name that hasn't been used + while seen.contains(&new_name) { + *count += 1; + new_name = format!("{}:{}", base_name, count); } + + seen.insert(new_name.clone()); + + let mut modified_field = + Field::new(&new_name, field.data_type().clone(), field.is_nullable()); + modified_field.set_metadata(field.metadata().clone()); + modified_field }) .collect() } @@ -2730,10 +2748,13 @@ mod tests { let t1_field_1 = Field::new("a", DataType::Int32, false); let t2_field_1 = Field::new("a", DataType::Int32, false); let t2_field_3 = Field::new("a", DataType::Int32, false); + let t2_field_4 = Field::new("a:1", DataType::Int32, false); let t1_field_2 = Field::new("b", DataType::Int32, false); let t2_field_2 = Field::new("b", DataType::Int32, false); - let field_vec = vec![t1_field_1, t2_field_1, t1_field_2, t2_field_2, t2_field_3]; + let field_vec = vec![ + t1_field_1, t2_field_1, t1_field_2, t2_field_2, t2_field_3, t2_field_4, + ]; let remove_redundant = change_redundant_column(&Fields::from(field_vec)); assert_eq!( @@ -2744,6 +2765,7 @@ mod tests { Field::new("b", DataType::Int32, false), Field::new("b:1", DataType::Int32, false), Field::new("a:2", DataType::Int32, false), + Field::new("a:1:1", DataType::Int32, false), ] ); Ok(()) diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index 035678fbf1f3..d10243fbab45 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -22,7 +22,7 @@ use crate::PhysicalExpr; use arrow::datatypes::SchemaRef; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{internal_err, Result}; +use datafusion_common::Result; /// Stores the mapping between source expressions and target expressions for a /// projection. @@ -66,9 +66,9 @@ impl ProjectionMapping { let idx = col.index(); let matching_input_field = input_schema.field(idx); if col.name() != matching_input_field.name() { - return internal_err!("Input field name {} does not match with the projection expression {}", - matching_input_field.name(),col.name()) - } + let fixed_col = Column::new(col.name(), idx); + return Ok(Transformed::yes(Arc::new(fixed_col))); + } let matching_input_column = Column::new(matching_input_field.name(), idx); Ok(Transformed::yes(Arc::new(matching_input_column))) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 61f3379735c7..1442267d3dbb 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1835,8 +1835,7 @@ fn requalify_sides_if_needed( }) }) { // These names have no connection to the original plan, but they'll make the columns - // (mostly) unique. There may be cases where this still causes duplicates, if either left - // or right side itself contains duplicate names with different qualifiers. + // (mostly) unique. Ok(( left.alias(TableReference::bare("left"))?, right.alias(TableReference::bare("right"))?, diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index af9d92378298..bdeeeb585c0c 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -519,6 +519,33 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_multiple_joins() -> Result<()> { + let plan_str = test_plan_to_string("multiple_joins.json").await?; + assert_eq!( + plan_str, + "Projection: left.count(Int64(1)) AS count_first, left.category, left.count(Int64(1)):1 AS count_second, right.count(Int64(1)) AS count_third\ + \n Left Join: left.id = right.id\ + \n SubqueryAlias: left\ + \n Left Join: left.id = right.id\ + \n SubqueryAlias: left\ + \n Left Join: left.id = right.id\ + \n SubqueryAlias: left\ + \n Aggregate: groupBy=[[id]], aggr=[[count(Int64(1))]]\ + \n Values: (Int64(1)), (Int64(2))\ + \n SubqueryAlias: right\ + \n Aggregate: groupBy=[[id, category]], aggr=[[]]\ + \n Values: (Int64(1), Utf8(\"info\")), (Int64(2), Utf8(\"low\"))\ + \n SubqueryAlias: right\ + \n Aggregate: groupBy=[[id]], aggr=[[count(Int64(1))]]\ + \n Values: (Int64(1)), (Int64(2))\ + \n SubqueryAlias: right\ + \n Aggregate: groupBy=[[id]], aggr=[[count(Int64(1))]]\ + \n Values: (Int64(1)), (Int64(2))" + ); + Ok(()) + } + #[tokio::test] async fn test_select_window_count() -> Result<()> { let plan_str = test_plan_to_string("select_window_count.substrait.json").await?; diff --git a/datafusion/substrait/tests/testdata/test_plans/multiple_joins.json b/datafusion/substrait/tests/testdata/test_plans/multiple_joins.json new file mode 100644 index 000000000000..e88cce648da7 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/multiple_joins.json @@ -0,0 +1,536 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_aggregate_generic.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "count:" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any_any" + } + }], + "relations": [{ + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [8, 9, 10, 11] + } + }, + "input": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["id"], + "struct": { + "types": [{ + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "values": [{ + "fields": [{ + "i64": "1", + "nullable": true, + "typeVariationReference": 0 + }] + }, { + "fields": [{ + "i64": "2", + "nullable": true, + "typeVariationReference": 0 + }] + }] + } + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }], + "expressionReferences": [] + }], + "measures": [{ + "measure": { + "functionReference": 0, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [], + "options": [] + } + }], + "groupingExpressions": [] + } + }, + "right": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["id", "category"], + "struct": { + "types": [{ + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "string": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "values": [{ + "fields": [{ + "i64": "1", + "nullable": true, + "typeVariationReference": 0 + }, { + "string": "info", + "nullable": true, + "typeVariationReference": 0 + }] + }, { + "fields": [{ + "i64": "2", + "nullable": true, + "typeVariationReference": 0 + }, { + "string": "low", + "nullable": true, + "typeVariationReference": 0 + }] + }] + } + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }], + "expressionReferences": [] + }], + "measures": [], + "groupingExpressions": [] + } + }, + "expression": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }], + "options": [] + } + }, + "type": "JOIN_TYPE_LEFT" + } + }, + "right": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["id"], + "struct": { + "types": [{ + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "values": [{ + "fields": [{ + "i64": "1", + "nullable": true, + "typeVariationReference": 0 + }] + }, { + "fields": [{ + "i64": "2", + "nullable": true, + "typeVariationReference": 0 + }] + }] + } + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }], + "expressionReferences": [] + }], + "measures": [{ + "measure": { + "functionReference": 0, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [], + "options": [] + } + }], + "groupingExpressions": [] + } + }, + "expression": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }], + "options": [] + } + }, + "type": "JOIN_TYPE_LEFT" + } + }, + "right": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["id"], + "struct": { + "types": [{ + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "values": [{ + "fields": [{ + "i64": "1", + "nullable": true, + "typeVariationReference": 0 + }] + }, { + "fields": [{ + "i64": "2", + "nullable": true, + "typeVariationReference": 0 + }] + }] + } + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }], + "expressionReferences": [] + }], + "measures": [{ + "measure": { + "functionReference": 0, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [], + "options": [] + } + }], + "groupingExpressions": [] + } + }, + "expression": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + } + }], + "options": [] + } + }, + "type": "JOIN_TYPE_LEFT" + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": { + } + } + }] + } + }, + "names": ["count_first", "category", "count_second", "count_third"] + } + }], + "expectedTypeUrls": [], + "version": { + "majorNumber": 0, + "minorNumber": 52, + "patchNumber": 0, + "gitHash": "" + } +} \ No newline at end of file