Skip to content

Commit a0db37a

Browse files
committed
Fix for #1021:
1. Make Usage field in completions Response to pointer. 2. Add omitempty to json tag Signed-off-by: Hritik003 <[email protected]>
1 parent 86ef1c3 commit a0db37a

File tree

6 files changed

+164
-28
lines changed

6 files changed

+164
-28
lines changed

completion.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ type CompletionResponse struct {
242242
Created int64 `json:"created"`
243243
Model string `json:"model"`
244244
Choices []CompletionChoice `json:"choices"`
245-
Usage *Usage `json:"usage"`
245+
Usage *Usage `json:"usage,omitempty"`
246246

247247
httpHeader
248248
}

internal/error_accumulator_test.go

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,32 @@
11
package openai_test
22

33
import (
4-
"bytes"
5-
"errors"
64
"testing"
75

8-
utils "github.com/sashabaranov/go-openai/internal"
9-
"github.com/sashabaranov/go-openai/internal/test"
6+
openai "github.com/sashabaranov/go-openai/internal"
7+
"github.com/sashabaranov/go-openai/internal/test/checks"
108
)
119

12-
func TestErrorAccumulatorBytes(t *testing.T) {
13-
accumulator := &utils.DefaultErrorAccumulator{
14-
Buffer: &bytes.Buffer{},
10+
func TestDefaultErrorAccumulator_WriteMultiple(t *testing.T) {
11+
ea, ok := openai.NewErrorAccumulator().(*openai.DefaultErrorAccumulator)
12+
if !ok {
13+
t.Fatal("type assertion to *DefaultErrorAccumulator failed")
1514
}
15+
checks.NoError(t, ea.Write([]byte("{\"error\": \"test1\"}")))
16+
checks.NoError(t, ea.Write([]byte("{\"error\": \"test2\"}")))
1617

17-
errBytes := accumulator.Bytes()
18-
if len(errBytes) != 0 {
19-
t.Fatalf("Did not return nil with empty bytes: %s", string(errBytes))
20-
}
21-
22-
err := accumulator.Write([]byte("{}"))
23-
if err != nil {
24-
t.Fatalf("%+v", err)
25-
}
26-
27-
errBytes = accumulator.Bytes()
28-
if len(errBytes) == 0 {
29-
t.Fatalf("Did not return error bytes when has error: %s", string(errBytes))
18+
expected := "{\"error\": \"test1\"}{\"error\": \"test2\"}"
19+
if string(ea.Bytes()) != expected {
20+
t.Fatalf("Expected %q, got %q", expected, ea.Bytes())
3021
}
3122
}
3223

33-
func TestErrorByteWriteErrors(t *testing.T) {
34-
accumulator := &utils.DefaultErrorAccumulator{
35-
Buffer: &test.FailingErrorBuffer{},
24+
func TestDefaultErrorAccumulator_EmptyBuffer(t *testing.T) {
25+
ea, ok := openai.NewErrorAccumulator().(*openai.DefaultErrorAccumulator)
26+
if !ok {
27+
t.Fatal("type assertion to *DefaultErrorAccumulator failed")
3628
}
37-
err := accumulator.Write([]byte("{"))
38-
if !errors.Is(err, test.ErrTestErrorAccumulatorWriteFailed) {
39-
t.Fatalf("Did not return error when write failed: %v", err)
29+
if len(ea.Bytes()) != 0 {
30+
t.Fatal("Buffer should be empty initially")
4031
}
4132
}

internal/form_builder.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, file
9797
}
9898

9999
func (fb *DefaultFormBuilder) WriteField(fieldname, value string) error {
100+
if fieldname == "" {
101+
return fmt.Errorf("fieldname cannot be empty")
102+
}
100103
return fb.writer.WriteField(fieldname, value)
101104
}
102105

internal/form_builder_test.go

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,57 @@
11
package openai //nolint:testpackage // testing private field
22

33
import (
4+
"errors"
45
"io"
56

67
"github.com/sashabaranov/go-openai/internal/test/checks"
78

89
"bytes"
9-
"errors"
1010
"os"
1111
"testing"
1212
)
1313

14+
type mockFormBuilder struct {
15+
mockCreateFormFile func(string, *os.File) error
16+
mockWriteField func(string, string) error
17+
mockClose func() error
18+
}
19+
20+
func (m *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
21+
return m.mockCreateFormFile(fieldname, file)
22+
}
23+
24+
func (m *mockFormBuilder) WriteField(fieldname, value string) error {
25+
return m.mockWriteField(fieldname, value)
26+
}
27+
28+
func (m *mockFormBuilder) Close() error {
29+
return m.mockClose()
30+
}
31+
32+
func (m *mockFormBuilder) FormDataContentType() string {
33+
return ""
34+
}
35+
36+
func TestCloseMethod(t *testing.T) {
37+
t.Run("NormalClose", func(t *testing.T) {
38+
body := &bytes.Buffer{}
39+
builder := NewFormBuilder(body)
40+
checks.NoError(t, builder.Close(), "正常关闭应成功")
41+
})
42+
43+
t.Run("ErrorPropagation", func(t *testing.T) {
44+
errorMock := errors.New("mock close error")
45+
mockBuilder := &mockFormBuilder{
46+
mockClose: func() error {
47+
return errorMock
48+
},
49+
}
50+
err := mockBuilder.Close()
51+
checks.ErrorIs(t, err, errorMock, "应传递关闭错误")
52+
})
53+
}
54+
1455
type failingWriter struct {
1556
}
1657

@@ -90,3 +131,33 @@ func TestFormBuilderWithReader(t *testing.T) {
90131
err = builder.CreateFormFileReader("file", rnc, "")
91132
checks.NoError(t, err, "formbuilder should not return error")
92133
}
134+
135+
func TestFormDataContentType(t *testing.T) {
136+
t.Run("ReturnsUnderlyingWriterContentType", func(t *testing.T) {
137+
buf := &bytes.Buffer{}
138+
builder := NewFormBuilder(buf)
139+
140+
contentType := builder.FormDataContentType()
141+
if contentType == "" {
142+
t.Errorf("expected non-empty content type, got empty string")
143+
}
144+
})
145+
}
146+
147+
func TestWriteField(t *testing.T) {
148+
t.Run("EmptyFieldNameShouldReturnError", func(t *testing.T) {
149+
buf := &bytes.Buffer{}
150+
builder := NewFormBuilder(buf)
151+
152+
err := builder.WriteField("", "some value")
153+
checks.HasError(t, err, "fieldname is required")
154+
})
155+
156+
t.Run("ValidFieldNameShouldSucceed", func(t *testing.T) {
157+
buf := &bytes.Buffer{}
158+
builder := NewFormBuilder(buf)
159+
160+
err := builder.WriteField("key", "value")
161+
checks.NoError(t, err, "should write field without error")
162+
})
163+
}

internal/marshaller_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package openai_test
2+
3+
import (
4+
"testing"
5+
6+
openai "github.com/sashabaranov/go-openai/internal"
7+
"github.com/sashabaranov/go-openai/internal/test/checks"
8+
)
9+
10+
func TestJSONMarshaller_Normal(t *testing.T) {
11+
jm := &openai.JSONMarshaller{}
12+
data := map[string]string{"key": "value"}
13+
14+
b, err := jm.Marshal(data)
15+
checks.NoError(t, err)
16+
if len(b) == 0 {
17+
t.Fatal("should return non-empty bytes")
18+
}
19+
}
20+
21+
func TestJSONMarshaller_InvalidInput(t *testing.T) {
22+
jm := &openai.JSONMarshaller{}
23+
_, err := jm.Marshal(make(chan int))
24+
checks.HasError(t, err, "should return error for unsupported type")
25+
}
26+
27+
func TestJSONMarshaller_EmptyValue(t *testing.T) {
28+
jm := &openai.JSONMarshaller{}
29+
b, err := jm.Marshal(nil)
30+
checks.NoError(t, err)
31+
if string(b) != "null" {
32+
t.Fatalf("unexpected marshaled value: %s", string(b))
33+
}
34+
}

internal/unmarshaler_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package openai_test
2+
3+
import (
4+
"testing"
5+
6+
openai "github.com/sashabaranov/go-openai/internal"
7+
"github.com/sashabaranov/go-openai/internal/test/checks"
8+
)
9+
10+
func TestJSONUnmarshaler_Normal(t *testing.T) {
11+
jm := &openai.JSONUnmarshaler{}
12+
data := []byte(`{"key":"value"}`)
13+
var v map[string]string
14+
15+
err := jm.Unmarshal(data, &v)
16+
checks.NoError(t, err)
17+
if v["key"] != "value" {
18+
t.Fatal("unmarshal result mismatch")
19+
}
20+
}
21+
22+
func TestJSONUnmarshaler_InvalidJSON(t *testing.T) {
23+
jm := &openai.JSONUnmarshaler{}
24+
data := []byte(`{invalid}`)
25+
var v map[string]interface{}
26+
27+
err := jm.Unmarshal(data, &v)
28+
checks.HasError(t, err, "should return error for invalid JSON")
29+
}
30+
31+
func TestJSONUnmarshaler_EmptyInput(t *testing.T) {
32+
jm := &openai.JSONUnmarshaler{}
33+
var v interface{}
34+
35+
err := jm.Unmarshal(nil, &v)
36+
checks.HasError(t, err, "should return error for nil input")
37+
}

0 commit comments

Comments
 (0)