|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql.catalyst.plans.logical
|
19 | 19 |
|
20 |
| -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression} |
| 20 | +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SubqueryExpression} |
21 | 21 | import org.apache.spark.sql.catalyst.plans.QueryPlan
|
22 | 22 | import org.apache.spark.sql.catalyst.rules.RuleId
|
23 | 23 | import org.apache.spark.sql.catalyst.rules.UnknownRuleId
|
| 24 | +import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION |
24 | 25 | import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin, TreePatternBits}
|
25 | 26 | import org.apache.spark.sql.errors.QueryExecutionErrors
|
26 | 27 | import org.apache.spark.util.Utils
|
@@ -155,6 +156,35 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan =>
|
155 | 156 | }
|
156 | 157 | }
|
157 | 158 |
|
| 159 | + /** |
| 160 | + * Similar to [[resolveOperatorsUpWithPruning]], but also applies the given partial function to |
| 161 | + * all the plans in the subqueries of all nodes. This method is useful when we want to rewrite the |
| 162 | + * whole plan, including its subqueries, in one go. |
| 163 | + */ |
| 164 | + def resolveOperatorsUpWithSubqueriesAndPruning( |
| 165 | + cond: TreePatternBits => Boolean, |
| 166 | + ruleId: RuleId = UnknownRuleId)( |
| 167 | + rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { |
| 168 | + val visit: PartialFunction[LogicalPlan, LogicalPlan] = |
| 169 | + new PartialFunction[LogicalPlan, LogicalPlan] { |
| 170 | + override def isDefinedAt(x: LogicalPlan): Boolean = true |
| 171 | + |
| 172 | + override def apply(plan: LogicalPlan): LogicalPlan = { |
| 173 | + val transformed = plan.transformExpressionsUpWithPruning( |
| 174 | + t => t.containsPattern(PLAN_EXPRESSION) && cond(t) |
| 175 | + ) { |
| 176 | + case subquery: SubqueryExpression => |
| 177 | + val newPlan = |
| 178 | + subquery.plan.resolveOperatorsUpWithSubqueriesAndPruning(cond, ruleId)(rule) |
| 179 | + subquery.withNewPlan(newPlan) |
| 180 | + } |
| 181 | + rule.applyOrElse[LogicalPlan, LogicalPlan](transformed, identity) |
| 182 | + } |
| 183 | + } |
| 184 | + |
| 185 | + resolveOperatorsUpWithPruning(cond, ruleId)(visit) |
| 186 | + } |
| 187 | + |
158 | 188 | /** Similar to [[resolveOperatorsUp]], but does it top-down. */
|
159 | 189 | def resolveOperatorsDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = {
|
160 | 190 | resolveOperatorsDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule)
|
|
0 commit comments