diff --git a/docs/agents.md b/docs/agents.md index a89b87dd8..d3f8c070d 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -153,6 +153,7 @@ class NeverResultType(TypedDict): agent = Agent( 'claude-3-5-sonnet-latest', + retries=3, result_type=NeverResultType, system_prompt='Any time you get a response, call the `infinite_retry_tool` to produce another response.', ) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 0aff3de3a..fc538bd4e 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -562,6 +562,8 @@ async def on_complete(): parts = await self._process_function_tools( tool_calls, result_tool_name, run_context, result_schema ) + if any(isinstance(part, _messages.RetryPromptPart) for part in parts): + self._incr_result_retry(run_context) if parts: messages.append(_messages.ModelRequest(parts)) run_span.set_attribute('all_messages', messages) @@ -1147,7 +1149,6 @@ async def _handle_structured_response( result_data = result_tool.validate(call) result_data = await self._validate_result(result_data, run_context, call) except _result.ToolRetryError as e: - self._incr_result_retry(run_context) parts.append(e.tool_retry) else: final_result = _MarkFinalResult(result_data, call.tool_name) @@ -1157,6 +1158,9 @@ async def _handle_structured_response( tool_calls, final_result and final_result.tool_name, run_context, result_schema ) + if any(isinstance(part, _messages.RetryPromptPart) for part in parts): + self._incr_result_retry(run_context) + return final_result, parts async def _process_function_tools( @@ -1210,7 +1214,7 @@ async def _process_function_tools( ) ) else: - parts.append(self._unknown_tool(call.tool_name, run_context, result_schema)) + parts.append(self._unknown_tool(call.tool_name, result_schema)) # Run all tool tasks in parallel if tasks: @@ -1257,7 +1261,7 @@ async def _handle_streamed_response( if tool := self._function_tools.get(p.tool_name): tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name)) else: - parts.append(self._unknown_tool(p.tool_name, run_context, result_schema)) + parts.append(self._unknown_tool(p.tool_name, result_schema)) if received_text and not tasks and not parts: # Can only get here if self._allow_text_result returns `False` for the provided result_schema @@ -1270,6 +1274,10 @@ async def _handle_streamed_response( with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks) parts.extend(task_results) + + if any(isinstance(part, _messages.RetryPromptPart) for part in parts): + self._incr_result_retry(run_context) + return model_response, parts async def _validate_result( @@ -1307,10 +1315,8 @@ async def _sys_parts(self, run_context: RunContext[AgentDepsT]) -> list[_message def _unknown_tool( self, tool_name: str, - run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultData] | None, ) -> _messages.RetryPromptPart: - self._incr_result_retry(run_context) names = list(self._function_tools.keys()) if result_schema: names.extend(result_schema.tool_names())