Skip to content

Commit 857bb0d

Browse files
committed
Add dependency injection to ChatCompletionStream for improved testability
**Describe the change** This PR refactors the `ChatCompletionStream` to use dependency injection by introducing a `ChatStreamReader` interface. This allows for injecting custom stream readers, primarily for testing purposes, making the streaming functionality more testable and maintainable. **Provide OpenAI documentation link** https://platform.openai.com/docs/api-reference/chat/create **Describe your solution** The changes include: - Added a `ChatStreamReader` interface that defines the contract for reading chat completion streams - Refactored `ChatCompletionStream` to use composition with a `ChatStreamReader` instead of embedding `streamReader` - Added `NewChatCompletionStream()` constructor function to enable dependency injection - Implemented explicit delegation methods (`Recv()`, `Close()`, `Header()`, `GetRateLimitHeaders()`) on `ChatCompletionStream` - Added interface compliance check via `var _ ChatStreamReader = (*streamReader[ChatCompletionStreamResponse])(nil)` This approach maintains backward compatibility while enabling easier mocking and testing of streaming functionality. **Tests** Added comprehensive tests demonstrating the new functionality: - `TestChatCompletionStream_MockInjection`: Tests basic mock injection with the new constructor - `mock_streaming_demo_test.go`: A complete demonstration file showing how to create mock clients and stream readers for testing, including: - `MockOpenAIStreamClient`: Full mock client implementation - `mockStreamReader`: Custom stream reader for controlled test responses - `TestMockOpenAIStreamClient_Demo`: Demonstrates assembling multiple stream chunks - `TestMockOpenAIStreamClient_ErrorHandling`: Shows error handling patterns **Additional context** This refactoring improves the testability of code that depends on go-openai streaming without introducing breaking changes. The existing public API remains unchanged, but now supports dependency injection for testing scenarios. The new demo test file serves as documentation for users who want to mock streaming responses in their own tests. Lint fix
1 parent ff9d83a commit 857bb0d

File tree

4 files changed

+272
-2
lines changed

4 files changed

+272
-2
lines changed

chat_stream.go

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,21 @@ type ChatCompletionStreamResponse struct {
6565
Usage *Usage `json:"usage,omitempty"`
6666
}
6767

68+
// ChatStreamReader is an interface for reading chat completion streams.
69+
type ChatStreamReader interface {
70+
Recv() (ChatCompletionStreamResponse, error)
71+
Close() error
72+
}
73+
6874
// ChatCompletionStream
6975
// Note: Perhaps it is more elegant to abstract Stream using generics.
7076
type ChatCompletionStream struct {
71-
*streamReader[ChatCompletionStreamResponse]
77+
reader ChatStreamReader
78+
}
79+
80+
// NewChatCompletionStream allows injecting a custom ChatStreamReader (for testing).
81+
func NewChatCompletionStream(reader ChatStreamReader) *ChatCompletionStream {
82+
return &ChatCompletionStream{reader: reader}
7283
}
7384

7485
// CreateChatCompletionStream — API call to create a chat completion w/ streaming
@@ -106,7 +117,37 @@ func (c *Client) CreateChatCompletionStream(
106117
return
107118
}
108119
stream = &ChatCompletionStream{
109-
streamReader: resp,
120+
reader: resp,
110121
}
111122
return
112123
}
124+
125+
func (s *ChatCompletionStream) Recv() (ChatCompletionStreamResponse, error) {
126+
return s.reader.Recv()
127+
}
128+
129+
func (s *ChatCompletionStream) Close() error {
130+
return s.reader.Close()
131+
}
132+
133+
func (s *ChatCompletionStream) Header() http.Header {
134+
if h, ok := s.reader.(interface{ Header() http.Header }); ok {
135+
return h.Header()
136+
}
137+
return http.Header{}
138+
}
139+
140+
func (s *ChatCompletionStream) GetRateLimitHeaders() map[string]interface{} {
141+
if h, ok := s.reader.(interface{ GetRateLimitHeaders() RateLimitHeaders }); ok {
142+
headers := h.GetRateLimitHeaders()
143+
return map[string]interface{}{
144+
"x-ratelimit-limit-requests": headers.LimitRequests,
145+
"x-ratelimit-limit-tokens": headers.LimitTokens,
146+
"x-ratelimit-remaining-requests": headers.RemainingRequests,
147+
"x-ratelimit-remaining-tokens": headers.RemainingTokens,
148+
"x-ratelimit-reset-requests": headers.ResetRequests.String(),
149+
"x-ratelimit-reset-tokens": headers.ResetTokens.String(),
150+
}
151+
}
152+
return map[string]interface{}{}
153+
}

chat_stream_test.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,34 @@ func TestCreateChatCompletionStreamStreamOptions(t *testing.T) {
767767
}
768768
}
769769

770+
type mockStream struct {
771+
calls int
772+
}
773+
774+
// Implement ChatStreamReader.
775+
func (m *mockStream) Recv() (openai.ChatCompletionStreamResponse, error) {
776+
m.calls++
777+
if m.calls == 1 {
778+
return openai.ChatCompletionStreamResponse{ID: "mock1"}, nil
779+
}
780+
return openai.ChatCompletionStreamResponse{}, io.EOF
781+
}
782+
func (m *mockStream) Close() error { return nil }
783+
784+
func TestChatCompletionStream_MockInjection(t *testing.T) {
785+
mock := &mockStream{}
786+
stream := openai.NewChatCompletionStream(mock)
787+
788+
resp, err := stream.Recv()
789+
if err != nil || resp.ID != "mock1" {
790+
t.Errorf("expected mock1, got %v, err %v", resp.ID, err)
791+
}
792+
_, err = stream.Recv()
793+
if !errors.Is(err, io.EOF) {
794+
t.Errorf("expected EOF, got %v", err)
795+
}
796+
}
797+
770798
// Helper funcs.
771799
func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool {
772800
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {

mock_streaming_demo_test.go

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
package openai_test
2+
3+
import (
4+
"context"
5+
"errors"
6+
"io"
7+
"testing"
8+
9+
"github.com/sashabaranov/go-openai"
10+
)
11+
12+
// This file demonstrates how to create mock clients for go-openai streaming
13+
// functionality. This pattern is useful when testing code that depends on
14+
// go-openai streaming but you want to control the responses for testing.
15+
16+
// MockOpenAIStreamClient demonstrates how to create a full mock client for go-openai.
17+
type MockOpenAIStreamClient struct {
18+
// Configure canned responses
19+
ChatCompletionResponse openai.ChatCompletionResponse
20+
ChatCompletionStreamErr error
21+
22+
// Allow function overrides for more complex scenarios
23+
CreateChatCompletionStreamFn func(
24+
ctx context.Context, req openai.ChatCompletionRequest) (*openai.ChatCompletionStream, error)
25+
}
26+
27+
func (m *MockOpenAIStreamClient) CreateChatCompletionStream(
28+
ctx context.Context,
29+
req openai.ChatCompletionRequest,
30+
) (*openai.ChatCompletionStream, error) {
31+
if m.CreateChatCompletionStreamFn != nil {
32+
return m.CreateChatCompletionStreamFn(ctx, req)
33+
}
34+
return nil, m.ChatCompletionStreamErr
35+
}
36+
37+
// mockStreamReader creates specific responses for testing.
38+
type mockStreamReader struct {
39+
responses []openai.ChatCompletionStreamResponse
40+
index int
41+
}
42+
43+
func (m *mockStreamReader) Recv() (openai.ChatCompletionStreamResponse, error) {
44+
if m.index >= len(m.responses) {
45+
return openai.ChatCompletionStreamResponse{}, io.EOF
46+
}
47+
resp := m.responses[m.index]
48+
m.index++
49+
return resp, nil
50+
}
51+
52+
func (m *mockStreamReader) Close() error {
53+
return nil
54+
}
55+
56+
func TestMockOpenAIStreamClient_Demo(t *testing.T) {
57+
// Create expected responses that our mock stream will return
58+
expectedResponses := []openai.ChatCompletionStreamResponse{
59+
{
60+
ID: "test-1",
61+
Object: "chat.completion.chunk",
62+
Model: "gpt-3.5-turbo",
63+
Choices: []openai.ChatCompletionStreamChoice{
64+
{
65+
Index: 0,
66+
Delta: openai.ChatCompletionStreamChoiceDelta{
67+
Role: "assistant",
68+
Content: "Hello",
69+
},
70+
},
71+
},
72+
},
73+
{
74+
ID: "test-2",
75+
Object: "chat.completion.chunk",
76+
Model: "gpt-3.5-turbo",
77+
Choices: []openai.ChatCompletionStreamChoice{
78+
{
79+
Index: 0,
80+
Delta: openai.ChatCompletionStreamChoiceDelta{
81+
Content: " World",
82+
},
83+
},
84+
},
85+
},
86+
{
87+
ID: "test-3",
88+
Object: "chat.completion.chunk",
89+
Model: "gpt-3.5-turbo",
90+
Choices: []openai.ChatCompletionStreamChoice{
91+
{
92+
Index: 0,
93+
Delta: openai.ChatCompletionStreamChoiceDelta{},
94+
FinishReason: "stop",
95+
},
96+
},
97+
},
98+
}
99+
100+
// Create mock client with custom stream function
101+
mockClient := &MockOpenAIStreamClient{
102+
CreateChatCompletionStreamFn: func(
103+
_ context.Context, _ openai.ChatCompletionRequest,
104+
) (*openai.ChatCompletionStream, error) {
105+
// Create a mock stream reader with our expected responses
106+
mockStreamReader := &mockStreamReader{
107+
responses: expectedResponses,
108+
index: 0,
109+
}
110+
// Return a new ChatCompletionStream with our mock reader
111+
return openai.NewChatCompletionStream(mockStreamReader), nil
112+
},
113+
}
114+
115+
// Test the mock client
116+
stream, err := mockClient.CreateChatCompletionStream(
117+
context.Background(),
118+
openai.ChatCompletionRequest{
119+
Model: openai.GPT3Dot5Turbo,
120+
Messages: []openai.ChatCompletionMessage{
121+
{
122+
Role: openai.ChatMessageRoleUser,
123+
Content: "Hello!",
124+
},
125+
},
126+
},
127+
)
128+
if err != nil {
129+
t.Fatalf("CreateChatCompletionStream returned error: %v", err)
130+
}
131+
defer stream.Close()
132+
133+
// Verify we get back exactly the responses we configured
134+
fullResponse := ""
135+
for i, expectedResponse := range expectedResponses {
136+
receivedResponse, streamErr := stream.Recv()
137+
if streamErr != nil {
138+
t.Fatalf("stream.Recv() failed at index %d: %v", i, streamErr)
139+
}
140+
141+
// Additional specific checks
142+
if receivedResponse.ID != expectedResponse.ID {
143+
t.Errorf("Response %d ID mismatch. Expected: %s, Got: %s",
144+
i, expectedResponse.ID, receivedResponse.ID)
145+
}
146+
if len(receivedResponse.Choices) > 0 && len(expectedResponse.Choices) > 0 {
147+
expectedContent := expectedResponse.Choices[0].Delta.Content
148+
receivedContent := receivedResponse.Choices[0].Delta.Content
149+
if receivedContent != expectedContent {
150+
t.Errorf("Response %d content mismatch. Expected: %s, Got: %s",
151+
i, expectedContent, receivedContent)
152+
}
153+
fullResponse += receivedContent
154+
}
155+
}
156+
157+
// Verify EOF at the end
158+
_, streamErr := stream.Recv()
159+
if !errors.Is(streamErr, io.EOF) {
160+
t.Errorf("Expected EOF at end of stream, got: %v", streamErr)
161+
}
162+
163+
// Verify the full assembled response
164+
expectedFullResponse := "Hello World"
165+
if fullResponse != expectedFullResponse {
166+
t.Errorf("Full response mismatch. Expected: %s, Got: %s", expectedFullResponse, fullResponse)
167+
}
168+
169+
t.Log("✅ Successfully demonstrated mock OpenAI client with streaming responses!")
170+
t.Logf(" Full response assembled: %q", fullResponse)
171+
}
172+
173+
// TestMockOpenAIStreamClient_ErrorHandling demonstrates error handling.
174+
func TestMockOpenAIStreamClient_ErrorHandling(t *testing.T) {
175+
expectedError := errors.New("mock stream error")
176+
177+
mockClient := &MockOpenAIStreamClient{
178+
ChatCompletionStreamErr: expectedError,
179+
}
180+
181+
_, err := mockClient.CreateChatCompletionStream(
182+
context.Background(),
183+
openai.ChatCompletionRequest{
184+
Model: openai.GPT3Dot5Turbo,
185+
Messages: []openai.ChatCompletionMessage{
186+
{
187+
Role: openai.ChatMessageRoleUser,
188+
Content: "Hello!",
189+
},
190+
},
191+
},
192+
)
193+
194+
if !errors.Is(err, expectedError) {
195+
t.Errorf("Expected error %v, got %v", expectedError, err)
196+
}
197+
198+
t.Log("✅ Successfully demonstrated mock OpenAI client error handling!")
199+
}

stream_reader.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ var (
1616
errorPrefix = regexp.MustCompile(`^data:\s*{"error":`)
1717
)
1818

19+
var _ ChatStreamReader = (*streamReader[ChatCompletionStreamResponse])(nil)
20+
1921
type streamable interface {
2022
ChatCompletionStreamResponse | CompletionResponse
2123
}

0 commit comments

Comments
 (0)