Skip to content

Commit 56d152d

Browse files
committed
Updated
1 parent 74300b6 commit 56d152d

File tree

8 files changed

+63
-5
lines changed

8 files changed

+63
-5
lines changed

agent.go

+4
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,8 @@ type Agent interface {
1111

1212
// Return the models
1313
Models(context.Context) ([]Model, error)
14+
15+
// Return a model by name, or nil if not found.
16+
// Panics on error.
17+
Model(context.Context, string) Model
1418
}

pkg/agent/agent.go

+9
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,15 @@ func (a *Agent) Models(ctx context.Context) ([]llm.Model, error) {
105105
return a.ListModels(ctx)
106106
}
107107

108+
// Return a model
109+
func (a *Agent) Model(ctx context.Context, name string) llm.Model {
110+
model, err := a.GetModel(ctx, name)
111+
if err != nil {
112+
panic(err)
113+
}
114+
return model
115+
}
116+
108117
// Return the models from list of agents
109118
func (a *Agent) ListModels(ctx context.Context, names ...string) ([]llm.Model, error) {
110119
var result error

pkg/anthropic/client.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414

1515
type Client struct {
1616
*client.Client
17+
cache map[string]llm.Model
1718
}
1819

1920
var _ llm.Agent = (*Client)(nil)
@@ -41,7 +42,10 @@ func New(ApiKey string, opts ...client.ClientOpt) (*Client, error) {
4142
}
4243

4344
// Return the client
44-
return &Client{client}, nil
45+
return &Client{
46+
Client: client,
47+
cache: make(map[string]llm.Model),
48+
}, nil
4549
}
4650

4751
///////////////////////////////////////////////////////////////////////////////

pkg/anthropic/model.go

+32-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,38 @@ type ModelMeta struct {
3434

3535
// Agent interface
3636
func (anthropic *Client) Models(ctx context.Context) ([]llm.Model, error) {
37-
return anthropic.ListModels(ctx)
37+
// Cache models
38+
if len(anthropic.cache) == 0 {
39+
models, err := anthropic.ListModels(ctx)
40+
if err != nil {
41+
return nil, err
42+
}
43+
for _, model := range models {
44+
name := model.Name()
45+
anthropic.cache[name] = model
46+
}
47+
}
48+
49+
// Return models
50+
result := make([]llm.Model, 0, len(anthropic.cache))
51+
for _, model := range anthropic.cache {
52+
result = append(result, model)
53+
}
54+
return result, nil
55+
}
56+
57+
// Agent interface
58+
func (anthropic *Client) Model(ctx context.Context, model string) llm.Model {
59+
// Cache models
60+
if len(anthropic.cache) == 0 {
61+
_, err := anthropic.Models(ctx)
62+
if err != nil {
63+
panic(err)
64+
}
65+
}
66+
67+
// Return model
68+
return anthropic.cache[model]
3869
}
3970

4071
// Get a model by name

pkg/anthropic/session_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,10 @@ func Test_session_002(t *testing.T) {
8383
t.FailNow()
8484
}
8585

86-
err := toolkit.Run(context.TODO(), session.ToolCalls()...)
86+
result, err := toolkit.Run(context.TODO(), session.ToolCalls()...)
8787
if !assert.NoError(err) {
8888
t.FailNow()
8989
}
90+
assert.NotEmpty(result)
9091
})
9192
}

pkg/ollama/chat_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func Test_chat_001(t *testing.T) {
3131

3232
t.Run("ChatStream", func(t *testing.T) {
3333
assert := assert.New(t)
34-
response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?"), llm.WithStream(func(stream llm.Context) {
34+
response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?"), llm.WithStream(func(stream llm.ContextContent) {
3535
t.Log(stream)
3636
}))
3737
if !assert.NoError(err) {

pkg/ollama/model.go

+9
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,15 @@ func (ollama *Client) Models(ctx context.Context) ([]llm.Model, error) {
9191
return ollama.ListModels(ctx)
9292
}
9393

94+
// Agent interface
95+
func (ollama *Client) Model(ctx context.Context, name string) llm.Model {
96+
model, err := ollama.GetModel(ctx, name)
97+
if err != nil {
98+
panic(err)
99+
}
100+
return model
101+
}
102+
94103
// List models
95104
func (ollama *Client) ListModels(ctx context.Context) ([]llm.Model, error) {
96105
type respListModel struct {

pkg/ollama/session_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func Test_session_001(t *testing.T) {
2828
// Session with a single user prompt - streaming
2929
t.Run("stream", func(t *testing.T) {
3030
assert := assert.New(t)
31-
session := model.Context(llm.WithStream(func(stream llm.Context) {
31+
session := model.Context(llm.WithStream(func(stream llm.ContextContent) {
3232
t.Log("SESSION DELTA", stream)
3333
}))
3434
assert.NotNil(session)

0 commit comments

Comments
 (0)