Skip to content

[SPARK-51519][SQL] MERGE INTO/UPDATE/DELETE support join hint #50524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ dmlStatementNoWith
| fromClause multiInsertQueryBody+ #multiInsertQuery
| DELETE FROM identifierReference tableAlias whereClause? #deleteFromTable
| UPDATE identifierReference tableAlias setClause whereClause? #updateTable
| MERGE (WITH SCHEMA EVOLUTION)? INTO target=identifierReference targetAlias=tableAlias
| MERGE (hints+=hint)* (WITH SCHEMA EVOLUTION)? INTO target=identifierReference targetAlias=tableAlias
USING (source=identifierReference |
LEFT_PAREN sourceQuery=query RIGHT_PAREN) sourceAlias=tableAlias
ON mergeCondition=booleanExpression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, Attribu
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, DeleteAction, Filter, HintInfo, InsertAction, Join, JoinHint, LogicalPlan, MergeAction, MergeIntoTable, MergeRows, NO_BROADCAST_AND_REPLICATION, Project, ReplaceData, UpdateAction, WriteDelta}
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, DeleteAction, Filter, HintInfo, InsertAction, Join, JoinHint, LogicalPlan, MergeAction, MergeIntoTable, MergeRows, NO_BROADCAST_AND_REPLICATION, Project, ReplaceData, ResolvedHint, UpdateAction, WriteDelta}
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Discard, Instruction, Keep, ROW_ID, Split}
import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{OPERATION_COLUMN, WRITE_OPERATION, WRITE_WITH_METADATA_OPERATION}
import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations
Expand Down Expand Up @@ -52,27 +52,11 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
EliminateSubqueryAliases(aliasedTable) match {
case r: DataSourceV2Relation =>
validateMergeIntoConditions(m)
buildAppendDataPlan(r, r, source, cond, notMatchedActions)

// NOT MATCHED conditions may only refer to columns in source so they can be pushed down
val insertAction = notMatchedActions.head.asInstanceOf[InsertAction]
val filteredSource = insertAction.condition match {
case Some(insertCond) => Filter(insertCond, source)
case None => source
}

// there is only one NOT MATCHED action, use a left anti join to remove any matching rows
// and switch to using a regular append instead of a row-level MERGE operation
// only unmatched source rows that match the condition are appended to the table
val joinPlan = Join(filteredSource, r, LeftAnti, Some(cond), JoinHint.NONE)

val output = insertAction.assignments.map(_.value)
val outputColNames = r.output.map(_.name)
val projectList = output.zip(outputColNames).map { case (expr, name) =>
Alias(expr, name)()
}
val project = Project(projectList, joinPlan)

AppendData.byPosition(r, project)
case h @ ResolvedHint(r: DataSourceV2Relation, _) =>
validateMergeIntoConditions(m)
buildAppendDataPlan(r, h, source, cond, notMatchedActions)

case _ =>
m
Expand All @@ -85,35 +69,11 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
EliminateSubqueryAliases(aliasedTable) match {
case r: DataSourceV2Relation =>
validateMergeIntoConditions(m)
buildAppendDataPlanForMultipleNotMatchedActions(r, r, source, cond, notMatchedActions)

// there are only NOT MATCHED actions, use a left anti join to remove any matching rows
// and switch to using a regular append instead of a row-level MERGE operation
// only unmatched source rows that match action conditions are appended to the table
val joinPlan = Join(source, r, LeftAnti, Some(cond), JoinHint.NONE)

val notMatchedInstructions = notMatchedActions.map {
case InsertAction(cond, assignments) =>
Keep(cond.getOrElse(TrueLiteral), assignments.map(_.value))
case other =>
throw new AnalysisException(
errorClass = "_LEGACY_ERROR_TEMP_3053",
messageParameters = Map("other" -> other.toString))
}

val outputs = notMatchedInstructions.flatMap(_.outputs)

// merge rows as there are multiple NOT MATCHED actions
val mergeRows = MergeRows(
isSourceRowPresent = TrueLiteral,
isTargetRowPresent = FalseLiteral,
matchedInstructions = Nil,
notMatchedInstructions = notMatchedInstructions,
notMatchedBySourceInstructions = Nil,
checkCardinality = false,
output = generateExpandOutput(r.output, outputs),
joinPlan)

AppendData.byPosition(r, mergeRows)
case h @ ResolvedHint(r: DataSourceV2Relation, _) =>
validateMergeIntoConditions(m)
buildAppendDataPlanForMultipleNotMatchedActions(r, h, source, cond, notMatchedActions)

case _ =>
m
Expand All @@ -137,11 +97,92 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
notMatchedActions, notMatchedBySourceActions)
}

case h @ ResolvedHint(
r @ DataSourceV2Relation(tbl: SupportsRowLevelOperations, _, _, _, _), _) =>
validateMergeIntoConditions(m)
val table = buildOperationTable(tbl, MERGE, CaseInsensitiveStringMap.empty())
table.operation match {
case _: SupportsDelta =>
buildWriteDeltaPlan(
r, table, source, cond, matchedActions,
notMatchedActions, notMatchedBySourceActions, Some(h))
case _ =>
buildReplaceDataPlan(
r, table, source, cond, matchedActions,
notMatchedActions, notMatchedBySourceActions, Some(h))
}

case _ =>
m
}
}

// build a rewrite plan for sources that support appending data
private def buildAppendDataPlan(
relation: DataSourceV2Relation,
target: LogicalPlan,
source: LogicalPlan,
cond: Expression,
notMatchedActions: Seq[MergeAction]): AppendData = {
// NOT MATCHED conditions may only refer to columns in source so they can be pushed down
val insertAction = notMatchedActions.head.asInstanceOf[InsertAction]
val filteredSource = insertAction.condition match {
case Some(insertCond) => Filter(insertCond, source)
case None => source
}

// there is only one NOT MATCHED action, use a left anti join to remove any matching rows
// and switch to using a regular append instead of a row-level MERGE operation
// only unmatched source rows that match the condition are appended to the table
val joinPlan = Join(filteredSource, target, LeftAnti, Some(cond), JoinHint.NONE)

val output = insertAction.assignments.map(_.value)
val outputColNames = relation.output.map(_.name)
val projectList = output.zip(outputColNames).map { case (expr, name) =>
Alias(expr, name)()
}
val project = Project(projectList, joinPlan)

AppendData.byPosition(relation, project)
}

// build a rewrite plan for sources that support appending data have multiple not matched actions
private def buildAppendDataPlanForMultipleNotMatchedActions(
relation: DataSourceV2Relation,
target: LogicalPlan,
source: LogicalPlan,
cond: Expression,
notMatchedActions: Seq[MergeAction]): AppendData = {
// there are only NOT MATCHED actions, use a left anti join to remove any matching rows
// and switch to using a regular append instead of a row-level MERGE operation
// only unmatched source rows that match action conditions are appended to the table
val joinPlan = Join(source, target, LeftAnti, Some(cond), JoinHint.NONE)

val notMatchedInstructions = notMatchedActions.map {
case InsertAction(cond, assignments) =>
Keep(cond.getOrElse(TrueLiteral), assignments.map(_.value))
case other =>
throw new AnalysisException(
errorClass = "_LEGACY_ERROR_TEMP_3053",
messageParameters = Map("other" -> other.toString))
}

val outputs = notMatchedInstructions.flatMap(_.outputs)

// merge rows as there are multiple NOT MATCHED actions
val mergeRows = MergeRows(
isSourceRowPresent = TrueLiteral,
isTargetRowPresent = FalseLiteral,
matchedInstructions = Nil,
notMatchedInstructions = notMatchedInstructions,
notMatchedBySourceInstructions = Nil,
checkCardinality = false,
output = generateExpandOutput(relation.output, outputs),
joinPlan)

AppendData.byPosition(relation, mergeRows)
}

// build a rewrite plan for sources that support replacing groups of data (e.g. files, partitions)
private def buildReplaceDataPlan(
relation: DataSourceV2Relation,
Expand All @@ -150,7 +191,8 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
cond: Expression,
matchedActions: Seq[MergeAction],
notMatchedActions: Seq[MergeAction],
notMatchedBySourceActions: Seq[MergeAction]): ReplaceData = {
notMatchedBySourceActions: Seq[MergeAction],
hintOption: Option[ResolvedHint] = None): ReplaceData = {

// resolve all required metadata attrs that may be used for grouping data on write
// for instance, JDBC data source may cluster data by shard/host before writing
Expand All @@ -159,12 +201,16 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
// construct a read relation and include all required metadata columns
val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs)

val target = hintOption.map { resolvedHint =>
resolvedHint.withNewChildren(Seq(readRelation))
}.getOrElse(readRelation)

val checkCardinality = shouldCheckCardinality(matchedActions)

// use left outer join if there is no NOT MATCHED action, unmatched source rows can be discarded
// use full outer join in all other cases, unmatched source rows may be needed
val joinType = if (notMatchedActions.isEmpty) LeftOuter else FullOuter
val joinPlan = join(readRelation, source, joinType, cond, checkCardinality)
val joinPlan = join(target, source, joinType, cond, checkCardinality)

val mergeRowsPlan = buildReplaceDataMergeRowsPlan(
readRelation, joinPlan, matchedActions, notMatchedActions,
Expand Down Expand Up @@ -258,7 +304,8 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
cond: Expression,
matchedActions: Seq[MergeAction],
notMatchedActions: Seq[MergeAction],
notMatchedBySourceActions: Seq[MergeAction]): WriteDelta = {
notMatchedBySourceActions: Seq[MergeAction],
hintOption: Option[ResolvedHint] = None): WriteDelta = {

val operation = operationTable.operation.asInstanceOf[SupportsDelta]

Expand All @@ -277,11 +324,14 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
} else {
(readRelation, cond)
}
val target = hintOption.map { resolvedHint =>
resolvedHint.withNewChildren(Seq(filteredReadRelation))
}.getOrElse(filteredReadRelation)

val checkCardinality = shouldCheckCardinality(matchedActions)

val joinType = chooseWriteDeltaJoinType(notMatchedActions, notMatchedBySourceActions)
val joinPlan = join(filteredReadRelation, source, joinType, joinCond, checkCardinality)
val joinPlan = join(target, source, joinType, joinCond, checkCardinality)

val mergeRowsPlan = buildWriteDeltaMergeRowsPlan(
readRelation, joinPlan, matchedActions, notMatchedActions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1109,14 +1109,15 @@ class AstBuilder extends DataTypeAstBuilder
matchedActions, notMatchedActions, notMatchedBySourceActions))
val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE")
val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable)
MergeIntoTable(
val plan: LogicalPlan = MergeIntoTable(
aliasedTarget,
aliasedSource,
mergeCondition,
matchedActions,
notMatchedActions,
notMatchedBySourceActions,
withSchemaEvolution)
ctx.hints.asScala.foldRight(plan)(withHints)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,7 @@ case class MergeIntoTable(
lazy val rewritable: Boolean = {
EliminateSubqueryAliases(targetTable) match {
case DataSourceV2Relation(_: SupportsRowLevelOperations, _, _, _, _) => true
case ResolvedHint(DataSourceV2Relation(_: SupportsRowLevelOperations, _, _, _, _), _) => true
case _ => false
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,14 @@ class InvokeProcedures(session: SparkSession) extends Rule[LogicalPlan] {
CommandResult(
Seq.empty,
call,
call,
LocalTableScanExec(Seq.empty, Seq.empty, None),
Seq.empty)
case Seq(relation: LocalRelation) =>
CommandResult(
relation.output,
call,
call,
LocalTableScanExec(relation.output, relation.data, None),
relation.data)
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.sql.execution.SparkPlan
case class CommandResult(
output: Seq[Attribute],
@transient commandLogicalPlan: LogicalPlan,
@transient commandOptimizedLogicalPlan: LogicalPlan,
@transient commandPhysicalPlan: SparkPlan,
@transient rows: Seq[InternalRow]) extends LeafNode {
override def innerChildren: Seq[QueryPlan[_]] = Seq(commandLogicalPlan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class QueryExecution(
CommandResult(
qe.analyzed.output,
qe.commandExecuted,
qe.optimizedPlan,
qe.executedPlan,
result.toImmutableArraySeq)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.LocalRelation(output, data, _, stream) =>
LocalTableScanExec(output, data, stream) :: Nil
case logical.EmptyRelation(l) => EmptyRelationExec(l) :: Nil
case CommandResult(output, _, plan, data) => CommandResultExec(output, plan, data) :: Nil
case CommandResult(output, _, _, plan, data) => CommandResultExec(output, plan, data) :: Nil
// We should match the combination of limit and offset first, to get the optimal physical
// plan, instead of planning limit and offset separately.
case LimitAndOffset(limit, offset, child) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase {
val df = sql(s"INSERT INTO $t1 WITH (`write.split-size` = 10) VALUES (1, 'a'), (2, 'b')")

var collected = df.queryExecution.optimizedPlan.collect {
case CommandResult(_, AppendData(relation: DataSourceV2Relation, _, _, _, _, _), _, _) =>
case CommandResult(_, AppendData(relation: DataSourceV2Relation, _, _, _, _, _), _, _, _) =>
assert(relation.options.get("write.split-size") == "10")
}
assert (collected.size == 1)
Expand Down Expand Up @@ -187,7 +187,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase {
var collected = df.queryExecution.optimizedPlan.collect {
case CommandResult(_,
OverwriteByExpression(relation: DataSourceV2Relation, _, _, _, _, _, _),
_, _) =>
_, _, _) =>
assert(relation.options.get("write.split-size") === "10")
}
assert (collected.size == 1)
Expand Down Expand Up @@ -247,7 +247,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase {
var collected = df.queryExecution.optimizedPlan.collect {
case CommandResult(_,
OverwriteByExpression(relation: DataSourceV2Relation, _, _, _, _, _, _),
_, _) =>
_, _, _) =>
assert(relation.options.get("write.split-size") == "10")
}
assert (collected.size == 1)
Expand Down
Loading