Skip to content

Commit c7b6c0c

Browse files
authored
feat (provider/togetherai): add image generation support (vercel#4655)
1 parent bb144d2 commit c7b6c0c

File tree

9 files changed

+590
-9
lines changed

9 files changed

+590
-9
lines changed

.changeset/giant-ducks-itch.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'@ai-sdk/togetherai': patch
3+
---
4+
5+
feat (provider/togetherai): add image generation support

content/docs/03-ai-sdk-core/35-image-generation.mdx

+10
Original file line numberDiff line numberDiff line change
@@ -234,5 +234,15 @@ try {
234234
| [Fal](/providers/ai-sdk-providers/fal#image-models) | `recraft-v3` | 1:1, 3:4, 4:3, 9:16, 16:9, 9:21, 21:9 |
235235
| [Fal](/providers/ai-sdk-providers/fal#image-models) | `stable-diffusion-3.5-large` | 1:1, 3:4, 4:3, 9:16, 16:9, 9:21, 21:9 |
236236
| [Fal](/providers/ai-sdk-providers/fal#image-models) | `hyper-sdxl` | 1:1, 3:4, 4:3, 9:16, 16:9, 9:21, 21:9 |
237+
| [Together.ai](/providers/ai-sdk-providers/togetherai#image-models) | `stabilityai/stable-diffusion-xl-base-1.0` | 512x512, 768x768, 1024x1024 |
238+
| [Together.ai](/providers/ai-sdk-providers/togetherai#image-models) | `black-forest-labs/FLUX.1-dev` | 512x512, 768x768, 1024x1024 |
239+
| [Together.ai](/providers/ai-sdk-providers/togetherai#image-models) | `black-forest-labs/FLUX.1-dev-lora` | 512x512, 768x768, 1024x1024 |
240+
| [Together.ai](/providers/ai-sdk-providers/togetherai#image-models) | `black-forest-labs/FLUX.1-schnell` | 512x512, 768x768, 1024x1024 |
241+
| [Together.ai](/providers/ai-sdk-providers/togetherai#image-models) | `black-forest-labs/FLUX.1-canny` | 512x512, 768x768, 1024x1024 |
242+
| [Together.ai](/providers/ai-sdk-providers/togetherai#image-models) | `black-forest-labs/FLUX.1-depth` | 512x512, 768x768, 1024x1024 |
243+
| [Together.ai](/providers/ai-sdk-providers/togetherai#image-models) | `black-forest-labs/FLUX.1-redux` | 512x512, 768x768, 1024x1024 |
244+
| [Together.ai](/providers/ai-sdk-providers/togetherai#image-models) | `black-forest-labs/FLUX.1.1-pro` | 512x512, 768x768, 1024x1024 |
245+
| [Together.ai](/providers/ai-sdk-providers/togetherai#image-models) | `black-forest-labs/FLUX.1-pro` | 512x512, 768x768, 1024x1024 |
246+
| [Together.ai](/providers/ai-sdk-providers/togetherai#image-models) | `black-forest-labs/FLUX.1-schnell-Free` | 512x512, 768x768, 1024x1024 |
237247

238248
Above are a small subset of the image models supported by the AI SDK providers. For more, see the respective provider documentation.

content/providers/01-ai-sdk-providers/24-togetherai.mdx

+59
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,62 @@ The Together.ai provider also supports [completion models](https://docs.together
111111
available models. You can also pass any available provider model ID as a
112112
string if needed.
113113
</Note>
114+
115+
## Image Models
116+
117+
You can create Together.ai image models using the `.image()` factory method.
118+
For more on image generation with the AI SDK see [generateImage()](/docs/reference/ai-sdk-core/generate-image).
119+
120+
```ts
121+
import { togetherai } from '@ai-sdk/togetherai';
122+
import { experimental_generateImage as generateImage } from 'ai';
123+
124+
const { images } = await generateImage({
125+
model: togetherai.image('black-forest-labs/FLUX.1-dev'),
126+
prompt: 'A delighted resplendent quetzal mid flight amidst raindrops',
127+
});
128+
```
129+
130+
You can pass optional provider-specific request parameters using the `providerOptions` argument.
131+
132+
```ts
133+
import { togetherai } from '@ai-sdk/togetherai';
134+
import { experimental_generateImage as generateImage } from 'ai';
135+
136+
const { images } = await generateImage({
137+
model: togetherai.image('black-forest-labs/FLUX.1-dev'),
138+
prompt: 'A delighted resplendent quetzal mid flight amidst raindrops',
139+
size: '512x512',
140+
// Optional additional provider-specific request parameters
141+
providerOptions: {
142+
togetherai: {
143+
steps: 40,
144+
},
145+
},
146+
});
147+
```
148+
149+
For a complete list of available provider-specific options, see the [Together.ai Image Generation API Reference](https://docs.together.ai/reference/post_images-generations).
150+
151+
### Model Capabilities
152+
153+
Together.ai image models support various image dimensions that vary by model. Common sizes include 512x512, 768x768, and 1024x1024, with some models supporting up to 1792x1792. The default size is 1024x1024.
154+
155+
| Available Models |
156+
| ------------------------------------------ |
157+
| `stabilityai/stable-diffusion-xl-base-1.0` |
158+
| `black-forest-labs/FLUX.1-dev` |
159+
| `black-forest-labs/FLUX.1-dev-lora` |
160+
| `black-forest-labs/FLUX.1-schnell` |
161+
| `black-forest-labs/FLUX.1-canny` |
162+
| `black-forest-labs/FLUX.1-depth` |
163+
| `black-forest-labs/FLUX.1-redux` |
164+
| `black-forest-labs/FLUX.1.1-pro` |
165+
| `black-forest-labs/FLUX.1-pro` |
166+
| `black-forest-labs/FLUX.1-schnell-Free` |
167+
168+
<Note>
169+
Please see the [Together.ai models
170+
page](https://docs.together.ai/docs/serverless-models#image-models) for a full
171+
list of available image models and their capabilities.
172+
</Note>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import { togetherai } from '@ai-sdk/togetherai';
2+
import { experimental_generateImage as generateImage } from 'ai';
3+
import 'dotenv/config';
4+
import fs from 'fs';
5+
6+
async function main() {
7+
const result = await generateImage({
8+
model: togetherai.image('black-forest-labs/FLUX.1-dev'),
9+
prompt: 'A delighted resplendent quetzal mid flight amidst raindrops',
10+
size: '1024x1024',
11+
providerOptions: {
12+
togetherai: {
13+
// Together AI specific options
14+
steps: 40,
15+
},
16+
},
17+
});
18+
19+
for (const [index, image] of result.images.entries()) {
20+
const filename = `image-${Date.now()}-${index}.png`;
21+
fs.writeFileSync(filename, image.uint8Array);
22+
console.log(`Image saved to ${filename}`);
23+
}
24+
}
25+
26+
main().catch(console.error);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
import { FetchFunction } from '@ai-sdk/provider-utils';
2+
import { createTestServer } from '@ai-sdk/provider-utils/test';
3+
import { describe, expect, it } from 'vitest';
4+
import { TogetherAIImageModel } from './togetherai-image-model';
5+
import { TogetherAIImageSettings } from './togetherai-image-settings';
6+
7+
const prompt = 'A cute baby sea otter';
8+
9+
function createBasicModel({
10+
headers,
11+
fetch,
12+
currentDate,
13+
settings,
14+
}: {
15+
headers?: () => Record<string, string>;
16+
fetch?: FetchFunction;
17+
currentDate?: () => Date;
18+
settings?: TogetherAIImageSettings;
19+
} = {}) {
20+
return new TogetherAIImageModel(
21+
'stabilityai/stable-diffusion-xl',
22+
settings ?? {},
23+
{
24+
provider: 'togetherai',
25+
baseURL: 'https://api.example.com',
26+
headers: headers ?? (() => ({ 'api-key': 'test-key' })),
27+
fetch,
28+
_internal: {
29+
currentDate,
30+
},
31+
},
32+
);
33+
}
34+
35+
const server = createTestServer({
36+
'https://api.example.com/*': {
37+
response: {
38+
type: 'json-value',
39+
body: {
40+
id: 'test-id',
41+
data: [{ index: 0, b64_json: 'test-base64-content' }],
42+
model: 'stabilityai/stable-diffusion-xl',
43+
object: 'list',
44+
},
45+
},
46+
},
47+
});
48+
49+
describe('doGenerate', () => {
50+
it('should pass the correct parameters including size and seed', async () => {
51+
const model = createBasicModel();
52+
53+
await model.doGenerate({
54+
prompt,
55+
n: 1,
56+
size: '1024x1024',
57+
seed: 42,
58+
providerOptions: { togetherai: { additional_param: 'value' } },
59+
aspectRatio: undefined,
60+
});
61+
62+
expect(await server.calls[0].requestBody).toStrictEqual({
63+
model: 'stabilityai/stable-diffusion-xl',
64+
prompt,
65+
seed: 42,
66+
n: 1,
67+
width: 1024,
68+
height: 1024,
69+
response_format: 'base64',
70+
additional_param: 'value',
71+
});
72+
});
73+
74+
it('should call the correct url', async () => {
75+
const model = createBasicModel();
76+
77+
await model.doGenerate({
78+
prompt,
79+
n: 1,
80+
size: '1024x1024',
81+
seed: 42,
82+
providerOptions: {},
83+
aspectRatio: undefined,
84+
});
85+
86+
expect(server.calls[0].requestMethod).toStrictEqual('POST');
87+
expect(server.calls[0].requestUrl).toStrictEqual(
88+
'https://api.example.com/images/generations',
89+
);
90+
});
91+
92+
it('should pass headers', async () => {
93+
const modelWithHeaders = createBasicModel({
94+
headers: () => ({
95+
'Custom-Provider-Header': 'provider-header-value',
96+
}),
97+
});
98+
99+
await modelWithHeaders.doGenerate({
100+
prompt,
101+
n: 1,
102+
size: undefined,
103+
seed: undefined,
104+
providerOptions: {},
105+
aspectRatio: undefined,
106+
headers: {
107+
'Custom-Request-Header': 'request-header-value',
108+
},
109+
});
110+
111+
expect(server.calls[0].requestHeaders).toStrictEqual({
112+
'content-type': 'application/json',
113+
'custom-provider-header': 'provider-header-value',
114+
'custom-request-header': 'request-header-value',
115+
});
116+
});
117+
118+
it('should handle API errors', async () => {
119+
server.urls['https://api.example.com/*'].response = {
120+
type: 'error',
121+
status: 400,
122+
body: JSON.stringify({
123+
error: {
124+
message: 'Bad Request',
125+
},
126+
}),
127+
};
128+
129+
const model = createBasicModel();
130+
await expect(
131+
model.doGenerate({
132+
prompt,
133+
n: 1,
134+
size: undefined,
135+
seed: undefined,
136+
providerOptions: {},
137+
aspectRatio: undefined,
138+
}),
139+
).rejects.toMatchObject({
140+
message: 'Bad Request',
141+
});
142+
});
143+
144+
describe('warnings', () => {
145+
it('should return aspectRatio warning when aspectRatio is provided', async () => {
146+
const model = createBasicModel();
147+
148+
const result = await model.doGenerate({
149+
prompt,
150+
n: 1,
151+
size: '1024x1024',
152+
aspectRatio: '1:1',
153+
seed: 123,
154+
providerOptions: {},
155+
});
156+
157+
expect(result.warnings).toContainEqual({
158+
type: 'unsupported-setting',
159+
setting: 'aspectRatio',
160+
details:
161+
'This model does not support the `aspectRatio` option. Use `size` instead.',
162+
});
163+
});
164+
});
165+
166+
it('should respect the abort signal', async () => {
167+
const model = createBasicModel();
168+
const controller = new AbortController();
169+
170+
const generatePromise = model.doGenerate({
171+
prompt,
172+
n: 1,
173+
size: undefined,
174+
seed: undefined,
175+
providerOptions: {},
176+
aspectRatio: undefined,
177+
abortSignal: controller.signal,
178+
});
179+
180+
controller.abort();
181+
182+
await expect(generatePromise).rejects.toThrow('This operation was aborted');
183+
});
184+
185+
describe('response metadata', () => {
186+
it('should include timestamp, headers and modelId in response', async () => {
187+
const testDate = new Date('2024-01-01T00:00:00Z');
188+
const model = createBasicModel({
189+
currentDate: () => testDate,
190+
});
191+
192+
const result = await model.doGenerate({
193+
prompt,
194+
n: 1,
195+
size: undefined,
196+
seed: undefined,
197+
providerOptions: {},
198+
aspectRatio: undefined,
199+
});
200+
201+
expect(result.response).toStrictEqual({
202+
timestamp: testDate,
203+
modelId: 'stabilityai/stable-diffusion-xl',
204+
headers: expect.any(Object),
205+
});
206+
});
207+
208+
it('should include response headers from API call', async () => {
209+
server.urls['https://api.example.com/*'].response = {
210+
type: 'json-value',
211+
body: {
212+
id: 'test-id',
213+
data: [{ index: 0, b64_json: 'test-base64-content' }],
214+
model: 'stabilityai/stable-diffusion-xl',
215+
object: 'list',
216+
},
217+
headers: {
218+
'x-request-id': 'test-request-id',
219+
'content-length': '128',
220+
},
221+
};
222+
223+
const model = createBasicModel();
224+
const result = await model.doGenerate({
225+
prompt,
226+
n: 1,
227+
size: undefined,
228+
seed: undefined,
229+
providerOptions: {},
230+
aspectRatio: undefined,
231+
});
232+
233+
expect(result.response.headers).toStrictEqual({
234+
'x-request-id': 'test-request-id',
235+
'content-type': 'application/json',
236+
'content-length': '128',
237+
});
238+
});
239+
});
240+
});
241+
242+
describe('constructor', () => {
243+
it('should expose correct provider and model information', () => {
244+
const model = createBasicModel();
245+
246+
expect(model.provider).toBe('togetherai');
247+
expect(model.modelId).toBe('stabilityai/stable-diffusion-xl');
248+
expect(model.specificationVersion).toBe('v1');
249+
expect(model.maxImagesPerCall).toBe(1);
250+
});
251+
252+
it('should use maxImagesPerCall from settings', () => {
253+
const model = createBasicModel({
254+
settings: {
255+
maxImagesPerCall: 4,
256+
},
257+
});
258+
259+
expect(model.maxImagesPerCall).toBe(4);
260+
});
261+
262+
it('should default maxImagesPerCall to 1 when not specified', () => {
263+
const model = createBasicModel();
264+
265+
expect(model.maxImagesPerCall).toBe(1);
266+
});
267+
});

0 commit comments

Comments
 (0)