Skip to content

Commit 88a1913

Browse files
committed
Add filter()[-1] optimization
1 parent a875bba commit 88a1913

File tree

13 files changed

+250
-59
lines changed

13 files changed

+250
-59
lines changed

builtin/builtin.go

+13-3
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,16 @@ var Builtins = []*Function{
6060
Predicate: true,
6161
Types: types(new(func([]interface{}, func(interface{}) bool) []interface{})),
6262
},
63+
{
64+
Name: "map",
65+
Predicate: true,
66+
Types: types(new(func([]interface{}, func(interface{}) interface{}) []interface{})),
67+
},
68+
{
69+
Name: "count",
70+
Predicate: true,
71+
Types: types(new(func([]interface{}, func(interface{}) bool) int)),
72+
},
6373
{
6474
Name: "find",
6575
Predicate: true,
@@ -71,12 +81,12 @@ var Builtins = []*Function{
7181
Types: types(new(func([]interface{}, func(interface{}) bool) int)),
7282
},
7383
{
74-
Name: "map",
84+
Name: "findLast",
7585
Predicate: true,
76-
Types: types(new(func([]interface{}, func(interface{}) interface{}) []interface{})),
86+
Types: types(new(func([]interface{}, func(interface{}) bool) interface{})),
7787
},
7888
{
79-
Name: "count",
89+
Name: "findLastIndex",
8090
Predicate: true,
8191
Types: types(new(func([]interface{}, func(interface{}) bool) int)),
8292
},

checker/checker.go

+13-14
Original file line numberDiff line numberDiff line change
@@ -604,11 +604,11 @@ func (v *visitor) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) {
604604
closure.NumIn() == 1 && isAny(closure.In(0)) {
605605

606606
if !isBool(closure.Out(0)) && !isAny(closure.Out(0)) {
607-
return v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String())
607+
return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String())
608608
}
609609
return boolType, info{}
610610
}
611-
return v.error(node.Arguments[1], "closure should has one input and one output param")
611+
return v.error(node.Arguments[1], "predicate should has one input and one output param")
612612

613613
case "filter":
614614
collection, _ := v.visit(node.Arguments[0])
@@ -625,14 +625,14 @@ func (v *visitor) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) {
625625
closure.NumIn() == 1 && isAny(closure.In(0)) {
626626

627627
if !isBool(closure.Out(0)) && !isAny(closure.Out(0)) {
628-
return v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String())
628+
return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String())
629629
}
630630
if isAny(collection) {
631631
return arrayType, info{}
632632
}
633633
return reflect.SliceOf(collection.Elem()), info{}
634634
}
635-
return v.error(node.Arguments[1], "closure should has one input and one output param")
635+
return v.error(node.Arguments[1], "predicate should has one input and one output param")
636636

637637
case "map":
638638
collection, _ := v.visit(node.Arguments[0])
@@ -650,7 +650,7 @@ func (v *visitor) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) {
650650

651651
return reflect.SliceOf(closure.Out(0)), info{}
652652
}
653-
return v.error(node.Arguments[1], "closure should has one input and one output param")
653+
return v.error(node.Arguments[1], "predicate should has one input and one output param")
654654

655655
case "count":
656656
collection, _ := v.visit(node.Arguments[0])
@@ -666,14 +666,14 @@ func (v *visitor) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) {
666666
closure.NumOut() == 1 &&
667667
closure.NumIn() == 1 && isAny(closure.In(0)) {
668668
if !isBool(closure.Out(0)) && !isAny(closure.Out(0)) {
669-
return v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String())
669+
return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String())
670670
}
671671

672672
return integerType, info{}
673673
}
674-
return v.error(node.Arguments[1], "closure should has one input and one output param")
674+
return v.error(node.Arguments[1], "predicate should has one input and one output param")
675675

676-
case "find":
676+
case "find", "findLast":
677677
collection, _ := v.visit(node.Arguments[0])
678678
if !isArray(collection) && !isAny(collection) {
679679
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
@@ -688,16 +688,16 @@ func (v *visitor) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) {
688688
closure.NumIn() == 1 && isAny(closure.In(0)) {
689689

690690
if !isBool(closure.Out(0)) && !isAny(closure.Out(0)) {
691-
return v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String())
691+
return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String())
692692
}
693693
if isAny(collection) {
694694
return anyType, info{}
695695
}
696696
return collection.Elem(), info{}
697697
}
698-
return v.error(node.Arguments[1], "closure should has one input and one output param")
698+
return v.error(node.Arguments[1], "predicate should has one input and one output param")
699699

700-
case "findIndex":
700+
case "findIndex", "findLastIndex":
701701
collection, _ := v.visit(node.Arguments[0])
702702
if !isArray(collection) && !isAny(collection) {
703703
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
@@ -712,12 +712,11 @@ func (v *visitor) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) {
712712
closure.NumIn() == 1 && isAny(closure.In(0)) {
713713

714714
if !isBool(closure.Out(0)) && !isAny(closure.Out(0)) {
715-
return v.error(node.Arguments[1], "closure should return boolean (got %v)", closure.Out(0).String())
715+
return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String())
716716
}
717717
return integerType, info{}
718718
}
719-
return v.error(node.Arguments[1], "closure should has one input and one output param")
720-
719+
return v.error(node.Arguments[1], "predicate should has one input and one output param")
721720
}
722721

723722
if id, ok := builtin.Index[node.Name]; ok {

checker/checker_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -414,22 +414,22 @@ builtin count takes only array (got int) (1:7)
414414
| ......^
415415
416416
count(ArrayOfInt, {#})
417-
closure should return boolean (got int) (1:19)
417+
predicate should return boolean (got int) (1:19)
418418
| count(ArrayOfInt, {#})
419419
| ..................^
420420
421421
all(ArrayOfInt, {# + 1})
422-
closure should return boolean (got int) (1:17)
422+
predicate should return boolean (got int) (1:17)
423423
| all(ArrayOfInt, {# + 1})
424424
| ................^
425425
426426
filter(ArrayOfFoo, {.Bar.Baz})
427-
closure should return boolean (got string) (1:20)
427+
predicate should return boolean (got string) (1:20)
428428
| filter(ArrayOfFoo, {.Bar.Baz})
429429
| ...................^
430430
431431
find(ArrayOfFoo, {.Bar.Baz})
432-
closure should return boolean (got string) (1:18)
432+
predicate should return boolean (got string) (1:18)
433433
| find(ArrayOfFoo, {.Bar.Baz})
434434
| .................^
435435

compiler/compiler.go

+60
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,48 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
760760
c.patchJump(loopBreak)
761761
c.emit(OpEnd)
762762
return
763+
764+
case "findLast":
765+
c.compile(node.Arguments[0])
766+
c.emit(OpBegin)
767+
var loopBreak int
768+
c.emitLoopBackwards(func() {
769+
c.compile(node.Arguments[1])
770+
noop := c.emit(OpJumpIfFalse, placeholder)
771+
c.emit(OpPop)
772+
c.emit(OpPointer)
773+
loopBreak = c.emit(OpJump, placeholder)
774+
c.patchJump(noop)
775+
c.emit(OpPop)
776+
})
777+
if node.Throws {
778+
c.emit(OpPush, c.addConstant(fmt.Errorf("reflect: slice index out of range")))
779+
c.emit(OpThrow)
780+
} else {
781+
c.emit(OpNil)
782+
}
783+
c.patchJump(loopBreak)
784+
c.emit(OpEnd)
785+
return
786+
787+
case "findLastIndex":
788+
c.compile(node.Arguments[0])
789+
c.emit(OpBegin)
790+
var loopBreak int
791+
c.emitLoopBackwards(func() {
792+
c.compile(node.Arguments[1])
793+
noop := c.emit(OpJumpIfFalse, placeholder)
794+
c.emit(OpPop)
795+
c.emit(OpGetIndex)
796+
loopBreak = c.emit(OpJump, placeholder)
797+
c.patchJump(noop)
798+
c.emit(OpPop)
799+
})
800+
c.emit(OpNil)
801+
c.patchJump(loopBreak)
802+
c.emit(OpEnd)
803+
return
804+
763805
}
764806

765807
if id, ok := builtin.Index[node.Name]; ok {
@@ -801,6 +843,24 @@ func (c *compiler) emitLoop(body func()) {
801843
c.patchJump(end)
802844
}
803845

846+
func (c *compiler) emitLoopBackwards(body func()) {
847+
c.emit(OpGetLen)
848+
c.emit(OpInt, 1)
849+
c.emit(OpSubtract)
850+
c.emit(OpSetIndex)
851+
begin := len(c.bytecode)
852+
c.emit(OpGetIndex)
853+
c.emit(OpInt, 0)
854+
c.emit(OpMoreOrEqual)
855+
end := c.emit(OpJumpIfFalse, placeholder)
856+
857+
body()
858+
859+
c.emit(OpDecrementIndex)
860+
c.emit(OpJumpBackward, c.calcBackwardJump(begin))
861+
c.patchJump(end)
862+
}
863+
804864
func (c *compiler) ClosureNode(node *ast.ClosureNode) {
805865
c.compile(node.Node)
806866
}

expr_test.go

+35-21
Original file line numberDiff line numberDiff line change
@@ -984,33 +984,47 @@ func TestExpr(t *testing.T) {
984984
`first(filter(ArrayOfFoo, false))`,
985985
nil,
986986
},
987+
{
988+
`findLast(1..9, # % 2 == 0)`,
989+
8,
990+
},
991+
{
992+
`findLastIndex(1..9, # % 2 == 0)`,
993+
7,
994+
},
995+
{
996+
`filter(1..9, # % 2 == 0)[-1]`,
997+
8,
998+
},
999+
{
1000+
`last(filter(1..9, # % 2 == 0))`,
1001+
8,
1002+
},
9871003
}
9881004

9891005
for _, tt := range tests {
9901006
t.Run(tt.code, func(t *testing.T) {
991-
program, err := expr.Compile(tt.code, expr.Env(mock.Env{}))
992-
require.NoError(t, err, "compile error")
1007+
{
1008+
program, err := expr.Compile(tt.code, expr.Env(mock.Env{}))
1009+
require.NoError(t, err, "compile error")
9931010

994-
got, err := expr.Run(program, env)
995-
require.NoError(t, err, "run error")
996-
assert.Equal(t, tt.want, got)
997-
})
998-
}
999-
for _, tt := range tests {
1000-
t.Run("Unoptimized "+tt.code, func(t *testing.T) {
1001-
program, err := expr.Compile(tt.code, expr.Optimize(false))
1002-
require.NoError(t, err, "unoptimized")
1011+
got, err := expr.Run(program, env)
1012+
require.NoError(t, err, "run error")
1013+
assert.Equal(t, tt.want, got)
1014+
}
1015+
{
1016+
program, err := expr.Compile(tt.code, expr.Optimize(false))
1017+
require.NoError(t, err, "unoptimized")
10031018

1004-
got, err := expr.Run(program, env)
1005-
require.NoError(t, err, "unoptimized")
1006-
assert.Equal(t, tt.want, got, "unoptimized")
1007-
})
1008-
}
1009-
for _, tt := range tests {
1010-
t.Run("Eval "+tt.code, func(t *testing.T) {
1011-
got, err := expr.Eval(tt.code, env)
1012-
require.NoError(t, err, "eval")
1013-
assert.Equal(t, tt.want, got, "eval")
1019+
got, err := expr.Run(program, env)
1020+
require.NoError(t, err, "unoptimized")
1021+
assert.Equal(t, tt.want, got, "unoptimized")
1022+
}
1023+
{
1024+
got, err := expr.Eval(tt.code, env)
1025+
require.NoError(t, err, "eval")
1026+
assert.Equal(t, tt.want, got, "eval")
1027+
}
10141028
})
10151029
}
10161030
}

optimizer/filter_last.go

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package optimizer
2+
3+
import (
4+
. "github.com/antonmedv/expr/ast"
5+
)
6+
7+
type filterLast struct{}
8+
9+
func (*filterLast) Visit(node *Node) {
10+
if member, ok := (*node).(*MemberNode); ok && member.Property != nil && !member.Optional {
11+
if prop, ok := member.Property.(*IntegerNode); ok && prop.Value == -1 {
12+
if filter, ok := member.Node.(*BuiltinNode); ok &&
13+
filter.Name == "filter" &&
14+
len(filter.Arguments) == 2 {
15+
Patch(node, &BuiltinNode{
16+
Name: "findLast",
17+
Arguments: filter.Arguments,
18+
Throws: true, // to match the behavior of filter()[-1]
19+
})
20+
}
21+
}
22+
}
23+
if first, ok := (*node).(*BuiltinNode); ok &&
24+
first.Name == "last" &&
25+
len(first.Arguments) == 1 {
26+
if filter, ok := first.Arguments[0].(*BuiltinNode); ok &&
27+
filter.Name == "filter" &&
28+
len(filter.Arguments) == 2 {
29+
Patch(node, &BuiltinNode{
30+
Name: "findLast",
31+
Arguments: filter.Arguments,
32+
Throws: false, // as last() will return nil if not found
33+
})
34+
}
35+
}
36+
}

optimizer/optimizer.go

+1
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,6 @@ func Optimize(node *Node, config *conf.Config) error {
3535
Walk(node, &constRange{})
3636
Walk(node, &filterLen{})
3737
Walk(node, &filterFirst{})
38+
Walk(node, &filterLast{})
3839
return nil
3940
}

optimizer/optimizer_test.go

+56
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,59 @@ func TestOptimize_filter_first(t *testing.T) {
221221

222222
assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
223223
}
224+
225+
func TestOptimize_filter_minus_1(t *testing.T) {
226+
tree, err := parser.Parse(`filter(users, .Name == "Bob")[-1]`)
227+
require.NoError(t, err)
228+
229+
err = optimizer.Optimize(&tree.Node, nil)
230+
require.NoError(t, err)
231+
232+
expected := &ast.BuiltinNode{
233+
Name: "findLast",
234+
Arguments: []ast.Node{
235+
&ast.IdentifierNode{Value: "users"},
236+
&ast.ClosureNode{
237+
Node: &ast.BinaryNode{
238+
Operator: "==",
239+
Left: &ast.MemberNode{
240+
Node: &ast.PointerNode{},
241+
Property: &ast.StringNode{Value: "Name"},
242+
},
243+
Right: &ast.StringNode{Value: "Bob"},
244+
},
245+
},
246+
},
247+
Throws: true,
248+
}
249+
250+
assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
251+
}
252+
253+
func TestOptimize_filter_last(t *testing.T) {
254+
tree, err := parser.Parse(`last(filter(users, .Name == "Bob"))`)
255+
require.NoError(t, err)
256+
257+
err = optimizer.Optimize(&tree.Node, nil)
258+
require.NoError(t, err)
259+
260+
expected := &ast.BuiltinNode{
261+
Name: "findLast",
262+
Arguments: []ast.Node{
263+
&ast.IdentifierNode{Value: "users"},
264+
&ast.ClosureNode{
265+
Node: &ast.BinaryNode{
266+
Operator: "==",
267+
Left: &ast.MemberNode{
268+
Node: &ast.PointerNode{},
269+
Property: &ast.StringNode{Value: "Name"},
270+
},
271+
Right: &ast.StringNode{Value: "Bob"},
272+
},
273+
},
274+
},
275+
Throws: false,
276+
}
277+
278+
assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
279+
}

0 commit comments

Comments
 (0)