Skip to content

Commit 9fc62bf

Browse files
committed
feat: add a call stack data structure
Signed-off-by: Donnie Adams <[email protected]>
1 parent 4b587e2 commit 9fc62bf

File tree

3 files changed

+114
-1
lines changed

3 files changed

+114
-1
lines changed

callstack.go

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package gptscript
2+
3+
type CallStack struct {
4+
Call CallFrame `json:"call"`
5+
SubCalls []CallStack `json:"subCalls"`
6+
}
7+
8+
func newCallStack(run *Run, parentID string) []CallStack {
9+
if len(run.callsByParentIDs[parentID]) == 0 {
10+
return nil
11+
}
12+
13+
var callStack []CallStack
14+
callIDs := run.callsByParentIDs[parentID]
15+
for len(callIDs) != 0 {
16+
callID := callIDs[0]
17+
callIDs = callIDs[1:]
18+
19+
cs := CallStack{
20+
Call: run.calls[callID],
21+
SubCalls: newCallStack(run, callID),
22+
}
23+
24+
callStack = append(callStack, cs)
25+
}
26+
27+
return callStack
28+
}

client_test.go

+54-1
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,6 @@ func TestEvaluateWithToolList(t *testing.T) {
188188
},
189189
&ToolDef{
190190
Name: "echo",
191-
Tools: []string{"sys.exec"},
192191
Description: "Echoes the input",
193192
Args: map[string]string{
194193
"input": "The string input to echo",
@@ -212,6 +211,60 @@ func TestEvaluateWithToolList(t *testing.T) {
212211
}
213212
}
214213

214+
func TestCallStack(t *testing.T) {
215+
shebang := "#!/bin/bash"
216+
if runtime.GOOS == "windows" {
217+
shebang = "#!/usr/bin/env powershell.exe"
218+
}
219+
tools := []fmt.Stringer{
220+
&ToolDef{
221+
Tools: []string{"echo"},
222+
Instructions: "echo hello there",
223+
},
224+
&ToolDef{
225+
Name: "echo",
226+
Description: "Echoes the input",
227+
Args: map[string]string{
228+
"input": "The string input to echo",
229+
},
230+
Instructions: shebang + "\n echo ${input}",
231+
},
232+
}
233+
234+
run, err := c.Evaluate(context.Background(), Options{}, tools...)
235+
if err != nil {
236+
t.Errorf("Error executing tool: %v", err)
237+
}
238+
239+
// Wait for the run to complete.
240+
_, err = run.Text()
241+
if err != nil {
242+
t.Fatalf("Error waiting for run to finish: %v", err)
243+
}
244+
245+
cs := run.CallStack()
246+
247+
if cs == nil {
248+
t.Fatalf("No call stack")
249+
}
250+
251+
if cs.Call.ID == "" {
252+
t.Errorf("Call ID is empty")
253+
}
254+
255+
if len(cs.SubCalls) != 1 {
256+
t.Fatalf("Call stack contains %d subcalls", len(cs.SubCalls))
257+
}
258+
259+
if cs.SubCalls[0].Call.ParentID != cs.Call.ID {
260+
t.Errorf("Unexpected call stack sub call parent: %s != %s", cs.SubCalls[0].Call.ParentID, cs.Call.ID)
261+
}
262+
263+
if cs.SubCalls[0].Call.Tool.Name != "echo" {
264+
t.Errorf("Unexpected tool name: %s", cs.SubCalls[0].Call.Tool.Name)
265+
}
266+
}
267+
215268
func TestEvaluateWithToolListAndSubTool(t *testing.T) {
216269
shebang := "#!/bin/bash"
217270
if runtime.GOOS == "windows" {

run.go

+32
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"maps"
1212
"net/http"
1313
"os/exec"
14+
"slices"
1415
"strconv"
1516
"sync"
1617
)
@@ -29,6 +30,7 @@ type Run struct {
2930

3031
callsLock sync.RWMutex
3132
calls map[string]CallFrame
33+
callsByParentIDs map[string][]string
3234
parentCallFrameID string
3335
rawOutput map[string]any
3436
output, errput string
@@ -82,6 +84,19 @@ func (r *Run) ParentCallFrame() (CallFrame, bool) {
8284
return r.calls[r.parentCallFrameID], true
8385
}
8486

87+
// CallStack returns a nested version of the current call stack of the run.
88+
func (r *Run) CallStack() *CallStack {
89+
r.callsLock.RLock()
90+
defer r.callsLock.RUnlock()
91+
92+
cs := newCallStack(r, "")
93+
if len(cs) == 0 {
94+
return nil
95+
}
96+
97+
return &cs[0]
98+
}
99+
85100
// ErrorOutput returns the stderr output of the gptscript.
86101
// Should only be called after Bytes or Text has returned an error.
87102
func (r *Run) ErrorOutput() string {
@@ -166,6 +181,10 @@ func (r *Run) NextChat(ctx context.Context, input string) (*Run, error) {
166181
}
167182

168183
func (r *Run) request(ctx context.Context, payload any) (err error) {
184+
if r.state.IsTerminal() {
185+
return fmt.Errorf("run is in terminal state and cannot be run again: state %q", r.state)
186+
}
187+
169188
var (
170189
req *http.Request
171190
url = fmt.Sprintf("%s/%s", r.url, r.requestPath)
@@ -228,6 +247,11 @@ func (r *Run) request(ctx context.Context, payload any) (err error) {
228247
r.lock.Unlock()
229248
}()
230249

250+
r.callsLock.Lock()
251+
r.calls = make(map[string]CallFrame)
252+
r.callsByParentIDs = make(map[string][]string)
253+
r.callsLock.Unlock()
254+
231255
for n := 0; n != 0 || err == nil; n, err = resp.Body.Read(buf) {
232256
for _, line := range bytes.Split(bytes.TrimSpace(append(frag, buf[:n]...)), []byte("\n\n")) {
233257
line = bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data: ")))
@@ -316,6 +340,8 @@ func (r *Run) request(ctx context.Context, payload any) (err error) {
316340
if r.parentCallFrameID == "" && event.Call.ParentID == "" {
317341
r.parentCallFrameID = event.Call.ID
318342
}
343+
344+
r.appendCallIDToParentID(event.Call.ID, event.Call.ParentID)
319345
r.callsLock.Unlock()
320346
}
321347

@@ -353,6 +379,12 @@ func (r *Run) request(ctx context.Context, payload any) (err error) {
353379
return nil
354380
}
355381

382+
func (r *Run) appendCallIDToParentID(callID, parentID string) {
383+
if !slices.Contains(r.callsByParentIDs[parentID], callID) {
384+
r.callsByParentIDs[parentID] = append(r.callsByParentIDs[parentID], callID)
385+
}
386+
}
387+
356388
type RunState string
357389

358390
func (rs RunState) IsTerminal() bool {

0 commit comments

Comments
 (0)