Skip to content

Add Groq as an inference provider #1352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Currently, we support the following providers:
- [Blackforestlabs](https://blackforestlabs.ai)
- [Cohere](https://cohere.com)
- [Cerebras](https://cerebras.ai/)
- [Groq](https://groq.com)

To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
```ts
Expand All @@ -84,6 +85,7 @@ Only a subset of models are supported when requesting third-party providers. You
- [Together supported models](https://huggingface.co/api/partners/together/models)
- [Cohere supported models](https://huggingface.co/api/partners/cohere/models)
- [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models)
- [Groq supported models](https://console.groq.com/docs/models)
- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)

❗**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.
Expand Down
5 changes: 5 additions & 0 deletions packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import * as Cerebras from "../providers/cerebras";
import * as Cohere from "../providers/cohere";
import * as FalAI from "../providers/fal-ai";
import * as Fireworks from "../providers/fireworks-ai";
import * as Groq from "../providers/groq";
import * as HFInference from "../providers/hf-inference";

import * as Hyperbolic from "../providers/hyperbolic";
Expand Down Expand Up @@ -95,6 +96,10 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
"fireworks-ai": {
conversational: new Fireworks.FireworksConversationalTask(),
},
groq: {
conversational: new Groq.GroqConversationalTask(),
"text-generation": new Groq.GroqTextGenerationTask(),
},
hyperbolic: {
"text-to-image": new Hyperbolic.HyperbolicTextToImageTask(),
conversational: new Hyperbolic.HyperbolicConversationalTask(),
Expand Down
1 change: 1 addition & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
cohere: {},
"fal-ai": {},
"fireworks-ai": {},
groq: {},
"hf-inference": {},
hyperbolic: {},
nebius: {},
Expand Down
40 changes: 40 additions & 0 deletions packages/inference/src/providers/groq.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import { BaseConversationalTask, BaseTextGenerationTask } from "./providerHelper";

/**
* See the registered mapping of HF model ID => Groq model ID here:
*
* https://huggingface.co/api/partners/groq/models
*
* This is a publicly available mapping.
*
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
*
* - If you work at Groq and want to update this mapping, please use the model mapping API we provide on huggingface.co
* - If you're a community member and want to add a new supported HF model to Groq, please open an issue on the present repo
* and we will tag Groq team members.
*
* Thanks!
*/

const GROQ_API_BASE_URL = "https://api.groq.com";

export class GroqTextGenerationTask extends BaseTextGenerationTask {
constructor() {
super("groq", GROQ_API_BASE_URL);
}

override makeRoute(): string {
return "/openai/v1/chat/completions";
}
}

export class GroqConversationalTask extends BaseConversationalTask {
constructor() {
super("groq", GROQ_API_BASE_URL);
}

override makeRoute(): string {
return "/openai/v1/chat/completions";
}
}
1 change: 1 addition & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ export const INFERENCE_PROVIDERS = [
"cohere",
"fal-ai",
"fireworks-ai",
"groq",
"hf-inference",
"hyperbolic",
"nebius",
Expand Down
51 changes: 51 additions & 0 deletions packages/inference/test/InferenceClient.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1691,4 +1691,55 @@ describe.concurrent("InferenceClient", () => {
},
TIMEOUT
);
describe.concurrent(
"Groq",
() => {
const client = new InferenceClient(env.HF_GROQ_KEY ?? "dummy");

HARDCODED_MODEL_INFERENCE_MAPPING["groq"] = {
"meta-llama/Llama-3.3-70B-Instruct": {
hfModelId: "meta-llama/Llama-3.3-70B-Instruct",
providerId: "llama-3.3-70b-versatile",
status: "live",
task: "conversational",
},
};

it("chatCompletion", async () => {
const res = await client.chatCompletion({
model: "meta-llama/Llama-3.3-70B-Instruct",
provider: "groq",
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
});
if (res.choices && res.choices.length > 0) {
const completion = res.choices[0].message?.content;
expect(completion).toContain("two");
}
});

it("chatCompletion stream", async () => {
const stream = client.chatCompletionStream({
model: "meta-llama/Llama-3.3-70B-Instruct",
provider: "groq",
messages: [{ role: "user", content: "Say 'this is a test'" }],
stream: true,
}) as AsyncGenerator<ChatCompletionStreamOutput>;

let fullResponse = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
const content = chunk.choices[0].delta?.content;
if (content) {
fullResponse += content;
}
}
}

// Verify we got a meaningful response
expect(fullResponse).toBeTruthy();
expect(fullResponse.length).toBeGreaterThan(0);
});
},
TIMEOUT
);
});
53 changes: 52 additions & 1 deletion packages/inference/test/tapes.json
Original file line number Diff line number Diff line change
Expand Up @@ -7486,5 +7486,56 @@
"vary": "Origin"
}
}
}
},
"5688b06e0eb91dd68eef47fad94783b8b38a56cceae637c57521a48d4711ff2d": {
"url": "https://api.groq.com/openai/v1/chat/completions",
"init": {
"headers": {
"Content-Type": "application/json"
},
"method": "POST",
"body": "{\"messages\":[{\"role\":\"user\",\"content\":\"Say 'this is a test'\"}],\"stream\":true,\"model\":\"llama-3.3-70b-versatile\"}"
},
"response": {
"body": "data: {\"id\":\"chatcmpl-f3f09b15-5394-4b39-bd41-5b30c9ba741e\",\"object\":\"chat.completion.chunk\",\"created\":1744905321,\"model\":\"llama-3.3-70b-versatile\",\"system_fingerprint\":\"fp_3f3b593e33\",\"instance_id\":\"LLAMA-33-70B-DMM1-PROD2-1\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"\"},\"logprobs\":null,\"finish_reason\":null}],\"x_groq\":{\"id\":\"req_01js27cen2f5brg46sjpmn0z3m\"}}\n\ndata: {\"id\":\"chatcmpl-f3f09b15-5394-4b39-bd41-5b30c9ba741e\",\"object\":\"chat.completion.chunk\",\"created\":1744905321,\"model\":\"llama-3.3-70b-versatile\",\"system_fingerprint\":\"fp_3f3b593e33\",\"instance_id\":\"LLAMA-33-70B-DMM1-PROD2-1\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"This\"},\"logprobs\":null,\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-f3f09b15-5394-4b39-bd41-5b30c9ba741e\",\"object\":\"chat.completion.chunk\",\"created\":1744905321,\"model\":\"llama-3.3-70b-versatile\",\"system_fingerprint\":\"fp_3f3b593e33\",\"instance_id\":\"LLAMA-33-70B-DMM1-PROD2-1\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" is\"},\"logprobs\":null,\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-f3f09b15-5394-4b39-bd41-5b30c9ba741e\",\"object\":\"chat.completion.chunk\",\"created\":1744905321,\"model\":\"llama-3.3-70b-versatile\",\"system_fingerprint\":\"fp_3f3b593e33\",\"instance_id\":\"LLAMA-33-70B-DMM1-PROD2-1\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" a\"},\"logprobs\":null,\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-f3f09b15-5394-4b39-bd41-5b30c9ba741e\",\"object\":\"chat.completion.chunk\",\"created\":1744905321,\"model\":\"llama-3.3-70b-versatile\",\"system_fingerprint\":\"fp_3f3b593e33\",\"instance_id\":\"LLAMA-33-70B-DMM1-PROD2-1\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" test\"},\"logprobs\":null,\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-f3f09b15-5394-4b39-bd41-5b30c9ba741e\",\"object\":\"chat.completion.chunk\",\"created\":1744905321,\"model\":\"llama-3.3-70b-versatile\",\"system_fingerprint\":\"fp_3f3b593e33\",\"instance_id\":\"LLAMA-33-70B-DMM1-PROD2-1\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\".\"},\"logprobs\":null,\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-f3f09b15-5394-4b39-bd41-5b30c9ba741e\",\"object\":\"chat.completion.chunk\",\"created\":1744905321,\"model\":\"llama-3.3-70b-versatile\",\"system_fingerprint\":\"fp_3f3b593e33\",\"instance_id\":\"LLAMA-33-70B-DMM1-PROD2-1\",\"choices\":[{\"index\":0,\"delta\":{},\"logprobs\":null,\"finish_reason\":\"stop\"}],\"x_groq\":{\"id\":\"req_01js27cen2f5brg46sjpmn0z3m\",\"usage\":{\"queue_time\":0.220723924,\"prompt_tokens\":42,\"prompt_time\":0.010160524,\"completion_tokens\":6,\"completion_time\":0.021818182,\"total_tokens\":48,\"total_time\":0.031978706}}}\n\ndata: [DONE]\n\n",
"status": 200,
"statusText": "OK",
"headers": {
"alt-svc": "h3=\":443\"; ma=86400",
"cache-control": "no-cache",
"cf-cache-status": "DYNAMIC",
"connection": "keep-alive",
"content-type": "text/event-stream",
"server": "cloudflare",
"transfer-encoding": "chunked",
"vary": "Origin, Accept-Encoding"
}
}
},
"01cb4504b502c793085788df0984db81d4f72532cebe5862d9558b0cbf07519c": {
"url": "https://api.groq.com/openai/v1/chat/completions",
"init": {
"headers": {
"Content-Type": "application/json"
},
"method": "POST",
"body": "{\"messages\":[{\"role\":\"user\",\"content\":\"Complete this sentence with words, one plus one is equal \"}],\"model\":\"llama-3.3-70b-versatile\"}"
},
"response": {
"body": "{\"id\":\"chatcmpl-76a3744d-bc7c-4153-900b-99d189af3720\",\"object\":\"chat.completion\",\"created\":1744905321,\"model\":\"llama-3.3-70b-versatile\",\"choices\":[{\"index\":0,\"message\":{\"role\":\"assistant\",\"content\":\"two.\"},\"logprobs\":null,\"finish_reason\":\"stop\"}],\"usage\":{\"queue_time\":0.218711445,\"prompt_tokens\":47,\"prompt_time\":0.010602797,\"completion_tokens\":3,\"completion_time\":0.010909091,\"total_tokens\":50,\"total_time\":0.021511888},\"usage_breakdown\":{\"models\":null},\"system_fingerprint\":\"fp_3f3b593e33\",\"instance_id\":\"LLAMA-33-70B-DMM1-PROD2-1\",\"x_groq\":{\"id\":\"req_01js27cen0fafs6j9xp8w18nfm\"}}",
"status": 200,
"statusText": "OK",
"headers": {
"alt-svc": "h3=\":443\"; ma=86400",
"cache-control": "private, max-age=0, no-store, no-cache, must-revalidate",
"cf-cache-status": "DYNAMIC",
"connection": "keep-alive",
"content-encoding": "br",
"content-type": "application/json",
"server": "cloudflare",
"transfer-encoding": "chunked",
"vary": "Origin, Accept-Encoding"
}
}
}
}
Loading