Skip to content

Commit 8a278ff

Browse files
spewRob Leidle
and
Rob Leidle
authored
Enable async encoding for the GPTEncoder (continuedev#1946)
This will improve the performance of indexing when using a GPT model. Also, improved the Async interface to also have an async decode method. Note: I was only able to test this method with the test I added as I do not have access to a GPT model. Co-authored-by: Rob Leidle <[email protected]>
1 parent 88c72c3 commit 8a278ff

File tree

5 files changed

+87
-9
lines changed

5 files changed

+87
-9
lines changed

core/llm/asyncEncoder.ts

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,23 @@ import * as path from "path";
55

66
export interface AsyncEncoder {
77
encode(text: string): Promise<number[]>;
8-
decode(tokens: number[]): string;
8+
decode(tokens: number[]): Promise<string>;
9+
close(): Promise<void>;
910
}
1011

1112
export class LlamaAsyncEncoder implements AsyncEncoder {
1213
private workerPool: workerpool.Pool;
1314

1415
constructor() {
15-
this.workerPool = workerpool.pool(path.join(__dirname, "/llamaTokenizerWorkerPool.mjs"));
16+
this.workerPool = workerpool.pool(workerCodeFilePath("llamaTokenizerWorkerPool.mjs"));
1617
}
1718

1819
async encode(text: string): Promise<number[]> {
1920
return this.workerPool.exec("encode", [text]);
2021
}
2122

22-
decode(tokens: number[]): string {
23-
return llamaTokenizer.decode(tokens);
23+
async decode(tokens: number[]): Promise<string> {
24+
return this.workerPool.exec("decode", [tokens]);
2425
}
2526

2627
// TODO: this should be called somewhere before exit or potentially with a shutdown hook
@@ -31,17 +32,30 @@ export class LlamaAsyncEncoder implements AsyncEncoder {
3132

3233
// this class does not yet do anything asynchronous
3334
export class GPTAsyncEncoder implements AsyncEncoder {
34-
private tiktokenEncoding: Tiktoken;
35+
private workerPool: workerpool.Pool;
3536

3637
constructor() {
37-
this.tiktokenEncoding = _encodingForModel("gpt-4");
38+
this.workerPool = workerpool.pool(workerCodeFilePath("tiktokenWorkerPool.mjs"));
3839
}
3940

4041
async encode(text: string): Promise<number[]> {
41-
return this.tiktokenEncoding.encode(text, "all", []);
42+
return this.workerPool.exec("encode", [text]);
43+
}
44+
45+
async decode(tokens: number[]): Promise<string> {
46+
return this.workerPool.exec("decode", [tokens]);
47+
}
48+
49+
// TODO: this should be called somewhere before exit or potentially with a shutdown hook
50+
public async close(): Promise<void> {
51+
await this.workerPool.terminate();
4252
}
53+
}
4354

44-
decode(tokens: number[]): string {
45-
return this.tiktokenEncoding.decode(tokens);
55+
function workerCodeFilePath(workerFileName: string): string {
56+
if (process.env.NODE_ENV === "test") {
57+
// `cross-env` seems to make it so __dirname is the root of the project and not the directory containing this file
58+
return path.join(__dirname, "llm", workerFileName);
4659
}
60+
return path.join(__dirname, workerFileName);
4761
}

core/llm/llamaTokenizerWorkerPool.mjs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ function encode(segment) {
55
return llamaTokenizer.encode(segment);
66
}
77

8+
function decode(tokens) {
9+
return llamaTokenizer.decode(tokens);
10+
}
11+
812
workerpool.worker({
13+
decode,
914
encode,
1015
});

core/llm/tiktokenWorkerPool.mjs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import workerpool from "workerpool";
2+
import { encodingForModel as _encodingForModel } from "js-tiktoken";
3+
4+
const tiktokenEncoding = _encodingForModel("gpt-4");
5+
6+
function encode(text) {
7+
return tiktokenEncoding.encode(text, "all", []);
8+
}
9+
10+
function decode(tokens) {
11+
return tiktokenEncoding.decode(tokens);
12+
}
13+
14+
workerpool.worker({
15+
decode,
16+
encode,
17+
});

core/test/llm/asyncEncoder.test.ts

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import { GPTAsyncEncoder, LlamaAsyncEncoder } from "../../llm/asyncEncoder";
2+
3+
describe("llama encoder", () => {
4+
var tokenizer: LlamaAsyncEncoder;
5+
6+
beforeAll(() => {
7+
tokenizer = new LlamaAsyncEncoder();
8+
});
9+
10+
afterAll(() => {
11+
tokenizer.close();
12+
});
13+
14+
test("hello world", async () => {
15+
const input = "the quick brown fox jumped over the lazy dog";
16+
const output = await tokenizer.encode(input);
17+
expect(output).toEqual([1, 278, 4996, 17354, 1701, 29916, 12500, 287, 975, 278, 17366, 11203]);
18+
const decoded = await tokenizer.decode(output);
19+
expect(decoded).toBe(input);
20+
});
21+
});
22+
23+
describe("tiktoken encoder", () => {
24+
var tokenizer: GPTAsyncEncoder;
25+
26+
beforeAll(() => {
27+
tokenizer = new GPTAsyncEncoder();
28+
});
29+
30+
afterAll(() => {
31+
tokenizer.close();
32+
});
33+
34+
test("hello world", async () => {
35+
const input = "the quick brown fox jumped over the lazy dog";
36+
const output = await tokenizer.encode(input);
37+
expect(output).toEqual([1820, 4062, 14198, 39935, 27096, 927, 279, 16053, 5679]);
38+
const decoded = await tokenizer.decode(output);
39+
expect(decoded).toBe(input);
40+
});
41+
});

extensions/vscode/scripts/prepackage.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ const exe = os === "win32" ? ".exe" : "";
245245
"../../../core/vendor/tree-sitter.wasm",
246246
"../../../core/llm/llamaTokenizerWorkerPool.mjs",
247247
"../../../core/llm/llamaTokenizer.mjs",
248+
"../../../core/llm/tiktokenWorkerPool.mjs",
248249
];
249250
for (const f of filesToCopy) {
250251
fs.copyFileSync(path.join(__dirname, f), path.join(__dirname, "..", "out", path.basename(f)));

0 commit comments

Comments
 (0)