Skip to content

Commit a4e8f37

Browse files
diadorerNikita Barinovhanouticelina
authored
✨ Add feature extraction task for Nebius (#1436)
Added support for [embedding models](https://studio.nebius.com/?modality=embedding) for Nebius provider. --------- Co-authored-by: Nikita Barinov <[email protected]> Co-authored-by: Celina Hanouti <[email protected]>
1 parent 8581ca8 commit a4e8f37

File tree

3 files changed

+47
-2
lines changed

3 files changed

+47
-2
lines changed

packages/inference/src/lib/getProviderHelper.ts

+1
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
115115
"text-to-image": new Nebius.NebiusTextToImageTask(),
116116
conversational: new Nebius.NebiusConversationalTask(),
117117
"text-generation": new Nebius.NebiusTextGenerationTask(),
118+
"feature-extraction": new Nebius.NebiusFeatureExtractionTask(),
118119
},
119120
novita: {
120121
conversational: new Novita.NovitaConversationalTask(),

packages/inference/src/providers/nebius.ts

+30-2
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
*
1515
* Thanks!
1616
*/
17+
import type { FeatureExtractionOutput } from "@huggingface/tasks";
1718
import { InferenceOutputError } from "../lib/InferenceOutputError";
1819
import type { BodyParams, UrlParams } from "../types";
1920
import { omit } from "../utils/omit";
2021
import {
2122
BaseConversationalTask,
2223
BaseTextGenerationTask,
2324
TaskProviderHelper,
25+
type FeatureExtractionTaskHelper,
2426
type TextToImageTaskHelper,
2527
} from "./providerHelper";
2628

@@ -32,6 +34,12 @@ interface NebiusBase64ImageGeneration {
3234
}>;
3335
}
3436

37+
interface NebiusEmbeddingsResponse {
38+
data: Array<{
39+
embedding: number[];
40+
}>;
41+
}
42+
3543
export class NebiusConversationalTask extends BaseConversationalTask {
3644
constructor() {
3745
super("nebius", NEBIUS_API_BASE_URL);
@@ -59,8 +67,7 @@ export class NebiusTextToImageTask extends TaskProviderHelper implements TextToI
5967
};
6068
}
6169

62-
makeRoute(params: UrlParams): string {
63-
void params;
70+
makeRoute(): string {
6471
return "v1/images/generations";
6572
}
6673

@@ -88,3 +95,24 @@ export class NebiusTextToImageTask extends TaskProviderHelper implements TextToI
8895
throw new InferenceOutputError("Expected Nebius text-to-image response format");
8996
}
9097
}
98+
99+
export class NebiusFeatureExtractionTask extends TaskProviderHelper implements FeatureExtractionTaskHelper {
100+
constructor() {
101+
super("nebius", NEBIUS_API_BASE_URL);
102+
}
103+
104+
preparePayload(params: BodyParams): Record<string, unknown> {
105+
return {
106+
input: params.args.inputs,
107+
model: params.model,
108+
};
109+
}
110+
111+
makeRoute(): string {
112+
return "v1/embeddings";
113+
}
114+
115+
async getResponse(response: NebiusEmbeddingsResponse): Promise<FeatureExtractionOutput> {
116+
return response.data.map((item) => item.embedding);
117+
}
118+
}

packages/inference/test/InferenceClient.spec.ts

+16
Original file line numberDiff line numberDiff line change
@@ -1369,6 +1369,12 @@ describe.skip("InferenceClient", () => {
13691369
status: "live",
13701370
task: "text-to-image",
13711371
},
1372+
"BAAI/bge-multilingual-gemma2": {
1373+
providerId: "BAAI/bge-multilingual-gemma2",
1374+
hfModelId: "BAAI/bge-multilingual-gemma2",
1375+
status: "live",
1376+
task: "feature-extraction",
1377+
},
13721378
};
13731379

13741380
it("chatCompletion", async () => {
@@ -1406,6 +1412,16 @@ describe.skip("InferenceClient", () => {
14061412
});
14071413
expect(res).toBeInstanceOf(Blob);
14081414
});
1415+
1416+
it("featureExtraction", async () => {
1417+
const res = await client.featureExtraction({
1418+
model: "BAAI/bge-multilingual-gemma2",
1419+
inputs: "That is a happy person",
1420+
});
1421+
1422+
expect(res).toBeInstanceOf(Array);
1423+
expect(res[0]).toEqual(expect.arrayContaining([expect.any(Number)]));
1424+
});
14091425
},
14101426
TIMEOUT
14111427
);

0 commit comments

Comments
 (0)