@@ -31,6 +31,9 @@ func SquashJoins(
31
31
defer span .Finish ()
32
32
33
33
a .Log ("squashing joins, node of type %T" , n )
34
+
35
+ projectSquashes := countProjectSquashes (n )
36
+
34
37
n , err := n .TransformUp (func (n sql.Node ) (sql.Node , error ) {
35
38
join , ok := n .(* plan.InnerJoin )
36
39
if ! ok {
@@ -39,18 +42,100 @@ func SquashJoins(
39
42
40
43
return squashJoin (join )
41
44
})
45
+
42
46
if err != nil {
43
47
return nil , err
44
48
}
45
49
46
- return n .TransformUp (func (n sql.Node ) (sql.Node , error ) {
50
+ n , err = n .TransformUp (func (n sql.Node ) (sql.Node , error ) {
47
51
t , ok := n .(* joinedTables )
48
52
if ! ok {
49
53
return n , nil
50
54
}
51
55
52
56
return buildSquashedTable (t .tables , t .filters , t .columns , t .indexes )
53
57
})
58
+
59
+ if err != nil {
60
+ return nil , err
61
+ }
62
+
63
+ return n .TransformUp (func (n sql.Node ) (sql.Node , error ) {
64
+ if projectSquashes <= 0 {
65
+ return n , nil
66
+ }
67
+
68
+ project , ok := n .(* plan.Project )
69
+ if ! ok {
70
+ return n , nil
71
+ }
72
+
73
+ child , ok := project .Child .(* plan.Project )
74
+ if ! ok {
75
+ return n , nil
76
+ }
77
+
78
+ squashedProject , ok := squashProjects (project , child )
79
+ if ! ok {
80
+ return n , nil
81
+ }
82
+
83
+ projectSquashes --
84
+ return squashedProject , nil
85
+ })
86
+ }
87
+
88
+ func countProjectSquashes (n sql.Node ) int {
89
+ var squashableProjects int
90
+ plan .Inspect (n , func (node sql.Node ) bool {
91
+ if project , ok := node .(* plan.Project ); ok {
92
+ if _ , ok := project .Child .(* plan.InnerJoin ); ok {
93
+ squashableProjects ++
94
+ }
95
+ }
96
+
97
+ return true
98
+ })
99
+
100
+ return squashableProjects - 1
101
+ }
102
+
103
+ func squashProjects (parent , child * plan.Project ) (sql.Node , bool ) {
104
+ projections := []sql.Expression {}
105
+ for _ , expr := range parent .Expressions () {
106
+ parentField , ok := expr .(* expression.GetField )
107
+ if ! ok {
108
+ return nil , false
109
+ }
110
+
111
+ index := parentField .Index ()
112
+ for _ , e := range child .Expressions () {
113
+ childField , ok := e .(* expression.GetField )
114
+ if ! ok {
115
+ return nil , false
116
+ }
117
+
118
+ if referenceSameColumn (parentField , childField ) {
119
+ index = childField .Index ()
120
+ }
121
+ }
122
+
123
+ projection := expression .NewGetFieldWithTable (
124
+ index ,
125
+ parentField .Type (),
126
+ parentField .Table (),
127
+ parentField .Name (),
128
+ parentField .IsNullable (),
129
+ )
130
+
131
+ projections = append (projections , projection )
132
+ }
133
+
134
+ return plan .NewProject (projections , child .Child ), true
135
+ }
136
+
137
+ func referenceSameColumn (parent , child * expression.GetField ) bool {
138
+ return parent .Name () == child .Name () && parent .Table () == child .Table ()
54
139
}
55
140
56
141
func squashJoin (join * plan.InnerJoin ) (sql.Node , error ) {
0 commit comments