Skip to content

Commit 86ef1c3

Browse files
authored
Merge branch 'sashabaranov:master' into master
2 parents 846b744 + d7dca83 commit 86ef1c3

File tree

3 files changed

+62
-5
lines changed

3 files changed

+62
-5
lines changed

image.go

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,32 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons
132132
return
133133
}
134134

135+
// WrapReader wraps an io.Reader with filename and Content-type.
136+
func WrapReader(rdr io.Reader, filename string, contentType string) io.Reader {
137+
return file{rdr, filename, contentType}
138+
}
139+
140+
type file struct {
141+
io.Reader
142+
name string
143+
contentType string
144+
}
145+
146+
func (f file) Name() string {
147+
if f.name != "" {
148+
return f.name
149+
} else if named, ok := f.Reader.(interface{ Name() string }); ok {
150+
return named.Name()
151+
}
152+
return ""
153+
}
154+
155+
func (f file) ContentType() string {
156+
return f.contentType
157+
}
158+
135159
// ImageEditRequest represents the request structure for the image API.
160+
// Use WrapReader to wrap an io.Reader with filename and Content-type.
136161
type ImageEditRequest struct {
137162
Image io.Reader `json:"image,omitempty"`
138163
Mask io.Reader `json:"mask,omitempty"`
@@ -150,15 +175,15 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
150175
body := &bytes.Buffer{}
151176
builder := c.createFormBuilder(body)
152177

153-
// image, filename is not required
178+
// image, filename verification can be postponed
154179
err = builder.CreateFormFileReader("image", request.Image, "")
155180
if err != nil {
156181
return
157182
}
158183

159184
// mask, it is optional
160185
if request.Mask != nil {
161-
// mask, filename is not required
186+
// filename verification can be postponed
162187
err = builder.CreateFormFileReader("mask", request.Mask, "")
163188
if err != nil {
164189
return
@@ -206,6 +231,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
206231
}
207232

208233
// ImageVariRequest represents the request structure for the image API.
234+
// Use WrapReader to wrap an io.Reader with filename and Content-type.
209235
type ImageVariRequest struct {
210236
Image io.Reader `json:"image,omitempty"`
211237
Model string `json:"model,omitempty"`
@@ -221,7 +247,7 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
221247
body := &bytes.Buffer{}
222248
builder := c.createFormBuilder(body)
223249

224-
// image, filename is not required
250+
// image, filename verification can be postponed
225251
err = builder.CreateFormFileReader("image", request.Image, "")
226252
if err != nil {
227253
return

internal/form_builder.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,18 @@ func escapeQuotes(s string) string {
3939
}
4040

4141
// CreateFormFileReader creates a form field with a file reader.
42-
// The filename in parameters can be an empty string.
43-
// The filename in Content-Disposition is required, But it can be an empty string.
42+
// The filename in Content-Disposition is required.
4443
func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
44+
if filename == "" {
45+
if f, ok := r.(interface{ Name() string }); ok {
46+
filename = f.Name()
47+
}
48+
}
49+
var contentType string
50+
if f, ok := r.(interface{ ContentType() string }); ok {
51+
contentType = f.ContentType()
52+
}
53+
4554
h := make(textproto.MIMEHeader)
4655
h.Set(
4756
"Content-Disposition",
@@ -51,6 +60,10 @@ func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader
5160
escapeQuotes(filepath.Base(filename)),
5261
),
5362
)
63+
// content type is optional, but it can be set
64+
if contentType != "" {
65+
h.Set("Content-Type", contentType)
66+
}
5467

5568
fieldWriter, err := fb.writer.CreatePart(h)
5669
if err != nil {

internal/form_builder_test.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package openai //nolint:testpackage // testing private field
22

33
import (
4+
"io"
5+
46
"github.com/sashabaranov/go-openai/internal/test/checks"
57

68
"bytes"
@@ -53,6 +55,18 @@ func (*failingReader) Read([]byte) (int, error) {
5355
return 0, errMockFailingReaderError
5456
}
5557

58+
type readerWithNameAndContentType struct {
59+
io.Reader
60+
}
61+
62+
func (*readerWithNameAndContentType) Name() string {
63+
return ""
64+
}
65+
66+
func (*readerWithNameAndContentType) ContentType() string {
67+
return "image/png"
68+
}
69+
5670
func TestFormBuilderWithReader(t *testing.T) {
5771
file, err := os.CreateTemp(t.TempDir(), "")
5872
if err != nil {
@@ -71,4 +85,8 @@ func TestFormBuilderWithReader(t *testing.T) {
7185
successReader := &bytes.Buffer{}
7286
err = builder.CreateFormFileReader("file", successReader, "")
7387
checks.NoError(t, err, "formbuilder should not return error")
88+
89+
rnc := &readerWithNameAndContentType{Reader: &bytes.Buffer{}}
90+
err = builder.CreateFormFileReader("file", rnc, "")
91+
checks.NoError(t, err, "formbuilder should not return error")
7492
}

0 commit comments

Comments
 (0)