-
Notifications
You must be signed in to change notification settings - Fork 3.3k
/
Copy pathbase.py
369 lines (316 loc) · 15 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
import asyncio
import datetime
import json
from enum import Enum
from time import time
from typing import Any, Callable, Optional, Tuple
import httpx
from core.config import LLMConfig, LLMProvider
from core.llm.convo import Convo
from core.llm.request_log import LLMRequestLog, LLMRequestStatus
from core.log import get_logger
log = get_logger(__name__)
class LLMError(str, Enum):
KEY_EXPIRED = "key_expired"
RATE_LIMITED = "rate_limited"
GENERIC_API_ERROR = "generic_api_error"
class APIError(Exception):
def __init__(self, message: str):
self.message = message
class BaseLLMClient:
"""
Base asynchronous streaming client for language models.
Example usage:
>>> async def stream_handler(content: str):
... print(content)
...
>>> def parser(content: str) -> dict:
... return json.loads(content)
...
>>> client_class = BaseClient.for_provider(provider)
>>> client = client_class(config, stream_handler=stream_handler)
>>> response, request_log = await client(convo, parser=parser)
"""
provider: LLMProvider
def __init__(
self,
config: LLMConfig,
*,
stream_handler: Optional[Callable] = None,
error_handler: Optional[Callable] = None,
):
"""
Initialize the client with the given configuration.
:param config: Configuration for the client.
:param stream_handler: Optional handler for streamed responses.
"""
self.config = config
self.stream_handler = stream_handler
self.error_handler = error_handler
self._init_client()
def _init_client(self):
raise NotImplementedError()
async def _make_request(
self,
convo: Convo,
temperature: Optional[float] = None,
json_mode: bool = False,
) -> tuple[str, int, int]:
"""
Call the Anthropic Claude model with the given conversation.
Low-level method that streams the response chunks.
Use `__call__` instead of this method.
:param convo: Conversation to send to the LLM.
:param json_mode: If True, the response is expected to be JSON.
:return: Tuple containing the full response content, number of input tokens, and number of output tokens.
"""
raise NotImplementedError()
async def _adapt_messages(self, convo: Convo) -> list[dict[str, str]]:
"""
Adapt the conversation messages to the format expected by the LLM.
Claude only recognizes "user" and "assistant roles"
:param convo: Conversation to adapt.
:return: Adapted conversation messages.
"""
messages = []
for msg in convo.messages:
if msg.role == "function":
raise ValueError("Anthropic Claude doesn't support function calling")
role = "user" if msg.role in ["user", "system"] else "assistant"
if messages and messages[-1]["role"] == role:
messages[-1]["content"] += "\n\n" + msg.content
else:
messages.append(
{
"role": role,
"content": msg.content,
}
)
return messages
async def __call__(
self,
convo: Convo,
*,
temperature: Optional[float] = None,
parser: Optional[Callable] = None,
max_retries: int = 3,
json_mode: bool = False,
) -> Tuple[Any, LLMRequestLog]:
"""
Invoke the LLM with the given conversation.
Stream handler, if provided, should be an async function
that takes a single argument, the response content (str).
It will be called for each response chunk.
Parser, if provided, should be a function that takes the
response content (str) and returns the parsed response.
On parse error, the parser should raise a ValueError with
a descriptive error message that will be sent back to the LLM
to retry, up to max_retries.
:param convo: Conversation to send to the LLM.
:param parser: Optional parser for the response.
:param max_retries: Maximum number of retries for parsing the response.
:param json_mode: If True, the response is expected to be JSON.
:return: Tuple of the (parsed) response and request log entry.
"""
import anthropic
import groq
import openai
if temperature is None:
temperature = self.config.temperature
convo = convo.fork()
request_log = LLMRequestLog(
provider=self.provider,
model=self.config.model,
temperature=temperature,
prompts=convo.prompt_log,
)
prompt_length_kb = len(json.dumps(convo.messages).encode("utf-8")) / 1024
log.debug(
f"Calling {self.provider.value} model {self.config.model} (temp={temperature}), prompt length: {prompt_length_kb:.1f} KB"
)
t0 = time()
remaining_retries = max_retries
while True:
if remaining_retries == 0:
# We've run out of auto-retries
if request_log.error:
last_error_msg = f"Error connecting to the LLM: {request_log.error}"
else:
last_error_msg = "Error parsing LLM response"
# If we can, ask the user if they want to keep retrying
if self.error_handler:
should_retry = await self.error_handler(LLMError.GENERIC_API_ERROR, message=last_error_msg)
if should_retry:
remaining_retries = max_retries
continue
# They don't want to retry (or we can't ask them), raise the last error and stop Pythagora
raise APIError(last_error_msg)
remaining_retries -= 1
request_log.messages = convo.messages[:]
request_log.response = None
request_log.status = LLMRequestStatus.SUCCESS
request_log.error = None
response = None
try:
response, prompt_tokens, completion_tokens = await self._make_request(
convo,
temperature=temperature,
json_mode=json_mode,
)
except (openai.APIConnectionError, anthropic.APIConnectionError, groq.APIConnectionError) as err:
log.warning(f"API connection error: {err}", exc_info=True)
request_log.error = str(f"API connection error: {err}")
request_log.status = LLMRequestStatus.ERROR
continue
except httpx.ReadTimeout as err:
log.warning(f"Read timeout (set to {self.config.read_timeout}s): {err}", exc_info=True)
request_log.error = str(f"Read timeout: {err}")
request_log.status = LLMRequestStatus.ERROR
continue
except httpx.ReadError as err:
log.warning(f"Read error: {err}", exc_info=True)
request_log.error = str(f"Read error: {err}")
request_log.status = LLMRequestStatus.ERROR
continue
except (openai.RateLimitError, anthropic.RateLimitError, groq.RateLimitError) as err:
log.warning(f"Rate limit error: {err}", exc_info=True)
request_log.error = str(f"Rate limit error: {err}")
request_log.status = LLMRequestStatus.ERROR
wait_time = self.rate_limit_sleep(err)
if wait_time:
message = f"We've hit {self.config.provider.value} rate limit. Sleeping for {wait_time.seconds} seconds..."
if self.error_handler:
await self.error_handler(LLMError.RATE_LIMITED, message)
await asyncio.sleep(wait_time.seconds)
continue
else:
# RateLimitError that shouldn't be retried, eg. insufficient funds
err_msg = err.response.json().get("error", {}).get("message", "Rate limiting error.")
raise APIError(err_msg) from err
except (openai.NotFoundError, anthropic.NotFoundError, groq.NotFoundError) as err:
err_msg = err.response.json().get("error", {}).get("message", f"Model not found: {self.config.model}")
raise APIError(err_msg) from err
except (openai.AuthenticationError, anthropic.AuthenticationError, groq.AuthenticationError) as err:
log.warning(f"Key expired: {err}", exc_info=True)
err_msg = err.response.json().get("error", {}).get("message", "Incorrect API key")
if "[BricksLLM]" in err_msg:
# We only want to show the key expired message if it's from Bricks
if self.error_handler:
should_retry = await self.error_handler(LLMError.KEY_EXPIRED)
if should_retry:
continue
raise APIError(err_msg) from err
except (openai.APIStatusError, anthropic.APIStatusError, groq.APIStatusError) as err:
# Token limit exceeded (in original gpt-pilot handled as
# TokenLimitError) is thrown as 400 (OpenAI, Anthropic) or 413 (Groq).
# All providers throw an exception that is caught here.
# OpenAI and Groq return a `code` field in the error JSON that lets
# us confirm that we've breached the token limit, but Anthropic doesn't,
# so we can't be certain that's the problem in Anthropic case.
# Here we try to detect that and tell the user what happened.
log.info(f"API status error: {err}")
try:
if hasattr(err, "response"):
if err.response.headers.get("Content-Type", "").startswith("application/json"):
err_code = err.response.json().get("error", {}).get("code", "")
else:
err_code = str(err.response.text)
elif isinstance(err, str):
err_code = err
else:
err_code = json.dumps(err)
except Exception as e:
err_code = f"Error parsing response: {str(e)}"
if err_code in ("request_too_large", "context_length_exceeded", "string_above_max_length"):
# Handle OpenAI and Groq token limit exceeded
# OpenAI will return `string_above_max_length` for prompts more than 1M characters
message = "".join(
[
"We sent too large request to the LLM, resulting in an error. ",
"This is usually caused by including framework files in an LLM request. ",
"Here's how you can get GPT Pilot to ignore those extra files: ",
"https://bit.ly/faq-token-limit-error",
]
)
raise APIError(message) from err
log.warning(f"API error: {err}", exc_info=True)
request_log.error = str(f"API error: {err}")
request_log.status = LLMRequestStatus.ERROR
continue
except (openai.APIError, anthropic.APIError, groq.APIError) as err:
# Generic LLM API error
# Make sure this handler is last in the chain as some of the above
# errors inherit from these `APIError` classes
log.warning(f"LLM API error {err}", exc_info=True)
request_log.error = f"LLM had an error processing our request: {err}"
request_log.status = LLMRequestStatus.ERROR
continue
request_log.response = response
request_log.prompt_tokens += prompt_tokens
request_log.completion_tokens += completion_tokens
if parser:
try:
response = parser(response)
break
except ValueError as err:
request_log.error = f"Error parsing response: {err}"
request_log.status = LLMRequestStatus.ERROR
log.debug(f"Error parsing LLM response: {err}, asking LLM to retry", exc_info=True)
convo.assistant(response)
convo.user(f"Error parsing response: {err}. Please output your response EXACTLY as requested.")
continue
else:
break
t1 = time()
request_log.duration = t1 - t0
log.debug(
f"Total {self.provider.value} response time {request_log.duration:.2f}s, {request_log.prompt_tokens} prompt tokens, {request_log.completion_tokens} completion tokens used"
)
return response, request_log
async def api_check(self) -> bool:
"""
Perform an LLM API check.
:return: True if the check was successful, False otherwise.
"""
convo = Convo()
msg = "This is a connection test. If you can see this, please respond only with 'START' and nothing else."
convo.user(msg)
resp, _log = await self(convo)
return bool(resp)
@staticmethod
def for_provider(provider: LLMProvider) -> type["BaseLLMClient"]:
"""
Return LLM client for the specified provider.
:param provider: Provider to return the client for.
:return: Client class for the specified provider.
"""
from .aiml_client import AIMLClient
from .anthropic_client import AnthropicClient
from .azure_client import AzureClient
from .groq_client import GroqClient
from .openai_client import OpenAIClient
if provider == LLMProvider.OPENAI:
return OpenAIClient
elif provider == LLMProvider.AIML:
return AIMLClient
elif provider == LLMProvider.ANTHROPIC:
return AnthropicClient
elif provider == LLMProvider.GROQ:
return GroqClient
elif provider == LLMProvider.AZURE:
return AzureClient
else:
raise ValueError(f"Unsupported LLM provider: {provider.value}")
def rate_limit_sleep(self, err: Exception) -> Optional[datetime.timedelta]:
"""
Return how long we need to sleep because of rate limiting.
These are computed from the response headers that each LLM returns.
For details, check the implementation for the specific LLM. If there
are no rate limiting headers, we assume that the request should not
be retried and return None (this will be the case for insufficient
quota/funds in the account).
:param err: RateLimitError that was raised by the LLM client.
:return: optional timedelta to wait before trying again
"""
raise NotImplementedError()
__all__ = ["BaseLLMClient"]