Skip to content

Commit f69fe09

Browse files
authored
Merge pull request #338 from mcarmonaa/fix/squash-natural-joins
internal/rule: fix squashjoins rule to squash projections properly too
2 parents 51f5d0c + 015bbfa commit f69fe09

File tree

2 files changed

+97
-1
lines changed

2 files changed

+97
-1
lines changed

integration_test.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,17 @@ func TestIntegration(t *testing.T) {
132132
{int32(4), "b029517f6300c2da0f4b651b8642506cd6aaf45d"},
133133
},
134134
},
135+
{
136+
`SELECT count(1), refs.repository_id
137+
FROM refs
138+
NATURAL JOIN commits
139+
NATURAL JOIN commit_blobs
140+
WHERE refs.ref_name = 'HEAD'
141+
GROUP BY refs.repository_id`,
142+
[]sql.Row{
143+
{int32(9), "worktree"},
144+
},
145+
},
135146
}
136147

137148
runTests := func(t *testing.T) {

internal/rule/squashjoins.go

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ func SquashJoins(
3131
defer span.Finish()
3232

3333
a.Log("squashing joins, node of type %T", n)
34+
35+
projectSquashes := countProjectSquashes(n)
36+
3437
n, err := n.TransformUp(func(n sql.Node) (sql.Node, error) {
3538
join, ok := n.(*plan.InnerJoin)
3639
if !ok {
@@ -39,18 +42,100 @@ func SquashJoins(
3942

4043
return squashJoin(join)
4144
})
45+
4246
if err != nil {
4347
return nil, err
4448
}
4549

46-
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
50+
n, err = n.TransformUp(func(n sql.Node) (sql.Node, error) {
4751
t, ok := n.(*joinedTables)
4852
if !ok {
4953
return n, nil
5054
}
5155

5256
return buildSquashedTable(t.tables, t.filters, t.columns, t.indexes)
5357
})
58+
59+
if err != nil {
60+
return nil, err
61+
}
62+
63+
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
64+
if projectSquashes <= 0 {
65+
return n, nil
66+
}
67+
68+
project, ok := n.(*plan.Project)
69+
if !ok {
70+
return n, nil
71+
}
72+
73+
child, ok := project.Child.(*plan.Project)
74+
if !ok {
75+
return n, nil
76+
}
77+
78+
squashedProject, ok := squashProjects(project, child)
79+
if !ok {
80+
return n, nil
81+
}
82+
83+
projectSquashes--
84+
return squashedProject, nil
85+
})
86+
}
87+
88+
func countProjectSquashes(n sql.Node) int {
89+
var squashableProjects int
90+
plan.Inspect(n, func(node sql.Node) bool {
91+
if project, ok := node.(*plan.Project); ok {
92+
if _, ok := project.Child.(*plan.InnerJoin); ok {
93+
squashableProjects++
94+
}
95+
}
96+
97+
return true
98+
})
99+
100+
return squashableProjects - 1
101+
}
102+
103+
func squashProjects(parent, child *plan.Project) (sql.Node, bool) {
104+
projections := []sql.Expression{}
105+
for _, expr := range parent.Expressions() {
106+
parentField, ok := expr.(*expression.GetField)
107+
if !ok {
108+
return nil, false
109+
}
110+
111+
index := parentField.Index()
112+
for _, e := range child.Expressions() {
113+
childField, ok := e.(*expression.GetField)
114+
if !ok {
115+
return nil, false
116+
}
117+
118+
if referenceSameColumn(parentField, childField) {
119+
index = childField.Index()
120+
}
121+
}
122+
123+
projection := expression.NewGetFieldWithTable(
124+
index,
125+
parentField.Type(),
126+
parentField.Table(),
127+
parentField.Name(),
128+
parentField.IsNullable(),
129+
)
130+
131+
projections = append(projections, projection)
132+
}
133+
134+
return plan.NewProject(projections, child.Child), true
135+
}
136+
137+
func referenceSameColumn(parent, child *expression.GetField) bool {
138+
return parent.Name() == child.Name() && parent.Table() == child.Table()
54139
}
55140

56141
func squashJoin(join *plan.InnerJoin) (sql.Node, error) {

0 commit comments

Comments
 (0)