From 77f090abd0e732a446c692e12fc33e8d4ef4295a Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 22 Jan 2025 14:29:58 -0700 Subject: [PATCH 1/2] Fix an issue with retry counting --- pydantic_ai_slim/pydantic_ai/agent.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 4a9b4493c..57be32bc1 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -560,6 +560,8 @@ async def on_complete(): parts = await self._process_function_tools( tool_calls, result_tool_name, run_context, result_schema ) + if any(isinstance(part, _messages.RetryPromptPart) for part in parts): + self._incr_result_retry(run_context) if parts: messages.append(_messages.ModelRequest(parts)) run_span.set_attribute('all_messages', messages) @@ -1133,7 +1135,6 @@ async def _handle_structured_response( result_data = result_tool.validate(call) result_data = await self._validate_result(result_data, run_context, call) except _result.ToolRetryError as e: - self._incr_result_retry(run_context) parts.append(e.tool_retry) else: final_result = _MarkFinalResult(result_data, call.tool_name) @@ -1143,6 +1144,9 @@ async def _handle_structured_response( tool_calls, final_result and final_result.tool_name, run_context, result_schema ) + if any(isinstance(part, _messages.RetryPromptPart) for part in parts): + self._incr_result_retry(run_context) + return final_result, parts async def _process_function_tools( @@ -1196,7 +1200,7 @@ async def _process_function_tools( ) ) else: - parts.append(self._unknown_tool(call.tool_name, run_context, result_schema)) + parts.append(self._unknown_tool(call.tool_name, result_schema)) # Run all tool tasks in parallel if tasks: @@ -1243,7 +1247,7 @@ async def _handle_streamed_response( if tool := self._function_tools.get(p.tool_name): tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name)) else: - parts.append(self._unknown_tool(p.tool_name, run_context, result_schema)) + parts.append(self._unknown_tool(p.tool_name, result_schema)) if received_text and not tasks and not parts: # Can only get here if self._allow_text_result returns `False` for the provided result_schema @@ -1256,6 +1260,10 @@ async def _handle_streamed_response( with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks) parts.extend(task_results) + + if any(isinstance(part, _messages.RetryPromptPart) for part in parts): + self._incr_result_retry(run_context) + return model_response, parts async def _validate_result( @@ -1293,10 +1301,8 @@ async def _sys_parts(self, run_context: RunContext[AgentDeps]) -> list[_messages def _unknown_tool( self, tool_name: str, - run_context: RunContext[AgentDeps], result_schema: _result.ResultSchema[RunResultData] | None, ) -> _messages.RetryPromptPart: - self._incr_result_retry(run_context) names = list(self._function_tools.keys()) if result_schema: names.extend(result_schema.tool_names()) From 7280507eb6cde7180f2b5fde97e37df7ed3a178f Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Thu, 23 Jan 2025 10:51:32 -0700 Subject: [PATCH 2/2] Be explicit about the number of allowed retries --- docs/agents.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/agents.md b/docs/agents.md index a89b87dd8..d3f8c070d 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -153,6 +153,7 @@ class NeverResultType(TypedDict): agent = Agent( 'claude-3-5-sonnet-latest', + retries=3, result_type=NeverResultType, system_prompt='Any time you get a response, call the `infinite_retry_tool` to produce another response.', )