Skip to content

embeddings: add support for prefixes in embeddings #4524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions core/config/yaml/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ async function modelConfigToBaseLLM(
if (model.embedOptions?.maxChunkSize) {
options.maxEmbeddingChunkSize = model.embedOptions.maxChunkSize;
}
if (model.embedOptions?.embeddingPrefixes) {
options.embeddingPrefixes = model.embedOptions.embeddingPrefixes;
}

// These are params that are at model config level in JSON
// But we decided to move to nested `env` in YAML
Expand Down
7 changes: 6 additions & 1 deletion core/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ export interface ILLM extends LLMOptions {
options?: LLMFullCompletionOptions,
): Promise<ChatMessage>;

embed(chunks: string[]): Promise<number[][]>;
embed(chunks: string[], embedding_task: EmbeddingTasks): Promise<number[][]>;

rerank(query: string, chunks: Chunk[]): Promise<number[]>;

Expand Down Expand Up @@ -479,6 +479,9 @@ export interface LLMFullCompletionOptions extends BaseCompletionOptions {

export type ToastType = "info" | "error" | "warning";

export type EmbeddingTasks = "chunk" | "query";
export type EmbeddingPrefixes = Partial<Record<EmbeddingTasks, string>>;

export interface LLMOptions {
model: string;

Expand Down Expand Up @@ -514,6 +517,7 @@ export interface LLMOptions {
embeddingId?: string;
maxEmbeddingChunkSize?: number;
maxEmbeddingBatchSize?: number;
embeddingPrefixes?: EmbeddingPrefixes;

// Cloudflare options
accountId?: string;
Expand Down Expand Up @@ -978,6 +982,7 @@ export interface EmbedOptions {
requestOptions?: RequestOptions;
maxEmbeddingChunkSize?: number;
maxEmbeddingBatchSize?: number;
embeddingPrefixes?: EmbeddingPrefixes;

// AWS options
profile?: string;
Expand Down
5 changes: 3 additions & 2 deletions core/indexing/LanceDbIndex.ts
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ export class LanceDbIndex implements CodebaseIndex {
return [];
}
try {
return await this.embeddingsProvider.embed(chunks.map((c) => c.content));
return await this.embeddingsProvider.embed(chunks.map((c) => c.content), "chunk");
} catch (err) {
throw new Error(
`Failed to generate embeddings for ${chunks.length} chunks with provider: ${this.embeddingsProvider.embeddingId}: ${err}`,
Expand Down Expand Up @@ -422,10 +422,11 @@ export class LanceDbIndex implements CodebaseIndex {
try {
[vector] = await this.embeddingsProvider.embed(
chunks.map((c) => c.content),
"query",
);
} catch (err) {
// If we fail to chunk, we just use what was happening before.
[vector] = await this.embeddingsProvider.embed([query]);
[vector] = await this.embeddingsProvider.embed([query], "query");
}

const db = await lance.connect(getLanceDbPath());
Expand Down
6 changes: 3 additions & 3 deletions core/indexing/docs/DocsService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ export default class DocsService {
// This particular failure will not mark as a failed config in global context
// Since SiteIndexingConfig is likely to be valid
try {
await provider.embed(["continue-test-run"]);
await provider.embed(["continue-test-run"], "chunk");
} catch (e) {
console.error("Failed to test embeddings connection", e);
return;
Expand Down Expand Up @@ -596,7 +596,7 @@ export default class DocsService {
try {
const subpathEmbeddings =
article.chunks.length > 0
? await provider.embed(article.chunks.map((c) => c.content))
? await provider.embed(article.chunks.map((c) => c.content), "chunk")
: [];
chunks.push(...article.chunks);
embeddings.push(...subpathEmbeddings);
Expand Down Expand Up @@ -761,7 +761,7 @@ export default class DocsService {
});
}

const [vector] = await provider.embed([query]);
const [vector] = await provider.embed([query], "query");

return await this.retrieveChunks(startUrl, vector, nRetrieve, isPreindexed);
}
Expand Down
21 changes: 16 additions & 5 deletions core/llm/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import {
ChatMessage,
Chunk,
CompletionOptions,
EmbeddingPrefixes,
EmbeddingTasks,
ILLM,
LLMFullCompletionOptions,
LLMOptions,
Expand Down Expand Up @@ -160,6 +162,7 @@ export abstract class BaseLLM implements ILLM {
embeddingId: string;
maxEmbeddingChunkSize: number;
maxEmbeddingBatchSize: number;
embeddingPrefixes?: EmbeddingPrefixes;

private _llmOptions: LLMOptions;

Expand Down Expand Up @@ -252,9 +255,13 @@ export abstract class BaseLLM implements ILLM {
options.maxEmbeddingBatchSize ?? DEFAULT_MAX_BATCH_SIZE;
this.maxEmbeddingChunkSize =
options.maxEmbeddingChunkSize ?? DEFAULT_MAX_CHUNK_SIZE;
this.embeddingId = `${this.constructor.name}::${this.model}::${this.maxEmbeddingChunkSize}`;
if (options?.embeddingPrefixes) {
this.embeddingPrefixes = options.embeddingPrefixes;
}
// The id is affected only by the chunking prefix because this is what used to embed the chunks.
const prefix = this?.embeddingPrefixes?.chunk ? "::" + this?.embeddingPrefixes?.chunk : undefined;
this.embeddingId = `${this.constructor.name}::${this.model}::${this.maxEmbeddingChunkSize}${prefix ?? ""}`;
}

protected createOpenAiAdapter() {
return constructLlmApi({
provider: this.providerName as any,
Expand Down Expand Up @@ -900,16 +907,20 @@ export abstract class BaseLLM implements ILLM {
return batchedChunks;
}

async embed(chunks: string[]): Promise<number[][]> {
async embed(chunks: string[], embedding_task : EmbeddingTasks): Promise<number[][]> {
const batches = this.getBatchedChunks(chunks);

return (
await Promise.all(
batches.map(async (batch) => {
if (batch.length === 0) {
return [];
}

if ((this?.embeddingPrefixes) && (this.embeddingPrefixes[embedding_task])) {
const prefix = this.embeddingPrefixes[embedding_task];
batch = batch.map((chunk) => {
return (prefix ?? "") + chunk;
});
}
const embeddings = await withExponentialBackoff<number[][]>(
async () => {
if (this.shouldUseOpenAIAdapter("embed") && this.openaiAdapter) {
Expand Down
2 changes: 1 addition & 1 deletion core/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
"@aws-sdk/client-sagemaker-runtime": "^3.758.0",
"@aws-sdk/credential-providers": "^3.758.0",
"@continuedev/config-types": "^1.0.13",
"@continuedev/config-yaml": "^1.0.63",
"@continuedev/config-yaml": "^1.0.64",
"@continuedev/fetch": "^1.0.4",
"@continuedev/llm-info": "^1.0.2",
"@continuedev/openai-adapters": "^1.0.18",
Expand Down
12 changes: 10 additions & 2 deletions docs/docs/customize/deep-dives/docs.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,11 @@ Here is the equivalent minimal example for Jetbrains, which requires setting up
],
"embeddingsProvider": {
"provider": "lmstudio",
"model": "nomic-ai/nomic-embed-text-v1.5-GGUF"
"model": "nomic-ai/nomic-embed-text-v1.5-GGUF",
"embeddingPrefixes" : {
"chunk": "search_document: ",
"query": "search_query: "
}
}
}
```
Expand Down Expand Up @@ -400,7 +404,11 @@ The following configuration example includes:
},
"embeddingsProvider": {
"provider": "lmstudio",
"model": "nomic-ai/nomic-embed-text-v1.5-GGUF"
"model": "nomic-ai/nomic-embed-text-v1.5-GGUF",
"embeddingPrefixes" : {
"chunk": "search_document: ",
"query": "search_query: "
},
}
}
```
Expand Down
10 changes: 9 additions & 1 deletion docs/docs/customize/model-providers/more/lmstudio.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ LMStudio supports embeddings endpoints, and comes with the `nomic-ai/nomic-embed
{
"embeddingsProvider": {
"provider": "lmstudio",
"model": "nomic-ai/nomic-embed-text-v1.5-GGUF"
"model": "nomic-ai/nomic-embed-tefxt-v1.5-GGUF",
"embeddingPrefixes" : {
"chunk": "search_document: ",
"query": "search_query: "
}
}
}
```
Expand Down Expand Up @@ -77,6 +81,10 @@ To configure a remote instance of LM Studio, add the `"apiBase"` property to you
"embeddingsProvider": {
"provider": "lmstudio",
"model": "nomic-ai/nomic-embed-text-v1.5-GGUF"
"embeddingPrefixes" : {
"chunk": "search_document: ",
"query": "search_query: "
},
}
}
```
Expand Down
6 changes: 5 additions & 1 deletion docs/docs/customize/model-providers/more/vllm.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,11 @@ We recommend configuring **Nomic Embed Text** as your embeddings model.
"embeddingsProvider": {
"provider": "vllm",
"model": "nomic-ai/nomic-embed-text-v1",
"apiBase": "http://<vllm embed endpoint>/v1"
"apiBase": "http://<vllm embed endpoint>/v1",
"embeddingPrefixes" : {
"chunk": "search_document: ",
"query": "search_query: "
}
}
}
```
Expand Down
6 changes: 5 additions & 1 deletion docs/docs/customize/model-providers/top-level/ollama.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@ We recommend configuring **Nomic Embed Text** as your embeddings model.
{
"embeddingsProvider": {
"provider": "ollama",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove the "embeddingPrefixes" from the JSON in the docs here. I know it's correct, but it ends up looking very complicated for a beginner setup. The way that we will solve this for all users going forward is by having the blocks on hub.continue.dev include the necessary embedding prefixes and then suggest in the docs that users use uses: ollama/nomic-embed-text for example

"model": "nomic-embed-text"
"model": "nomic-embed-text",
"embeddingPrefixes" : {
"chunk": "search_document: ",
"query": "search_query: "
}
}
}
```
Expand Down
6 changes: 5 additions & 1 deletion docs/docs/json-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ Example
"disableInFiles": ["*.md"]
}
}
````

### `embeddingsProvider`

Expand All @@ -118,7 +119,10 @@ Embeddings model settings - the model used for @Codebase and @docs.
- `apiBase`: Base URL for API requests.
- `requestOptions`: Additional HTTP request settings specific to the embeddings provider.
- `maxChunkSize`: Maximum tokens per document chunk. Minimum is 128 tokens.
- `maxBatchSize`: Maximum number of chunks per request. Minimum is 1 chunk.
- `maxBatchSize`: Maximunumber of chunks per request. Minimum is 1 chunk.
- `embeddingPrefixes` - the prefixes for embedding chunks.
- `chunk` - the prefix string for chunk embeddings
- `query` - the prefix for query embeddings

(AWS ONLY)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ slug: ../ollama
{
"embeddingsProvider": {
"provider": "ollama",
"model": "nomic-embed-text"
"model": "nomic-embed-text",
"embeddingPrefixes" : {
"chunk": "search_document: ",
"query": "search_query: "
}
}
}
```
Expand Down
16 changes: 16 additions & 0 deletions extensions/vscode/config_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -2805,6 +2805,22 @@
"title": "Profile",
"description": "The AWS security profile to use",
"type": "string"
},
"embeddingPrefixes" : {
"title": "Embedding Prefixes",
"description" : "Prefixes for the embedding model tasks",
"type": "object",
"properties": {
"chunk": {
"description": "Prefix to be used before chunks of code when generating embeddings.",
"type": "string"
},
"query": {
"description": "Prefix to be used before queries when generating embeddings.",
"type": "string"
}
},
"required": []
}
},
"required": ["provider"],
Expand Down
2 changes: 1 addition & 1 deletion packages/config-yaml/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@continuedev/config-yaml",
"version": "1.0.63",
"version": "1.0.64",
"description": "",
"main": "dist/index.js",
"types": "dist/index.d.ts",
Expand Down
10 changes: 10 additions & 0 deletions packages/config-yaml/src/schemas/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,19 @@ export const completionOptionsSchema = z.object({
});
export type CompletionOptions = z.infer<typeof completionOptionsSchema>;

export const embeddingTasksSchema = z.union([
z.literal("chunk"),
z.literal("query")
]);
export type EmbeddingTasks = z.infer<typeof embeddingTasksSchema>;

export const embeddingPrefixesSchema = z.record(embeddingTasksSchema, z.string());
export type EmbeddingPrefixes = z.infer<typeof embeddingPrefixesSchema>;

export const embedOptionsSchema = z.object({
maxChunkSize: z.number().optional(),
maxBatchSize: z.number().optional(),
embeddingPrefixes: embeddingPrefixesSchema.optional(),
});
export type EmbedOptions = z.infer<typeof embedOptionsSchema>;

Expand Down
Loading