Skip to content

Commit ed3d1f3

Browse files
committed
Add reduce()
1 parent 82bb7df commit ed3d1f3

File tree

10 files changed

+112
-3
lines changed

10 files changed

+112
-3
lines changed

bench_test.go

+14
Original file line numberDiff line numberDiff line change
@@ -526,3 +526,17 @@ func Benchmark_countBy(b *testing.B) {
526526

527527
require.Equal(b, 14, out.(int))
528528
}
529+
530+
func Benchmark_reduce(b *testing.B) {
531+
program, err := expr.Compile(`reduce(1..100, # + #acc)`)
532+
require.NoError(b, err)
533+
534+
var out any
535+
b.ResetTimer()
536+
for n := 0; n < b.N; n++ {
537+
out, _ = vm.Run(program, nil)
538+
}
539+
b.StopTimer()
540+
541+
require.Equal(b, 5050, out.(int))
542+
}

builtin/builtin.go

+5
Original file line numberDiff line numberDiff line change
@@ -800,4 +800,9 @@ var Builtins = []*ast.Function{
800800
Predicate: true,
801801
Types: types(new(func([]any, func(any) any) map[any]int)),
802802
},
803+
{
804+
Name: "reduce",
805+
Predicate: true,
806+
Types: types(new(func([]any, func(any, any) any, any) any)),
807+
},
803808
}

builtin/builtin_test.go

+4
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ func TestBuiltin(t *testing.T) {
106106
{`groupBy(1..3, # > 1 ? nil : "")[nil]`, []any{2, 3}},
107107
{`groupBy(ArrayOfFoo, .Value).a`, []any{mock.Foo{Value: "a"}}},
108108
{`countBy(1..9, # % 2)`, map[any]int{0: 4, 1: 5}},
109+
{`reduce(1..9, # + #acc, 0)`, 45},
110+
{`reduce(1..9, # + #acc)`, 45},
111+
{`reduce([.5, 1.5, 2.5], # + #acc, 0)`, 4.5},
112+
{`reduce([], 5, 0)`, 0},
109113
}
110114

111115
for _, test := range tests {

checker/checker.go

+19
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,25 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) {
755755
}
756756
return v.error(node.Arguments[1], "predicate should has one input and one output param")
757757

758+
case "reduce":
759+
collection, _ := v.visit(node.Arguments[0])
760+
if !isArray(collection) && !isAny(collection) {
761+
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
762+
}
763+
764+
v.begin(collection, scopeVar{"index", integerType}, scopeVar{"acc", anyType})
765+
closure, _ := v.visit(node.Arguments[1])
766+
v.end()
767+
768+
if len(node.Arguments) == 3 {
769+
_, _ = v.visit(node.Arguments[2])
770+
}
771+
772+
if isFunc(closure) && closure.NumOut() == 1 {
773+
return closure.Out(0), info{}
774+
}
775+
return v.error(node.Arguments[1], "predicate should has two input and one output param")
776+
758777
}
759778

760779
if id, ok := builtin.Index[node.Name]; ok {

compiler/compiler.go

+21
Original file line numberDiff line numberDiff line change
@@ -836,6 +836,25 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
836836
c.emit(OpEnd)
837837
return
838838

839+
case "reduce":
840+
c.compile(node.Arguments[0])
841+
c.emit(OpBegin)
842+
if len(node.Arguments) == 3 {
843+
c.compile(node.Arguments[2])
844+
c.emit(OpSetAcc)
845+
} else {
846+
c.emit(OpPointer)
847+
c.emit(OpIncrementIndex)
848+
c.emit(OpSetAcc)
849+
}
850+
c.emitLoop(func() {
851+
c.compile(node.Arguments[1])
852+
c.emit(OpSetAcc)
853+
})
854+
c.emit(OpGetAcc)
855+
c.emit(OpEnd)
856+
return
857+
839858
}
840859

841860
if id, ok := builtin.Index[node.Name]; ok {
@@ -903,6 +922,8 @@ func (c *compiler) PointerNode(node *ast.PointerNode) {
903922
switch node.Name {
904923
case "index":
905924
c.emit(OpGetIndex)
925+
case "acc":
926+
c.emit(OpGetAcc)
906927
case "":
907928
c.emit(OpPointer)
908929
default:

debug/debugger.go

+3
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@ func StartDebugger(program *Program, env any) {
133133
if s.GroupBy != nil {
134134
keys = append(keys, pair{"GroupBy", s.GroupBy})
135135
}
136+
if s.Acc != nil {
137+
keys = append(keys, pair{"Acc", s.Acc})
138+
}
136139
row := 0
137140
for _, pair := range keys {
138141
scope.SetCellSimple(row, 0, fmt.Sprintf("%v: ", pair.key))

parser/parser.go

+28
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ var predicates = map[string]struct {
3030
"findLastIndex": {2},
3131
"groupBy": {2},
3232
"countBy": {2},
33+
"reduce": {3},
3334
}
3435

3536
type parser struct {
@@ -357,6 +358,9 @@ func (p *parser) parseCall(token Token) Node {
357358

358359
if b, ok := predicates[token.Value]; ok {
359360
p.expect(Bracket, "(")
361+
362+
// TODO: Refactor parser to use builtin.Builtins instead of predicates map.
363+
360364
if b.arity == 1 {
361365
arguments = make([]Node, 1)
362366
arguments[0] = p.parseExpression(0)
@@ -366,6 +370,18 @@ func (p *parser) parseCall(token Token) Node {
366370
p.expect(Operator, ",")
367371
arguments[1] = p.parseClosure()
368372
}
373+
374+
if token.Value == "reduce" {
375+
arguments = make([]Node, 2)
376+
arguments[0] = p.parseExpression(0)
377+
p.expect(Operator, ",")
378+
arguments[1] = p.parseClosure()
379+
if p.current.Is(Operator, ",") {
380+
p.next()
381+
arguments = append(arguments, p.parseExpression(0))
382+
}
383+
}
384+
369385
p.expect(Bracket, ")")
370386

371387
node = &BuiltinNode{
@@ -596,9 +612,21 @@ func (p *parser) parsePipe(node Node) Node {
596612

597613
if b, ok := predicates[identifier.Value]; ok {
598614
p.expect(Bracket, "(")
615+
616+
// TODO: Refactor parser to use builtin.Builtins instead of predicates map.
617+
599618
if b.arity == 2 {
600619
arguments = append(arguments, p.parseClosure())
601620
}
621+
622+
if identifier.Value == "reduce" {
623+
arguments = append(arguments, p.parseClosure())
624+
if p.current.Is(Operator, ",") {
625+
p.next()
626+
arguments = append(arguments, p.parseExpression(0))
627+
}
628+
}
629+
602630
p.expect(Bracket, ")")
603631

604632
node = &BuiltinNode{

vm/opcodes.go

+2
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,12 @@ const (
7474
OpGetLen
7575
OpGetGroupBy
7676
OpGetCountBy
77+
OpGetAcc
7778
OpPointer
7879
OpThrow
7980
OpGroupBy
8081
OpCountBy
82+
OpSetAcc
8183
OpBegin
8284
OpEnd // This opcode must be at the end of this list.
8385
)

vm/program.go

+9-3
Original file line numberDiff line numberDiff line change
@@ -295,21 +295,27 @@ func (program *Program) Opcodes(w io.Writer) {
295295
case OpGetCountBy:
296296
code("OpGetCountBy")
297297

298+
case OpGetAcc:
299+
code("OpGetAcc")
300+
298301
case OpPointer:
299302
code("OpPointer")
300303

301304
case OpThrow:
302305
code("OpThrow")
303306

304-
case OpBegin:
305-
code("OpBegin")
306-
307307
case OpGroupBy:
308308
code("OpGroupBy")
309309

310310
case OpCountBy:
311311
code("OpCountBy")
312312

313+
case OpSetAcc:
314+
code("OpSetAcc")
315+
316+
case OpBegin:
317+
code("OpBegin")
318+
313319
case OpEnd:
314320
code("OpEnd")
315321

vm/vm.go

+7
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ type Scope struct {
4545
Count int
4646
GroupBy map[any][]any
4747
CountBy map[any]int
48+
Acc any
4849
}
4950

5051
func Debug() *VM {
@@ -469,6 +470,12 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) {
469470
case OpGetCountBy:
470471
vm.push(vm.Scope().CountBy)
471472

473+
case OpGetAcc:
474+
vm.push(vm.Scope().Acc)
475+
476+
case OpSetAcc:
477+
vm.Scope().Acc = vm.pop()
478+
472479
case OpPointer:
473480
scope := vm.Scope()
474481
vm.push(scope.Array.Index(scope.Index).Interface())

0 commit comments

Comments
 (0)