Skip to content

Commit 1bd5aa3

Browse files
committed
feat(sdk-js): choosing model in generate
1 parent a717dbd commit 1bd5aa3

File tree

4 files changed

+141
-44
lines changed

4 files changed

+141
-44
lines changed

sdk/embedbase-js/src/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ type Chat = {
8282
export interface GenerateOptions {
8383
history: Chat[]
8484
url?: string
85+
model?: 'gpt-3.5-turbo-16k' | 'falcon'
8586
}
8687

8788
export interface RangeOptions {

sdk/embedbase-py/embedbase_client/async_client.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional
1+
from typing import Any, Dict, List, Optional, Union
22

33
import asyncio
44
import itertools
@@ -579,12 +579,13 @@ def batch_chunks(l, n):
579579

580580
async def create_max_context(
581581
self,
582-
dataset: str,
582+
dataset: Union[str, List[str]],
583583
query: str,
584-
max_tokens: int,
584+
max_tokens: Union[int, List[int]],
585585
) -> str:
586586
"""
587-
Create a context from a query by searching for similar documents and concatenating them up to the specified max tokens.
587+
Create a context from a query by searching for similar documents and
588+
concatenating them up to the specified max tokens.
588589
589590
Args:
590591
dataset: The name of the dataset to search.
@@ -595,32 +596,55 @@ async def create_max_context(
595596
A string containing the context.
596597
597598
Example usage:
598-
context = create_max_context("Python is a programming language.", max_tokens=30)
599+
context = await create_max_context("programming", "Python is a programming language.", 30)
599600
print(context)
600601
# Python is a programming language.
601602
# Python is a high-level, general-purpose programming language.
602603
# Python is interpreted, dynamically typed and garbage-collected.
603604
# Python is designed to be highly extensible.
604605
# Python is a multi-paradig
606+
# or
607+
context = await create_max_context(["programming", "science"], "Python lives planet earth.", [3, 30])
608+
print(context)
609+
# Pyt
610+
# The earth orbits the sun.
611+
# The earth is the third planet from the sun.
612+
# The earth is the only planet known to support life.
613+
# The earth formed approximately 4.5 billion years ago.
614+
# The earth's gravity interacts with other objects in space, especially the sun and the moon.
605615
"""
606616

607-
# try to build a context until it's big enough by incrementing top_k
608-
top_k = 100
609-
context = await self.create_context(dataset, query, top_k)
610-
merged_context, size = merge_and_return_tokens(context, max_tokens)
611-
612-
tries = 0
613-
max_tries = 3
614-
while size < max_tokens and tries < max_tries:
615-
top_k *= 3
616-
context = await self.create_context(dataset, query, top_k)
617+
async def create_context_for_dataset(d, max_tokens):
618+
top_k = 100
619+
context = await self.create_context(d, query, top_k)
617620
merged_context, size = merge_and_return_tokens(context, max_tokens)
618-
tries += 1
619621

620-
if size < max_tokens:
621-
# warn the user that the context is smaller than the max tokens
622-
print(
623-
f"Warning: context is smaller than the max tokens ({size} < {max_tokens})"
624-
)
622+
tries = 0
623+
max_tries = 3
624+
while size < max_tokens and tries < max_tries:
625+
top_k *= 3
626+
context = await self.create_context(dataset, query, top_k)
627+
merged_context, size = merge_and_return_tokens(context, max_tokens)
628+
tries += 1
629+
630+
if size < max_tokens:
631+
print(
632+
f"Warning: context for dataset '{dataset}' is smaller than the max tokens ({size} < {max_tokens})"
633+
)
634+
return merged_context
635+
636+
if not isinstance(dataset, list):
637+
dataset = [dataset]
638+
639+
if not isinstance(max_tokens, list):
640+
max_tokens = [max_tokens for _ in range(len(dataset))]
641+
642+
if len(dataset) != len(max_tokens):
643+
raise ValueError("The number of datasets and max_tokens should be equal.")
644+
645+
contexts = []
646+
for ds, mt in zip(dataset, max_tokens):
647+
context = await create_context_for_dataset(ds, mt)
648+
contexts.append(context)
625649

626-
return merged_context
650+
return "\n\n".join(contexts)

sdk/embedbase-py/embedbase_client/sync_client.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, Generator, List, Optional
1+
from typing import Any, Dict, Generator, List, Optional, Union
22

33
import itertools
44
import json
@@ -600,12 +600,13 @@ def add_batch(batch):
600600

601601
def create_max_context(
602602
self,
603-
dataset: str,
603+
dataset: Union[str, List[str]],
604604
query: str,
605-
max_tokens: int,
605+
max_tokens: Union[int, List[int]],
606606
) -> str:
607607
"""
608-
Create a context from a query by searching for similar documents and concatenating them up to the specified max tokens.
608+
Create a context from a query by searching for similar documents and
609+
concatenating them up to the specified max tokens.
609610
610611
Args:
611612
dataset: The name of the dataset to search.
@@ -616,32 +617,55 @@ def create_max_context(
616617
A string containing the context.
617618
618619
Example usage:
619-
context = embedbase.create_max_context("my_dataset", "What is Python?", max_tokens=30)
620+
context = create_max_context("programming", "Python is a programming language.", 30)
620621
print(context)
621622
# Python is a programming language.
622623
# Python is a high-level, general-purpose programming language.
623624
# Python is interpreted, dynamically typed and garbage-collected.
624625
# Python is designed to be highly extensible.
625-
# Python is a multi-paradig...
626+
# Python is a multi-paradig
627+
# or
628+
context = create_max_context(["programming", "science"], "Python lives planet earth.", [3, 30])
629+
print(context)
630+
# Pyt
631+
# The earth orbits the sun.
632+
# The earth is the third planet from the sun.
633+
# The earth is the only planet known to support life.
634+
# The earth formed approximately 4.5 billion years ago.
635+
# The earth's gravity interacts with other objects in space, especially the sun and the moon.
626636
"""
627637

628-
# try to build a context until it's big enough by incrementing top_k
629-
top_k = 100
630-
context = self.create_context(dataset, query, top_k)
631-
merged_context, size = merge_and_return_tokens(context, max_tokens)
632-
633-
tries = 0
634-
max_tries = 3
635-
while size < max_tokens and tries < max_tries:
636-
top_k *= 3
638+
def create_context_for_dataset(dataset, max_tokens):
639+
top_k = 100
637640
context = self.create_context(dataset, query, top_k)
638641
merged_context, size = merge_and_return_tokens(context, max_tokens)
639-
tries += 1
640642

641-
if size < max_tokens:
642-
# warn the user that the context is smaller than the max tokens
643-
print(
644-
f"Warning: context is smaller than the max tokens ({size} < {max_tokens})"
645-
)
643+
tries = 0
644+
max_tries = 3
645+
while size < max_tokens and tries < max_tries:
646+
top_k *= 3
647+
context = self.create_context(dataset, query, top_k)
648+
merged_context, size = merge_and_return_tokens(context, max_tokens)
649+
tries += 1
650+
651+
if size < max_tokens:
652+
print(
653+
f"Warning: context for dataset '{dataset}' is smaller than the max tokens ({size} < {max_tokens})"
654+
)
655+
return merged_context
656+
657+
if not isinstance(dataset, list):
658+
dataset = [dataset]
659+
660+
if not isinstance(max_tokens, list):
661+
max_tokens = [max_tokens for _ in range(len(dataset))]
662+
663+
if len(dataset) != len(max_tokens):
664+
raise ValueError("The number of datasets and max_tokens should be equal.")
665+
666+
contexts = []
667+
for ds, mt in zip(dataset, max_tokens):
668+
context = create_context_for_dataset(ds, mt)
669+
contexts.append(context)
646670

647-
return merged_context
671+
return "\n\n".join(contexts)

sdk/embedbase-py/tests/test_client.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,51 @@ async def test_create_max_context_async():
203203

204204
assert isinstance(context, str)
205205
assert len(tokenizer.encode(context)) <= max_tokens
206+
207+
208+
@pytest.mark.asyncio
209+
async def test_create_max_context_multiple_datasets_async():
210+
query = "What is Python?"
211+
dataset1 = "programming"
212+
dataset2 = "animals"
213+
max_tokens1 = 20
214+
max_tokens2 = 25
215+
await client.dataset(dataset1).clear()
216+
await client.dataset(dataset2).clear()
217+
programming_documents = [
218+
"Python is a programming language.",
219+
"Java is another popular programming language.",
220+
"JavaScript is widely used for web development.",
221+
"C++ is commonly used for system programming.",
222+
"Ruby is known for its simplicity and readability.",
223+
"Go is a statically typed language developed by Google.",
224+
"Rust is a systems programming language that focuses on safety and performance.",
225+
"TypeScript is a superset of JavaScript that adds static typing.",
226+
"PHP is a server-side scripting language used for web development.",
227+
"Swift is a modern programming language developed by Apple for iOS app development.",
228+
]
229+
animal_documents = [
230+
"Python is a type of snake.",
231+
"Lions are known as the king of the jungle.",
232+
"Elephants are the largest land animals.",
233+
"Giraffes are known for their long necks.",
234+
"Kangaroos are native to Australia.",
235+
"Pandas are native to China and primarily eat bamboo.",
236+
"Penguins live primarily in the Southern Hemisphere.",
237+
"Tigers are carnivorous mammals found in Asia.",
238+
"Whales are large marine mammals.",
239+
"Zebras are part of the horse family and native to Africa.",
240+
]
241+
242+
await client.dataset(dataset1).batch_add([{"data": d} for d in programming_documents])
243+
await client.dataset(dataset2).batch_add([{"data": d} for d in animal_documents])
244+
context = await client.create_max_context(
245+
[dataset1, dataset2], query, [max_tokens1, max_tokens2]
246+
)
247+
tokenizer = get_encoding("cl100k_base")
248+
249+
assert isinstance(context, str)
250+
context_parts = context.split("\n")
251+
assert len(context_parts) == 2
252+
assert len(tokenizer.encode(context_parts[0])) <= max_tokens1
253+
assert len(tokenizer.encode(context_parts[1])) <= max_tokens2

0 commit comments

Comments
 (0)