diff --git a/requirements.txt b/requirements.txt index c7c0b55..17dc297 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ PyYAML==6.0 beautifulsoup4==4.12.2 numpy==1.24.2 openai==0.27.8 +litellm==0.1.226 python-dotenv==1.0.0 pytz==2023.3 sendgrid==6.10.0 diff --git a/src/utils.py b/src/utils.py index c128702..37fc69b 100644 --- a/src/utils.py +++ b/src/utils.py @@ -9,6 +9,7 @@ from typing import Optional, Sequence, Union import openai +import litellm import tqdm from openai import openai_object import copy @@ -70,6 +71,7 @@ def openai_completion( - a list of objects of the above types (if decoding_args.n > 1) """ is_chat_model = "gpt-3.5" in model_name or "gpt-4" in model_name + is_litellm_model = model_name in litellm.model_list is_single_prompt = isinstance(prompts, (str, dict)) if is_single_prompt: prompts = [prompts] @@ -113,6 +115,14 @@ def openai_completion( ], **shared_kwargs ) + elif is_litellm_model: + completion_batch = litellm.completion( + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt_batch[0]} + ], + **shared_kwargs + ) else: completion_batch = openai.Completion.create(prompt=prompt_batch, **shared_kwargs)