Skip to content

planner: fix doesn't push down hash join to tiflash #60436

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/planner/core/casetest/tpch/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ go_test(
],
data = glob(["testdata/**"]),
flaky = True,
shard_count = 3,
shard_count = 4,
deps = [
"//pkg/config",
"//pkg/testkit",
Expand Down
6 changes: 6 additions & 0 deletions pkg/planner/core/casetest/tpch/testdata/tpch_suite_in.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
"explain format='brief' SELECT /*+ SHUFFLE_JOIN(orders, lineitem) */ o.o_orderdate, SUM(l.l_extendedprice * (1 - l.l_discount)) AS revenue FROM orders AS o JOIN lineitem AS l ON o.o_orderkey = l.l_orderkey WHERE o.o_orderdate BETWEEN '1994-01-01' AND '1994-12-31' GROUP BY o.o_orderdate ORDER BY revenue DESC LIMIT 10;"
]
},
{
"name": "TestQ9",
"cases": [
"explain format='brief' SELECT nation, o_year, SUM(amount) AS sum_profit FROM (SELECT n_name AS nation, EXTRACT(YEAR FROM o_orderdate) AS o_year, l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity AS amount FROM part, supplier, lineitem, partsupp, orders, nation WHERE s_suppkey = l_suppkey AND ps_suppkey = l_suppkey AND ps_partkey = l_partkey AND p_partkey = l_partkey AND o_orderkey = l_orderkey AND s_nationkey = n_nationkey AND p_name LIKE '%dim%') AS profit GROUP BY nation, o_year ORDER BY nation, o_year DESC;"
]
},
{
"name": "TestQ13",
"cases": [
Expand Down
58 changes: 58 additions & 0 deletions pkg/planner/core/casetest/tpch/testdata/tpch_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,64 @@
}
]
},
{
"Name": "TestQ9",
"Cases": [
{
"SQL": "explain format='brief' SELECT nation, o_year, SUM(amount) AS sum_profit FROM (SELECT n_name AS nation, EXTRACT(YEAR FROM o_orderdate) AS o_year, l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity AS amount FROM part, supplier, lineitem, partsupp, orders, nation WHERE s_suppkey = l_suppkey AND ps_suppkey = l_suppkey AND ps_partkey = l_partkey AND p_partkey = l_partkey AND o_orderkey = l_orderkey AND s_nationkey = n_nationkey AND p_name LIKE '%dim%') AS profit GROUP BY nation, o_year ORDER BY nation, o_year DESC;",
"Result": [
"Sort 8000.00 root test.nation.n_name, Column#52:desc",
"└─TableReader 8000.00 root MppVersion: 3, data:ExchangeSender",
" └─ExchangeSender 8000.00 mpp[tiflash] ExchangeType: PassThrough",
" └─Projection 8000.00 mpp[tiflash] test.nation.n_name, Column#52, Column#54",
" └─Projection 8000.00 mpp[tiflash] Column#54, test.nation.n_name, Column#52",
" └─HashAgg 8000.00 mpp[tiflash] group by:Column#72, test.nation.n_name, funcs:sum(Column#73)->Column#54, funcs:firstrow(test.nation.n_name)->test.nation.n_name, funcs:firstrow(Column#72)->Column#52",
" └─ExchangeReceiver 8000.00 mpp[tiflash] ",
" └─ExchangeSender 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.nation.n_name, collate: utf8mb4_bin]",
" └─HashAgg 8000.00 mpp[tiflash] group by:Column#77, Column#78, funcs:sum(Column#76)->Column#73",
" └─Projection 24414.06 mpp[tiflash] minus(mul(test.lineitem.l_extendedprice, minus(1, test.lineitem.l_discount)), mul(test.partsupp.ps_supplycost, test.lineitem.l_quantity))->Column#76, test.nation.n_name->Column#77, extract(YEAR, test.orders.o_orderdate)->Column#78",
" └─Projection 24414.06 mpp[tiflash] test.lineitem.l_quantity, test.lineitem.l_extendedprice, test.lineitem.l_discount, test.partsupp.ps_supplycost, test.orders.o_orderdate, test.nation.n_name",
" └─Projection 24414.06 mpp[tiflash] test.lineitem.l_quantity, test.lineitem.l_extendedprice, test.lineitem.l_discount, test.partsupp.ps_supplycost, test.orders.o_orderdate, test.nation.n_name, test.supplier.s_nationkey",
" └─HashJoin 24414.06 mpp[tiflash] inner join, equal:[eq(test.supplier.s_nationkey, test.nation.n_nationkey)]",
" ├─ExchangeReceiver(Build) 10000.00 mpp[tiflash] ",
" │ └─ExchangeSender 10000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.nation.n_nationkey, collate: binary]",
" │ └─TableFullScan 10000.00 mpp[tiflash] table:nation keep order:false, stats:pseudo",
" └─ExchangeReceiver(Probe) 19531.25 mpp[tiflash] ",
" └─ExchangeSender 19531.25 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.supplier.s_nationkey, collate: binary]",
" └─Projection 19531.25 mpp[tiflash] test.lineitem.l_quantity, test.lineitem.l_extendedprice, test.lineitem.l_discount, test.supplier.s_nationkey, test.partsupp.ps_supplycost, test.orders.o_orderdate, test.lineitem.l_orderkey",
" └─HashJoin 19531.25 mpp[tiflash] inner join, equal:[eq(test.lineitem.l_orderkey, test.orders.o_orderkey)]",
" ├─ExchangeReceiver(Build) 10000.00 mpp[tiflash] ",
" │ └─ExchangeSender 10000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.orders.o_orderkey, collate: binary]",
" │ └─TableFullScan 10000.00 mpp[tiflash] table:orders keep order:false, stats:pseudo",
" └─ExchangeReceiver(Probe) 15625.00 mpp[tiflash] ",
" └─ExchangeSender 15625.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.lineitem.l_orderkey, collate: binary]",
" └─Projection 15625.00 mpp[tiflash] test.lineitem.l_orderkey, test.lineitem.l_quantity, test.lineitem.l_extendedprice, test.lineitem.l_discount, test.supplier.s_nationkey, test.partsupp.ps_supplycost, test.lineitem.l_suppkey, test.lineitem.l_partkey",
" └─HashJoin 15625.00 mpp[tiflash] inner join, equal:[eq(test.lineitem.l_suppkey, test.partsupp.ps_suppkey) eq(test.lineitem.l_partkey, test.partsupp.ps_partkey)]",
" ├─ExchangeReceiver(Build) 10000.00 mpp[tiflash] ",
" │ └─ExchangeSender 10000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.partsupp.ps_suppkey, collate: binary], [name: test.partsupp.ps_partkey, collate: binary]",
" │ └─TableFullScan 10000.00 mpp[tiflash] table:partsupp keep order:false, stats:pseudo",
" └─ExchangeReceiver(Probe) 12500.00 mpp[tiflash] ",
" └─ExchangeSender 12500.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.lineitem.l_suppkey, collate: binary], [name: test.lineitem.l_partkey, collate: binary]",
" └─Projection 12500.00 mpp[tiflash] test.lineitem.l_orderkey, test.lineitem.l_partkey, test.lineitem.l_suppkey, test.lineitem.l_quantity, test.lineitem.l_extendedprice, test.lineitem.l_discount, test.supplier.s_nationkey, test.supplier.s_suppkey",
" └─HashJoin 12500.00 mpp[tiflash] inner join, equal:[eq(test.lineitem.l_suppkey, test.supplier.s_suppkey)]",
" ├─ExchangeReceiver(Build) 10000.00 mpp[tiflash] ",
" │ └─ExchangeSender 10000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.lineitem.l_suppkey, collate: binary]",
" │ └─Projection 10000.00 mpp[tiflash] test.lineitem.l_orderkey, test.lineitem.l_partkey, test.lineitem.l_suppkey, test.lineitem.l_quantity, test.lineitem.l_extendedprice, test.lineitem.l_discount",
" │ └─HashJoin 10000.00 mpp[tiflash] inner join, equal:[eq(test.part.p_partkey, test.lineitem.l_partkey)]",
" │ ├─ExchangeReceiver(Build) 8000.00 mpp[tiflash] ",
" │ │ └─ExchangeSender 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.part.p_partkey, collate: binary]",
" │ │ └─Selection 8000.00 mpp[tiflash] like(test.part.p_name, \"%dim%\", 92)",
" │ │ └─TableFullScan 10000.00 mpp[tiflash] table:part pushed down filter:empty, keep order:false, stats:pseudo",
" │ └─ExchangeReceiver(Probe) 10000.00 mpp[tiflash] ",
" │ └─ExchangeSender 10000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.lineitem.l_partkey, collate: binary]",
" │ └─TableFullScan 10000.00 mpp[tiflash] table:lineitem keep order:false, stats:pseudo",
" └─ExchangeReceiver(Probe) 10000.00 mpp[tiflash] ",
" └─ExchangeSender 10000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.supplier.s_suppkey, collate: binary]",
" └─TableFullScan 10000.00 mpp[tiflash] table:supplier keep order:false, stats:pseudo"
]
}
]
},
{
"Name": "TestQ13",
"Cases": [
Expand Down
103 changes: 103 additions & 0 deletions pkg/planner/core/casetest/tpch/tpch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,109 @@ CREATE TABLE lineitem (
}
}

func TestQ9(t *testing.T) {
store, dom := testkit.CreateMockStoreAndDomain(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec(`
CREATE TABLE lineitem (
L_ORDERKEY bigint NOT NULL,
L_PARTKEY bigint NOT NULL,
L_SUPPKEY bigint NOT NULL,
L_LINENUMBER bigint NOT NULL,
L_QUANTITY decimal(15,2) NOT NULL,
L_EXTENDEDPRICE decimal(15,2) NOT NULL,
L_DISCOUNT decimal(15,2) NOT NULL,
L_TAX decimal(15,2) NOT NULL,
L_RETURNFLAG char(1) NOT NULL,
L_LINESTATUS char(1) NOT NULL,
L_SHIPDATE date NOT NULL,
L_COMMITDATE date NOT NULL,
L_RECEIPTDATE date NOT NULL,
L_SHIPINSTRUCT char(25) NOT NULL,
L_SHIPMODE char(10) NOT NULL,
L_COMMENT varchar(44) NOT NULL,
PRIMARY KEY (L_ORDERKEY, L_LINENUMBER) /*T![clustered_index] CLUSTERED */
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;
`)
tk.MustExec(`CREATE TABLE nation (
N_NATIONKEY bigint NOT NULL,
N_NAME char(25) NOT NULL,
N_REGIONKEY bigint NOT NULL,
N_COMMENT varchar(152) DEFAULT NULL,
PRIMARY KEY (N_NATIONKEY) /*T![clustered_index] CLUSTERED */
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin`)
tk.MustExec(`
CREATE TABLE orders (
O_ORDERKEY bigint NOT NULL,
O_CUSTKEY bigint NOT NULL,
O_ORDERSTATUS char(1) NOT NULL,
O_TOTALPRICE decimal(15,2) NOT NULL,
O_ORDERDATE date NOT NULL,
O_ORDERPRIORITY char(15) NOT NULL,
O_CLERK char(15) NOT NULL,
O_SHIPPRIORITY bigint NOT NULL,
O_COMMENT varchar(79) NOT NULL,
PRIMARY KEY (O_ORDERKEY) /*T![clustered_index] CLUSTERED */
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;`)
tk.MustExec(`CREATE TABLE part (
P_PARTKEY bigint NOT NULL,
P_NAME varchar(55) NOT NULL,
P_MFGR char(25) NOT NULL,
P_BRAND char(10) NOT NULL,
P_TYPE varchar(25) NOT NULL,
P_SIZE bigint NOT NULL,
P_CONTAINER char(10) NOT NULL,
P_RETAILPRICE decimal(15,2) NOT NULL,
P_COMMENT varchar(23) NOT NULL,
PRIMARY KEY (P_PARTKEY) /*T![clustered_index] CLUSTERED */
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin`)
tk.MustExec(`CREATE TABLE partsupp (
PS_PARTKEY bigint NOT NULL,
PS_SUPPKEY bigint NOT NULL,
PS_AVAILQTY bigint NOT NULL,
PS_SUPPLYCOST decimal(15,2) NOT NULL,
PS_COMMENT varchar(199) NOT NULL,
PRIMARY KEY (PS_PARTKEY,PS_SUPPKEY) /*T![clustered_index] NONCLUSTERED */
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin`)
tk.MustExec(`CREATE TABLE supplier (
S_SUPPKEY bigint NOT NULL,
S_NAME char(25) NOT NULL,
S_ADDRESS varchar(40) NOT NULL,
S_NATIONKEY bigint NOT NULL,
S_PHONE char(15) NOT NULL,
S_ACCTBAL decimal(15,2) NOT NULL,
S_COMMENT varchar(101) NOT NULL,
PRIMARY KEY (S_SUPPKEY) /*T![clustered_index] CLUSTERED */
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin`)
tk.MustExec("set @@session.tidb_broadcast_join_threshold_size = 0")
tk.MustExec("set @@session.tidb_broadcast_join_threshold_count = 0")
testkit.SetTiFlashReplica(t, dom, "test", "orders")
testkit.SetTiFlashReplica(t, dom, "test", "lineitem")
testkit.SetTiFlashReplica(t, dom, "test", "nation")
testkit.SetTiFlashReplica(t, dom, "test", "part")
testkit.SetTiFlashReplica(t, dom, "test", "partsupp")
testkit.SetTiFlashReplica(t, dom, "test", "supplier")
integrationSuiteData := GetTPCHSuiteData()
var (
input []string
output []struct {
SQL string
Result []string
}
)
integrationSuiteData.LoadTestCases(t, &input, &output)
for i := 0; i < len(input); i++ {
testdata.OnRecord(func() {
output[i].SQL = input[i]
})
testdata.OnRecord(func() {
output[i].Result = testdata.ConvertRowsToStrings(tk.MustQuery(input[i]).Rows())
})
tk.MustQuery(input[i]).Check(testkit.Rows(output[i].Result...))
}
}

func TestQ13(t *testing.T) {
store, dom := testkit.CreateMockStoreAndDomain(t)
tk := testkit.NewTestKit(t, store)
Expand Down
34 changes: 30 additions & 4 deletions pkg/planner/core/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,17 +431,43 @@ func (p *PhysicalHashJoin) convertPartitionKeysIfNeed(lTask, rTask *MppTask) (_,
return lTask, rTask
}

func (p *PhysicalHashJoin) enforceExchangerByBackup(task *MppTask, idx int, expectedCols int) *MppTask {
if backupHashProp := p.GetChildReqProps(idx); backupHashProp != nil {
if len(backupHashProp.MPPPartitionCols) == expectedCols {
return task.enforceExchangerImpl(backupHashProp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we first eliminate the ex and add it back if necessary. i afraid the complicated case, currently looks fine to me

}
}
return nil
}

func (p *PhysicalHashJoin) attach2TaskForMpp(tasks ...base.Task) base.Task {
rTask, rok := tasks[1].(*MppTask)
lTask, lok := tasks[0].(*MppTask)
const (
left = 0
right = 1
)
rTask, rok := tasks[right].(*MppTask)
lTask, lok := tasks[left].(*MppTask)
if !lok || !rok {
return base.InvalidTask
}
if p.mppShuffleJoin {
// protection check is case of some bugs
if len(lTask.hashCols) != len(rTask.hashCols) || len(lTask.hashCols) == 0 {
if len(lTask.hashCols) == 0 || len(rTask.hashCols) == 0 {
// if the hash columns are empty, this is very likely a bug.
return base.InvalidTask
}
if len(lTask.hashCols) != len(rTask.hashCols) {
// if the hash columns are not the same, The most likely scenario is that
// they have undergone exchange optimization, removing some hash columns.
// In this case, we need to restore them on the side that is missing.
if len(lTask.hashCols) < len(rTask.hashCols) {
lTask = p.enforceExchangerByBackup(lTask, left, len(rTask.hashCols))
} else {
rTask = p.enforceExchangerByBackup(rTask, right, len(lTask.hashCols))
}
if lTask == nil || rTask == nil {
return base.InvalidTask
}
}
lTask, rTask = p.convertPartitionKeysIfNeed(lTask, rTask)
}
p.SetChildren(lTask.Plan(), rTask.Plan())
Expand Down