Skip to content

Commit 80f169a

Browse files
committed
Add map(filter()) optimization
1 parent 88a1913 commit 80f169a

File tree

8 files changed

+138
-4
lines changed

8 files changed

+138
-4
lines changed

ast/node.go

+1
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ type BuiltinNode struct {
134134
Name string
135135
Arguments []Node
136136
Throws bool
137+
Map Node
137138
}
138139

139140
type ClosureNode struct {

compiler/compiler.go

+15-3
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,11 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
688688
c.compile(node.Arguments[1])
689689
c.emitCond(func() {
690690
c.emit(OpIncrementCount)
691-
c.emit(OpPointer)
691+
if node.Map != nil {
692+
c.compile(node.Map)
693+
} else {
694+
c.emit(OpPointer)
695+
}
692696
})
693697
})
694698
c.emit(OpGetCount)
@@ -728,7 +732,11 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
728732
c.compile(node.Arguments[1])
729733
noop := c.emit(OpJumpIfFalse, placeholder)
730734
c.emit(OpPop)
731-
c.emit(OpPointer)
735+
if node.Map != nil {
736+
c.compile(node.Map)
737+
} else {
738+
c.emit(OpPointer)
739+
}
732740
loopBreak = c.emit(OpJump, placeholder)
733741
c.patchJump(noop)
734742
c.emit(OpPop)
@@ -769,7 +777,11 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
769777
c.compile(node.Arguments[1])
770778
noop := c.emit(OpJumpIfFalse, placeholder)
771779
c.emit(OpPop)
772-
c.emit(OpPointer)
780+
if node.Map != nil {
781+
c.compile(node.Map)
782+
} else {
783+
c.emit(OpPointer)
784+
}
773785
loopBreak = c.emit(OpJump, placeholder)
774786
c.patchJump(noop)
775787
c.emit(OpPop)

expr_test.go

+28
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,34 @@ func TestExpr(t *testing.T) {
10001000
`last(filter(1..9, # % 2 == 0))`,
10011001
8,
10021002
},
1003+
{
1004+
`map(filter(1..9, # % 2 == 0), # * 2)`,
1005+
[]interface{}{4, 8, 12, 16},
1006+
},
1007+
{
1008+
`map(map(filter(1..9, # % 2 == 0), # * 2), # * 2)`,
1009+
[]interface{}{8, 16, 24, 32},
1010+
},
1011+
{
1012+
`first(map(filter(1..9, # % 2 == 0), # * 2))`,
1013+
4,
1014+
},
1015+
{
1016+
`map(filter(1..9, # % 2 == 0), # * 2)[-1]`,
1017+
16,
1018+
},
1019+
{
1020+
`len(map(filter(1..9, # % 2 == 0), # * 2))`,
1021+
4,
1022+
},
1023+
{
1024+
`len(filter(map(1..9, # * 2), # % 2 == 0))`,
1025+
9,
1026+
},
1027+
{
1028+
`first(filter(map(1..9, # * 2), # % 2 == 0))`,
1029+
2,
1030+
},
10031031
}
10041032

10051033
for _, tt := range tests {

optimizer/filter_first.go

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ func (*filterFirst) Visit(node *Node) {
1616
Name: "find",
1717
Arguments: filter.Arguments,
1818
Throws: true, // to match the behavior of filter()[0]
19+
Map: filter.Map,
1920
})
2021
}
2122
}
@@ -30,6 +31,7 @@ func (*filterFirst) Visit(node *Node) {
3031
Name: "find",
3132
Arguments: filter.Arguments,
3233
Throws: false, // as first() will return nil if not found
34+
Map: filter.Map,
3335
})
3436
}
3537
}

optimizer/filter_last.go

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ func (*filterLast) Visit(node *Node) {
1616
Name: "findLast",
1717
Arguments: filter.Arguments,
1818
Throws: true, // to match the behavior of filter()[-1]
19+
Map: filter.Map,
1920
})
2021
}
2122
}
@@ -30,6 +31,7 @@ func (*filterLast) Visit(node *Node) {
3031
Name: "findLast",
3132
Arguments: filter.Arguments,
3233
Throws: false, // as last() will return nil if not found
34+
Map: filter.Map,
3335
})
3436
}
3537
}

optimizer/filter_map.go

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package optimizer
2+
3+
import (
4+
. "github.com/antonmedv/expr/ast"
5+
)
6+
7+
type filterMap struct{}
8+
9+
func (*filterMap) Visit(node *Node) {
10+
if mapBuiltin, ok := (*node).(*BuiltinNode); ok &&
11+
mapBuiltin.Name == "map" &&
12+
len(mapBuiltin.Arguments) == 2 {
13+
if closure, ok := mapBuiltin.Arguments[1].(*ClosureNode); ok {
14+
if filter, ok := mapBuiltin.Arguments[0].(*BuiltinNode); ok &&
15+
filter.Name == "filter" &&
16+
filter.Map == nil /* not already optimized */ {
17+
Patch(node, &BuiltinNode{
18+
Name: "filter",
19+
Arguments: filter.Arguments,
20+
Map: closure.Node,
21+
})
22+
}
23+
}
24+
}
25+
}

optimizer/optimizer.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ func Optimize(node *Node, config *conf.Config) error {
3333
}
3434
Walk(node, &inRange{})
3535
Walk(node, &constRange{})
36+
Walk(node, &filterMap{})
3637
Walk(node, &filterLen{})
37-
Walk(node, &filterFirst{})
3838
Walk(node, &filterLast{})
39+
Walk(node, &filterFirst{})
3940
return nil
4041
}

optimizer/optimizer_test.go

+63
Original file line numberDiff line numberDiff line change
@@ -277,3 +277,66 @@ func TestOptimize_filter_last(t *testing.T) {
277277

278278
assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
279279
}
280+
281+
func TestOptimize_filter_map(t *testing.T) {
282+
tree, err := parser.Parse(`map(filter(users, .Name == "Bob"), .Age)`)
283+
require.NoError(t, err)
284+
285+
err = optimizer.Optimize(&tree.Node, nil)
286+
require.NoError(t, err)
287+
288+
expected := &ast.BuiltinNode{
289+
Name: "filter",
290+
Arguments: []ast.Node{
291+
&ast.IdentifierNode{Value: "users"},
292+
&ast.ClosureNode{
293+
Node: &ast.BinaryNode{
294+
Operator: "==",
295+
Left: &ast.MemberNode{
296+
Node: &ast.PointerNode{},
297+
Property: &ast.StringNode{Value: "Name"},
298+
},
299+
Right: &ast.StringNode{Value: "Bob"},
300+
},
301+
},
302+
},
303+
Map: &ast.MemberNode{
304+
Node: &ast.PointerNode{},
305+
Property: &ast.StringNode{Value: "Age"},
306+
},
307+
}
308+
309+
assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
310+
}
311+
312+
func TestOptimize_filter_map_first(t *testing.T) {
313+
tree, err := parser.Parse(`first(map(filter(users, .Name == "Bob"), .Age))`)
314+
require.NoError(t, err)
315+
316+
err = optimizer.Optimize(&tree.Node, nil)
317+
require.NoError(t, err)
318+
319+
expected := &ast.BuiltinNode{
320+
Name: "find",
321+
Arguments: []ast.Node{
322+
&ast.IdentifierNode{Value: "users"},
323+
&ast.ClosureNode{
324+
Node: &ast.BinaryNode{
325+
Operator: "==",
326+
Left: &ast.MemberNode{
327+
Node: &ast.PointerNode{},
328+
Property: &ast.StringNode{Value: "Name"},
329+
},
330+
Right: &ast.StringNode{Value: "Bob"},
331+
},
332+
},
333+
},
334+
Map: &ast.MemberNode{
335+
Node: &ast.PointerNode{},
336+
Property: &ast.StringNode{Value: "Age"},
337+
},
338+
Throws: false,
339+
}
340+
341+
assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
342+
}

0 commit comments

Comments
 (0)