Skip to content

Commit a875bba

Browse files
committed
Add filter()[0] optimization
1 parent 8b8934f commit a875bba

File tree

9 files changed

+147
-2
lines changed

9 files changed

+147
-2
lines changed

ast/node.go

+1
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ type BuiltinNode struct {
133133
base
134134
Name string
135135
Arguments []Node
136+
Throws bool
136137
}
137138

138139
type ClosureNode struct {

compiler/compiler.go

+7-2
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,12 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
733733
c.patchJump(noop)
734734
c.emit(OpPop)
735735
})
736-
c.emit(OpNil)
736+
if node.Throws {
737+
c.emit(OpPush, c.addConstant(fmt.Errorf("reflect: slice index out of range")))
738+
c.emit(OpThrow)
739+
} else {
740+
c.emit(OpNil)
741+
}
737742
c.patchJump(loopBreak)
738743
c.emit(OpEnd)
739744
return
@@ -751,7 +756,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
751756
c.patchJump(noop)
752757
c.emit(OpPop)
753758
})
754-
c.emit(OpPushInt, -1)
759+
c.emit(OpNil)
755760
c.patchJump(loopBreak)
756761
c.emit(OpEnd)
757762
return

expr_test.go

+39
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,18 @@ func TestExpr(t *testing.T) {
972972
`findIndex(ArrayOfFoo, .Value == "baz")`,
973973
2,
974974
},
975+
{
976+
`filter(ArrayOfFoo, .Value == "baz")[0]`,
977+
env.ArrayOfFoo[2],
978+
},
979+
{
980+
`first(filter(ArrayOfFoo, .Value == "baz"))`,
981+
env.ArrayOfFoo[2],
982+
},
983+
{
984+
`first(filter(ArrayOfFoo, false))`,
985+
nil,
986+
},
975987
}
976988

977989
for _, tt := range tests {
@@ -1003,6 +1015,33 @@ func TestExpr(t *testing.T) {
10031015
}
10041016
}
10051017

1018+
func TestExpr_error(t *testing.T) {
1019+
env := mock.Env{}
1020+
1021+
tests := []struct {
1022+
code string
1023+
want string
1024+
}{
1025+
{
1026+
`filter(1..9, # > 9)[0]`,
1027+
`reflect: slice index out of range (1:20)
1028+
| filter(1..9, # > 9)[0]
1029+
| ...................^`,
1030+
},
1031+
}
1032+
1033+
for _, tt := range tests {
1034+
t.Run(tt.code, func(t *testing.T) {
1035+
program, err := expr.Compile(tt.code, expr.Env(mock.Env{}))
1036+
require.NoError(t, err)
1037+
1038+
_, err = expr.Run(program, env)
1039+
require.Error(t, err)
1040+
assert.Equal(t, tt.want, err.Error())
1041+
})
1042+
}
1043+
}
1044+
10061045
func TestExpr_optional_chaining(t *testing.T) {
10071046
env := map[string]interface{}{}
10081047
program, err := expr.Compile("foo?.bar.baz", expr.Env(env), expr.AllowUndefinedVariables())

optimizer/filter_first.go

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package optimizer
2+
3+
import (
4+
. "github.com/antonmedv/expr/ast"
5+
)
6+
7+
type filterFirst struct{}
8+
9+
func (*filterFirst) Visit(node *Node) {
10+
if member, ok := (*node).(*MemberNode); ok && member.Property != nil && !member.Optional {
11+
if prop, ok := member.Property.(*IntegerNode); ok && prop.Value == 0 {
12+
if filter, ok := member.Node.(*BuiltinNode); ok &&
13+
filter.Name == "filter" &&
14+
len(filter.Arguments) == 2 {
15+
Patch(node, &BuiltinNode{
16+
Name: "find",
17+
Arguments: filter.Arguments,
18+
Throws: true, // to match the behavior of filter()[0]
19+
})
20+
}
21+
}
22+
}
23+
if first, ok := (*node).(*BuiltinNode); ok &&
24+
first.Name == "first" &&
25+
len(first.Arguments) == 1 {
26+
if filter, ok := first.Arguments[0].(*BuiltinNode); ok &&
27+
filter.Name == "filter" &&
28+
len(filter.Arguments) == 2 {
29+
Patch(node, &BuiltinNode{
30+
Name: "find",
31+
Arguments: filter.Arguments,
32+
Throws: false, // as first() will return nil if not found
33+
})
34+
}
35+
}
36+
}

optimizer/optimizer.go

+1
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,6 @@ func Optimize(node *Node, config *conf.Config) error {
3434
Walk(node, &inRange{})
3535
Walk(node, &constRange{})
3636
Walk(node, &filterLen{})
37+
Walk(node, &filterFirst{})
3738
return nil
3839
}

optimizer/optimizer_test.go

+56
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,59 @@ func TestOptimize_filter_len(t *testing.T) {
165165

166166
assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
167167
}
168+
169+
func TestOptimize_filter_0(t *testing.T) {
170+
tree, err := parser.Parse(`filter(users, .Name == "Bob")[0]`)
171+
require.NoError(t, err)
172+
173+
err = optimizer.Optimize(&tree.Node, nil)
174+
require.NoError(t, err)
175+
176+
expected := &ast.BuiltinNode{
177+
Name: "find",
178+
Arguments: []ast.Node{
179+
&ast.IdentifierNode{Value: "users"},
180+
&ast.ClosureNode{
181+
Node: &ast.BinaryNode{
182+
Operator: "==",
183+
Left: &ast.MemberNode{
184+
Node: &ast.PointerNode{},
185+
Property: &ast.StringNode{Value: "Name"},
186+
},
187+
Right: &ast.StringNode{Value: "Bob"},
188+
},
189+
},
190+
},
191+
Throws: true,
192+
}
193+
194+
assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
195+
}
196+
197+
func TestOptimize_filter_first(t *testing.T) {
198+
tree, err := parser.Parse(`first(filter(users, .Name == "Bob"))`)
199+
require.NoError(t, err)
200+
201+
err = optimizer.Optimize(&tree.Node, nil)
202+
require.NoError(t, err)
203+
204+
expected := &ast.BuiltinNode{
205+
Name: "find",
206+
Arguments: []ast.Node{
207+
&ast.IdentifierNode{Value: "users"},
208+
&ast.ClosureNode{
209+
Node: &ast.BinaryNode{
210+
Operator: "==",
211+
Left: &ast.MemberNode{
212+
Node: &ast.PointerNode{},
213+
Property: &ast.StringNode{Value: "Name"},
214+
},
215+
Right: &ast.StringNode{Value: "Bob"},
216+
},
217+
},
218+
},
219+
Throws: false,
220+
}
221+
222+
assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
223+
}

vm/opcodes.go

+1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ const (
7171
OpGetCount
7272
OpGetLen
7373
OpPointer
74+
OpThrow
7475
OpBegin
7576
OpEnd // This opcode must be at the end of this list.
7677
)

vm/program.go

+3
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,9 @@ func (program *Program) Opcodes(w io.Writer) {
286286
case OpPointer:
287287
code("OpPointer")
288288

289+
case OpThrow:
290+
code("OpThrow")
291+
289292
case OpBegin:
290293
code("OpBegin")
291294

vm/vm.go

+3
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,9 @@ func (vm *VM) Run(program *Program, env interface{}) (_ interface{}, err error)
457457
scope := vm.Scope()
458458
vm.push(scope.Array.Index(scope.Index).Interface())
459459

460+
case OpThrow:
461+
panic(vm.pop().(error))
462+
460463
case OpBegin:
461464
a := vm.pop()
462465
array := reflect.ValueOf(a)

0 commit comments

Comments
 (0)