@@ -11,6 +11,7 @@ import (
11
11
"maps"
12
12
"net/http"
13
13
"os/exec"
14
+ "slices"
14
15
"strconv"
15
16
"sync"
16
17
)
@@ -29,6 +30,7 @@ type Run struct {
29
30
30
31
callsLock sync.RWMutex
31
32
calls map [string ]CallFrame
33
+ callsByParentIDs map [string ][]string
32
34
parentCallFrameID string
33
35
rawOutput map [string ]any
34
36
output , errput string
@@ -82,6 +84,19 @@ func (r *Run) ParentCallFrame() (CallFrame, bool) {
82
84
return r .calls [r .parentCallFrameID ], true
83
85
}
84
86
87
+ // CallStack returns a nested version of the current call stack of the run.
88
+ func (r * Run ) CallStack () * CallStack {
89
+ r .callsLock .RLock ()
90
+ defer r .callsLock .RUnlock ()
91
+
92
+ cs := newCallStack (r , "" )
93
+ if len (cs ) == 0 {
94
+ return nil
95
+ }
96
+
97
+ return & cs [0 ]
98
+ }
99
+
85
100
// ErrorOutput returns the stderr output of the gptscript.
86
101
// Should only be called after Bytes or Text has returned an error.
87
102
func (r * Run ) ErrorOutput () string {
@@ -166,6 +181,10 @@ func (r *Run) NextChat(ctx context.Context, input string) (*Run, error) {
166
181
}
167
182
168
183
func (r * Run ) request (ctx context.Context , payload any ) (err error ) {
184
+ if r .state .IsTerminal () {
185
+ return fmt .Errorf ("run is in terminal state and cannot be run again: state %q" , r .state )
186
+ }
187
+
169
188
var (
170
189
req * http.Request
171
190
url = fmt .Sprintf ("%s/%s" , r .url , r .requestPath )
@@ -228,6 +247,11 @@ func (r *Run) request(ctx context.Context, payload any) (err error) {
228
247
r .lock .Unlock ()
229
248
}()
230
249
250
+ r .callsLock .Lock ()
251
+ r .calls = make (map [string ]CallFrame )
252
+ r .callsByParentIDs = make (map [string ][]string )
253
+ r .callsLock .Unlock ()
254
+
231
255
for n := 0 ; n != 0 || err == nil ; n , err = resp .Body .Read (buf ) {
232
256
for _ , line := range bytes .Split (bytes .TrimSpace (append (frag , buf [:n ]... )), []byte ("\n \n " )) {
233
257
line = bytes .TrimSpace (bytes .TrimPrefix (line , []byte ("data: " )))
@@ -316,6 +340,8 @@ func (r *Run) request(ctx context.Context, payload any) (err error) {
316
340
if r .parentCallFrameID == "" && event .Call .ParentID == "" {
317
341
r .parentCallFrameID = event .Call .ID
318
342
}
343
+
344
+ r .appendCallIDToParentID (event .Call .ID , event .Call .ParentID )
319
345
r .callsLock .Unlock ()
320
346
}
321
347
@@ -353,6 +379,12 @@ func (r *Run) request(ctx context.Context, payload any) (err error) {
353
379
return nil
354
380
}
355
381
382
+ func (r * Run ) appendCallIDToParentID (callID , parentID string ) {
383
+ if ! slices .Contains (r .callsByParentIDs [parentID ], callID ) {
384
+ r .callsByParentIDs [parentID ] = append (r .callsByParentIDs [parentID ], callID )
385
+ }
386
+ }
387
+
356
388
type RunState string
357
389
358
390
func (rs RunState ) IsTerminal () bool {
0 commit comments