Skip to content

Commit 99bc393

Browse files
authored
Move Scalar Subquery validation logic to the Analyzer (#6084)
* Move Scalar Subquery validation logic to the Analyzer * resolve review comments * resolve review comments
1 parent 569f6fe commit 99bc393

File tree

13 files changed

+1031
-472
lines changed

13 files changed

+1031
-472
lines changed

benchmarks/expected-plans/q2.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
+---------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
44
| logical_plan | Sort: supplier.s_acctbal DESC NULLS FIRST, nation.n_name ASC NULLS LAST, supplier.s_name ASC NULLS LAST, part.p_partkey ASC NULLS LAST |
55
| | Projection: supplier.s_acctbal, supplier.s_name, nation.n_name, part.p_partkey, part.p_mfgr, supplier.s_address, supplier.s_phone, supplier.s_comment |
6-
| | Inner Join: part.p_partkey = __scalar_sq_1.ps_partkey, partsupp.ps_supplycost = __scalar_sq_1.__value |
6+
| | Inner Join: partsupp.ps_supplycost = __scalar_sq_1.__value, part.p_partkey = __scalar_sq_1.ps_partkey |
77
| | Projection: part.p_partkey, part.p_mfgr, partsupp.ps_supplycost, supplier.s_name, supplier.s_address, supplier.s_phone, supplier.s_acctbal, supplier.s_comment, nation.n_name |
88
| | Inner Join: nation.n_regionkey = region.r_regionkey |
99
| | Projection: part.p_partkey, part.p_mfgr, partsupp.ps_supplycost, supplier.s_name, supplier.s_address, supplier.s_phone, supplier.s_acctbal, supplier.s_comment, nation.n_name, nation.n_regionkey |
@@ -40,9 +40,9 @@
4040
| | SortExec: expr=[s_acctbal@0 DESC,n_name@2 ASC NULLS LAST,s_name@1 ASC NULLS LAST,p_partkey@3 ASC NULLS LAST] |
4141
| | ProjectionExec: expr=[s_acctbal@6 as s_acctbal, s_name@3 as s_name, n_name@8 as n_name, p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr, s_address@4 as s_address, s_phone@5 as s_phone, s_comment@7 as s_comment] |
4242
| | CoalesceBatchesExec: target_batch_size=8192 |
43-
| | HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "p_partkey", index: 0 }, Column { name: "ps_partkey", index: 0 }), (Column { name: "ps_supplycost", index: 2 }, Column { name: "__value", index: 1 })] |
43+
| | HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "ps_supplycost", index: 2 }, Column { name: "__value", index: 1 }), (Column { name: "p_partkey", index: 0 }, Column { name: "ps_partkey", index: 0 })] |
4444
| | CoalesceBatchesExec: target_batch_size=8192 |
45-
| | RepartitionExec: partitioning=Hash([Column { name: "p_partkey", index: 0 }, Column { name: "ps_supplycost", index: 2 }], 2), input_partitions=2 |
45+
| | RepartitionExec: partitioning=Hash([Column { name: "ps_supplycost", index: 2 }, Column { name: "p_partkey", index: 0 }], 2), input_partitions=2 |
4646
| | ProjectionExec: expr=[p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr, ps_supplycost@2 as ps_supplycost, s_name@3 as s_name, s_address@4 as s_address, s_phone@5 as s_phone, s_acctbal@6 as s_acctbal, s_comment@7 as s_comment, n_name@8 as n_name] |
4747
| | CoalesceBatchesExec: target_batch_size=8192 |
4848
| | HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "n_regionkey", index: 9 }, Column { name: "r_regionkey", index: 0 })] |
@@ -85,7 +85,7 @@
8585
| | FilterExec: r_name@1 = EUROPE |
8686
| | MemoryExec: partitions=0, partition_sizes=[] |
8787
| | CoalesceBatchesExec: target_batch_size=8192 |
88-
| | RepartitionExec: partitioning=Hash([Column { name: "ps_partkey", index: 0 }, Column { name: "__value", index: 1 }], 2), input_partitions=2 |
88+
| | RepartitionExec: partitioning=Hash([Column { name: "__value", index: 1 }, Column { name: "ps_partkey", index: 0 }], 2), input_partitions=2 |
8989
| | ProjectionExec: expr=[ps_partkey@0 as ps_partkey, MIN(partsupp.ps_supplycost)@1 as __value] |
9090
| | AggregateExec: mode=FinalPartitioned, gby=[ps_partkey@0 as ps_partkey], aggr=[MIN(partsupp.ps_supplycost)] |
9191
| | CoalesceBatchesExec: target_batch_size=8192 |

datafusion/core/tests/dataframe.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ use datafusion_expr::expr::{GroupingSet, Sort};
3939
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
4040
use datafusion_expr::Expr::Wildcard;
4141
use datafusion_expr::{
42-
avg, col, count, exists, expr, in_subquery, lit, max, scalar_subquery, sum,
43-
AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFrameBound,
42+
avg, col, count, exists, expr, in_subquery, lit, max, out_ref_col, scalar_subquery,
43+
sum, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFrameBound,
4444
WindowFrameUnits, WindowFunction,
4545
};
4646

@@ -241,7 +241,7 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> {
241241
scalar_subquery(Arc::new(
242242
ctx.table("t2")
243243
.await?
244-
.filter(col("t1.a").eq(col("t2.a")))?
244+
.filter(out_ref_col(DataType::UInt32, "t1.a").eq(col("t2.a")))?
245245
.aggregate(vec![], vec![count(lit(COUNT_STAR_EXPANSION))])?
246246
.select(vec![count(lit(COUNT_STAR_EXPANSION))])?
247247
.into_unoptimized_plan(),

datafusion/core/tests/sql/subqueries.rs

Lines changed: 226 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ async fn invalid_scalar_subquery() -> Result<()> {
185185
let dataframe = ctx.sql(sql).await.expect(&msg);
186186
let err = dataframe.into_optimized_plan().err().unwrap();
187187
assert_eq!(
188-
r#"Context("check_analyzed_plan", Plan("Scalar subquery should only return one column"))"#,
188+
r#"Context("check_analyzed_plan", Plan("Scalar subquery should only return one column, but found 2: t2.t2_id, t2.t2_name"))"#,
189189
&format!("{err:?}")
190190
);
191191

@@ -203,7 +203,231 @@ async fn subquery_not_allowed() -> Result<()> {
203203
let err = dataframe.into_optimized_plan().err().unwrap();
204204

205205
assert_eq!(
206-
r#"Context("check_analyzed_plan", Plan("In/Exist subquery can not be used in Sort plan nodes"))"#,
206+
r#"Context("check_analyzed_plan", Plan("In/Exist subquery can only be used in Projection, Filter, Window functions, Aggregate and Join plan nodes"))"#,
207+
&format!("{err:?}")
208+
);
209+
210+
Ok(())
211+
}
212+
213+
#[tokio::test]
214+
async fn non_aggregated_correlated_scalar_subquery() -> Result<()> {
215+
let ctx = create_join_context("t1_id", "t2_id", true)?;
216+
217+
let sql = "SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int) as t2_int from t1";
218+
let msg = format!("Creating logical plan for '{sql}'");
219+
let dataframe = ctx.sql(sql).await.expect(&msg);
220+
let err = dataframe.into_optimized_plan().err().unwrap();
221+
222+
assert_eq!(
223+
r#"Context("check_analyzed_plan", Plan("Correlated scalar subquery must be aggregated to return at most one row"))"#,
224+
&format!("{err:?}")
225+
);
226+
227+
let sql = "SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1_int group by t2_int) as t2_int from t1";
228+
let msg = format!("Creating logical plan for '{sql}'");
229+
let dataframe = ctx.sql(sql).await.expect(&msg);
230+
let err = dataframe.into_optimized_plan().err().unwrap();
231+
232+
assert_eq!(
233+
r#"Context("check_analyzed_plan", Plan("Correlated scalar subquery must be aggregated to return at most one row"))"#,
234+
&format!("{err:?}")
235+
);
236+
237+
Ok(())
238+
}
239+
240+
#[tokio::test]
241+
async fn non_aggregated_correlated_scalar_subquery_with_limit() -> Result<()> {
242+
let ctx = create_join_context("t1_id", "t2_id", true)?;
243+
244+
let sql = "SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int limit 2) as t2_int from t1";
245+
let msg = format!("Creating logical plan for '{sql}'");
246+
let dataframe = ctx.sql(sql).await.expect(&msg);
247+
let err = dataframe.into_optimized_plan().err().unwrap();
248+
249+
assert_eq!(
250+
r#"Context("check_analyzed_plan", Plan("Correlated scalar subquery must be aggregated to return at most one row"))"#,
251+
&format!("{err:?}")
252+
);
253+
254+
Ok(())
255+
}
256+
257+
#[tokio::test]
258+
async fn non_aggregated_correlated_scalar_subquery_with_single_row() -> Result<()> {
259+
let ctx = create_join_context("t1_id", "t2_id", true)?;
260+
261+
let sql = "SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int limit 1) as t2_int from t1";
262+
let msg = format!("Creating logical plan for '{sql}'");
263+
let dataframe = ctx.sql(sql).await.expect(&msg);
264+
let plan = dataframe.into_optimized_plan()?;
265+
266+
let expected = vec![
267+
"Projection: t1.t1_id, (<subquery>) AS t2_int [t1_id:UInt32;N, t2_int:UInt32;N]",
268+
" Subquery: [t2_int:UInt32;N]",
269+
" Limit: skip=0, fetch=1 [t2_int:UInt32;N]",
270+
" Projection: t2.t2_int [t2_int:UInt32;N]",
271+
" Filter: t2.t2_int = outer_ref(t1.t1_int) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
272+
" TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
273+
" TableScan: t1 projection=[t1_id] [t1_id:UInt32;N]",
274+
];
275+
let formatted = plan.display_indent_schema().to_string();
276+
let actual: Vec<&str> = formatted.trim().lines().collect();
277+
assert_eq!(
278+
expected, actual,
279+
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
280+
);
281+
282+
let sql = "SELECT t1_id from t1 where t1_int = (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int limit 1)";
283+
let msg = format!("Creating logical plan for '{sql}'");
284+
let dataframe = ctx.sql(sql).await.expect(&msg);
285+
let plan = dataframe.into_optimized_plan()?;
286+
287+
let expected = vec![
288+
"Projection: t1.t1_id [t1_id:UInt32;N]",
289+
" Filter: t1.t1_int = (<subquery>) [t1_id:UInt32;N, t1_int:UInt32;N]",
290+
" Subquery: [t2_int:UInt32;N]",
291+
" Limit: skip=0, fetch=1 [t2_int:UInt32;N]",
292+
" Projection: t2.t2_int [t2_int:UInt32;N]",
293+
" Filter: t2.t2_int = outer_ref(t1.t1_int) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
294+
" TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
295+
" TableScan: t1 projection=[t1_id, t1_int] [t1_id:UInt32;N, t1_int:UInt32;N]",
296+
];
297+
let formatted = plan.display_indent_schema().to_string();
298+
let actual: Vec<&str> = formatted.trim().lines().collect();
299+
assert_eq!(
300+
expected, actual,
301+
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
302+
);
303+
304+
let sql = "SELECT t1_id, (SELECT a FROM (select 1 as a) WHERE a = t1.t1_int) as t2_int from t1";
305+
let msg = format!("Creating logical plan for '{sql}'");
306+
let dataframe = ctx.sql(sql).await.expect(&msg);
307+
let plan = dataframe.into_optimized_plan()?;
308+
309+
let expected = vec![
310+
"Projection: t1.t1_id, (<subquery>) AS t2_int [t1_id:UInt32;N, t2_int:Int64]",
311+
" Subquery: [a:Int64]",
312+
" Projection: a [a:Int64]",
313+
" Filter: a = CAST(outer_ref(t1.t1_int) AS Int64) [a:Int64]",
314+
" Projection: Int64(1) AS a [a:Int64]",
315+
" EmptyRelation []",
316+
" TableScan: t1 projection=[t1_id] [t1_id:UInt32;N]",
317+
];
318+
let formatted = plan.display_indent_schema().to_string();
319+
let actual: Vec<&str> = formatted.trim().lines().collect();
320+
assert_eq!(
321+
expected, actual,
322+
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
323+
);
324+
325+
Ok(())
326+
}
327+
328+
#[tokio::test]
329+
async fn non_equal_correlated_scalar_subquery() -> Result<()> {
330+
let ctx = create_join_context("t1_id", "t2_id", true)?;
331+
332+
let sql = "SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id < t1.t1_id) as t2_sum from t1";
333+
let msg = format!("Creating logical plan for '{sql}'");
334+
let dataframe = ctx.sql(sql).await.expect(&msg);
335+
let err = dataframe.into_optimized_plan().err().unwrap();
336+
337+
assert_eq!(
338+
r#"Context("check_analyzed_plan", Plan("Correlated column is not allowed in predicate: t2.t2_id < outer_ref(t1.t1_id)"))"#,
339+
&format!("{err:?}")
340+
);
341+
342+
Ok(())
343+
}
344+
345+
#[tokio::test]
346+
async fn aggregated_correlated_scalar_subquery() -> Result<()> {
347+
let ctx = create_join_context("t1_id", "t2_id", true)?;
348+
349+
let sql = "SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id) as t2_sum from t1";
350+
let msg = format!("Creating logical plan for '{sql}'");
351+
let dataframe = ctx.sql(sql).await.expect(&msg);
352+
let plan = dataframe.into_optimized_plan()?;
353+
354+
let expected = vec![
355+
"Projection: t1.t1_id, (<subquery>) AS t2_sum [t1_id:UInt32;N, t2_sum:UInt64;N]",
356+
" Subquery: [SUM(t2.t2_int):UInt64;N]",
357+
" Projection: SUM(t2.t2_int) [SUM(t2.t2_int):UInt64;N]",
358+
" Aggregate: groupBy=[[]], aggr=[[SUM(t2.t2_int)]] [SUM(t2.t2_int):UInt64;N]",
359+
" Filter: t2.t2_id = outer_ref(t1.t1_id) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
360+
" TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
361+
" TableScan: t1 projection=[t1_id] [t1_id:UInt32;N]",
362+
];
363+
let formatted = plan.display_indent_schema().to_string();
364+
let actual: Vec<&str> = formatted.trim().lines().collect();
365+
assert_eq!(
366+
expected, actual,
367+
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
368+
);
369+
370+
Ok(())
371+
}
372+
373+
#[tokio::test]
374+
async fn aggregated_correlated_scalar_subquery_with_extra_group_by_columns() -> Result<()>
375+
{
376+
let ctx = create_join_context("t1_id", "t2_id", true)?;
377+
378+
let sql = "SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id group by t2_name) as t2_sum from t1";
379+
let msg = format!("Creating logical plan for '{sql}'");
380+
let dataframe = ctx.sql(sql).await.expect(&msg);
381+
let err = dataframe.into_optimized_plan().err().unwrap();
382+
383+
assert_eq!(
384+
r#"Context("check_analyzed_plan", Plan("A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns"))"#,
385+
&format!("{err:?}")
386+
);
387+
388+
Ok(())
389+
}
390+
391+
#[tokio::test]
392+
async fn aggregated_correlated_scalar_subquery_with_extra_group_by_constant() -> Result<()>
393+
{
394+
let ctx = create_join_context("t1_id", "t2_id", true)?;
395+
396+
let sql = "SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id group by t2_id, 'a') as t2_sum from t1";
397+
let msg = format!("Creating logical plan for '{sql}'");
398+
let dataframe = ctx.sql(sql).await.expect(&msg);
399+
let plan = dataframe.into_optimized_plan()?;
400+
401+
let expected = vec![
402+
"Projection: t1.t1_id, (<subquery>) AS t2_sum [t1_id:UInt32;N, t2_sum:UInt64;N]",
403+
" Subquery: [SUM(t2.t2_int):UInt64;N]",
404+
" Projection: SUM(t2.t2_int) [SUM(t2.t2_int):UInt64;N]",
405+
" Aggregate: groupBy=[[t2.t2_id, Utf8(\"a\")]], aggr=[[SUM(t2.t2_int)]] [t2_id:UInt32;N, Utf8(\"a\"):Utf8, SUM(t2.t2_int):UInt64;N]",
406+
" Filter: t2.t2_id = outer_ref(t1.t1_id) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
407+
" TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
408+
" TableScan: t1 projection=[t1_id] [t1_id:UInt32;N]",
409+
];
410+
let formatted = plan.display_indent_schema().to_string();
411+
let actual: Vec<&str> = formatted.trim().lines().collect();
412+
assert_eq!(
413+
expected, actual,
414+
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
415+
);
416+
417+
Ok(())
418+
}
419+
420+
#[tokio::test]
421+
async fn group_by_correlated_scalar_subquery() -> Result<()> {
422+
let ctx = create_join_context("t1_id", "t2_id", true)?;
423+
let sql = "SELECT sum(t1_int) from t1 GROUP BY (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id)";
424+
425+
let msg = format!("Creating logical plan for '{sql}'");
426+
let dataframe = ctx.sql(sql).await.expect(&msg);
427+
let err = dataframe.into_optimized_plan().err().unwrap();
428+
429+
assert_eq!(
430+
r#"Context("check_analyzed_plan", Plan("Correlated scalar subquery in the GROUP BY clause must also be in the aggregate expressions"))"#,
207431
&format!("{err:?}")
208432
);
209433

0 commit comments

Comments
 (0)