Skip to content

Commit 0594863

Browse files
committed
fix
1 parent 418cfd1 commit 0594863

File tree

3 files changed

+66
-4
lines changed

3 files changed

+66
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyCharTypePaddingHelper.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ object ApplyCharTypePaddingHelper {
6868
private[sql] def paddingForStringComparison(
6969
plan: LogicalPlan,
7070
padCharCol: Boolean): LogicalPlan = {
71-
plan.resolveOperatorsUpWithPruning(_.containsAnyPattern(BINARY_COMPARISON, IN)) {
71+
plan.resolveOperatorsUpWithSubqueriesAndPruning(_.containsAnyPattern(BINARY_COMPARISON, IN)) {
7272
case operator =>
7373
operator.transformExpressionsUpWithPruning(_.containsAnyPattern(BINARY_COMPARISON, IN)) {
7474
case e if !e.childrenResolved => e

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717

1818
package org.apache.spark.sql.catalyst.plans.logical
1919

20-
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression}
20+
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SubqueryExpression}
2121
import org.apache.spark.sql.catalyst.plans.QueryPlan
2222
import org.apache.spark.sql.catalyst.rules.RuleId
2323
import org.apache.spark.sql.catalyst.rules.UnknownRuleId
2424
import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin, TreePatternBits}
25+
import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION
2526
import org.apache.spark.sql.errors.QueryExecutionErrors
2627
import org.apache.spark.util.Utils
2728

@@ -155,6 +156,35 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan =>
155156
}
156157
}
157158

159+
/**
160+
* Similar to [[resolveOperatorsUpWithPruning]], but also applies the given partial function to
161+
* all the plans in the subqueries of all nodes. This method is useful when we want to rewrite the
162+
* whole plan, including its subqueries, in one go.
163+
*/
164+
def resolveOperatorsUpWithSubqueriesAndPruning(
165+
cond: TreePatternBits => Boolean,
166+
ruleId: RuleId = UnknownRuleId)(
167+
rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = {
168+
val visit: PartialFunction[LogicalPlan, LogicalPlan] =
169+
new PartialFunction[LogicalPlan, LogicalPlan] {
170+
override def isDefinedAt(x: LogicalPlan): Boolean = true
171+
172+
override def apply(plan: LogicalPlan): LogicalPlan = {
173+
val transformed = plan.transformExpressionsUpWithPruning(
174+
t => t.containsPattern(PLAN_EXPRESSION) && cond(t)
175+
) {
176+
case subquery: SubqueryExpression =>
177+
val newPlan =
178+
subquery.plan.resolveOperatorsUpWithSubqueriesAndPruning(cond, ruleId)(rule)
179+
subquery.withNewPlan(newPlan)
180+
}
181+
rule.applyOrElse[LogicalPlan, LogicalPlan](transformed, identity)
182+
}
183+
}
184+
185+
resolveOperatorsUpWithPruning(cond, ruleId)(visit)
186+
}
187+
158188
/** Similar to [[resolveOperatorsUp]], but does it top-down. */
159189
def resolveOperatorsDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = {
160190
resolveOperatorsDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule)

sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
package org.apache.spark.sql
1919

2020
import org.apache.spark.{SparkConf, SparkRuntimeException}
21-
import org.apache.spark.sql.catalyst.expressions.Attribute
21+
import org.apache.spark.sql.catalyst.expressions.{Attribute, EqualTo, GreaterThan, ScalarSubquery, StringRPad}
2222
import org.apache.spark.sql.catalyst.expressions.Cast.toSQLId
2323
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
24-
import org.apache.spark.sql.catalyst.plans.logical.Project
24+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Project}
2525
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
2626
import org.apache.spark.sql.connector.SchemaRequiredDataSource
2727
import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog
@@ -798,6 +798,38 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
798798
)
799799
}
800800
}
801+
802+
test(
803+
"SPARK-51732: rpad should be applied on attributes with same ExprId if those attributes " +
804+
"should be deduplicated 2"
805+
) {
806+
withSQLConf(
807+
SQLConf.READ_SIDE_CHAR_PADDING.key -> "false",
808+
SQLConf.LEGACY_NO_CHAR_PADDING_IN_PREDICATE.key -> "false"
809+
) {
810+
withTable("mytable") {
811+
sql(s"CREATE TABLE mytable(col CHAR(10))")
812+
val plan = sql(
813+
"""
814+
SELECT t1.col
815+
FROM mytable t1
816+
WHERE (SELECT count(*) AS cnt FROM mytable t2 WHERE (t1.col = t2.col)) > 0
817+
""".stripMargin).queryExecution.analyzed
818+
val subquery = plan.asInstanceOf[Project]
819+
.child.asInstanceOf[Filter]
820+
.condition.asInstanceOf[GreaterThan]
821+
.left.asInstanceOf[ScalarSubquery]
822+
val subqueryFilterCondition = subquery.plan.asInstanceOf[Aggregate]
823+
.child.asInstanceOf[Filter]
824+
.condition.asInstanceOf[EqualTo]
825+
826+
// rpad should be applied to both left and right hand side of t1.col = t2.col because the
827+
// attributes are deduplicated.
828+
assert(subqueryFilterCondition.left.isInstanceOf[StringRPad])
829+
assert(subqueryFilterCondition.right.isInstanceOf[StringRPad])
830+
}
831+
}
832+
}
801833
}
802834

803835
// Some basic char/varchar tests which doesn't rely on table implementation.

0 commit comments

Comments
 (0)