Skip to content

Commit 2b978d8

Browse files
committed
fix
1 parent 75d80c7 commit 2b978d8

File tree

2 files changed

+32
-2
lines changed

2 files changed

+32
-2
lines changed

Diff for: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyCharTypePaddingHelper.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ object ApplyCharTypePaddingHelper {
6868
private[sql] def paddingForStringComparison(
6969
plan: LogicalPlan,
7070
padCharCol: Boolean): LogicalPlan = {
71-
plan.resolveOperatorsUpWithPruning(_.containsAnyPattern(BINARY_COMPARISON, IN)) {
71+
plan.resolveOperatorsUpWithSubqueriesAndPruning(_.containsAnyPattern(BINARY_COMPARISON, IN)) {
7272
case operator =>
7373
operator.transformExpressionsUpWithPruning(_.containsAnyPattern(BINARY_COMPARISON, IN)) {
7474
case e if !e.childrenResolved => e

Diff for: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala

+31-1
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717

1818
package org.apache.spark.sql.catalyst.plans.logical
1919

20-
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression}
20+
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SubqueryExpression}
2121
import org.apache.spark.sql.catalyst.plans.QueryPlan
2222
import org.apache.spark.sql.catalyst.rules.RuleId
2323
import org.apache.spark.sql.catalyst.rules.UnknownRuleId
24+
import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION
2425
import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin, TreePatternBits}
2526
import org.apache.spark.sql.errors.QueryExecutionErrors
2627
import org.apache.spark.util.Utils
@@ -155,6 +156,35 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan =>
155156
}
156157
}
157158

159+
/**
160+
* Similar to [[resolveOperatorsUpWithPruning]], but also applies the given partial function to
161+
* all the plans in the subqueries of all nodes. This method is useful when we want to rewrite the
162+
* whole plan, including its subqueries, in one go.
163+
*/
164+
def resolveOperatorsUpWithSubqueriesAndPruning(
165+
cond: TreePatternBits => Boolean,
166+
ruleId: RuleId = UnknownRuleId)(
167+
rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = {
168+
val visit: PartialFunction[LogicalPlan, LogicalPlan] =
169+
new PartialFunction[LogicalPlan, LogicalPlan] {
170+
override def isDefinedAt(x: LogicalPlan): Boolean = true
171+
172+
override def apply(plan: LogicalPlan): LogicalPlan = {
173+
val transformed = plan.transformExpressionsUpWithPruning(
174+
t => t.containsPattern(PLAN_EXPRESSION) && cond(t)
175+
) {
176+
case subquery: SubqueryExpression =>
177+
val newPlan =
178+
subquery.plan.resolveOperatorsUpWithSubqueriesAndPruning(cond, ruleId)(rule)
179+
subquery.withNewPlan(newPlan)
180+
}
181+
rule.applyOrElse[LogicalPlan, LogicalPlan](transformed, identity)
182+
}
183+
}
184+
185+
resolveOperatorsUpWithPruning(cond, ruleId)(visit)
186+
}
187+
158188
/** Similar to [[resolveOperatorsUp]], but does it top-down. */
159189
def resolveOperatorsDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = {
160190
resolveOperatorsDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule)

0 commit comments

Comments
 (0)