diff --git a/pydantic_ai_slim/pydantic_ai/models/ollama.py b/pydantic_ai_slim/pydantic_ai/models/ollama.py index 8a0f30340..2ac6d1065 100644 --- a/pydantic_ai_slim/pydantic_ai/models/ollama.py +++ b/pydantic_ai_slim/pydantic_ai/models/ollama.py @@ -1,10 +1,13 @@ from __future__ import annotations as _annotations +import json from dataclasses import dataclass -from typing import Literal, Union +from datetime import datetime, timezone +from typing import Any, Literal, Union from httpx import AsyncClient as AsyncHTTPClient +from ..messages import ModelResponse, ModelResponsePart, TextPart, ToolCallPart from ..tools import ToolDefinition from . import ( AgentModel, @@ -14,6 +17,9 @@ try: from openai import AsyncOpenAI + from openai.types import chat + + from .openai import OpenAIAgentModel except ImportError as e: raise ImportError( 'Please install `openai` to use the OpenAI model, ' @@ -54,6 +60,53 @@ """ +class NestedJSONDecoder(json.JSONDecoder): + """Modification of the built-in json decoder to enable decoding of nested models provided by the Ollama API.""" + + def decode(self, s, _w=json.decoder.WHITESPACE.match): # type: ignore + parsed = super().decode(s) + return self._decode_nested(parsed) + + def _decode_nested(self, obj: dict[str, Any] | list[dict[str, Any] | str] | str) -> str | dict[str, Any]: + if isinstance(obj, dict): + return {key: self._decode_nested(value) for key, value in obj.items()} + elif isinstance(obj, list): + return [self._decode_nested(item) for item in obj] # type: ignore + elif isinstance(obj, str): + try: + return self._decode_nested(json.loads(obj)) + except json.JSONDecodeError: + return obj + + +@dataclass +class OllamaAgentModel(OpenAIAgentModel): + """Agent model for the Ollama API. Contains special handling for the escape characters in ollama responses.""" + + @staticmethod + def _process_response(response: chat.ChatCompletion) -> ModelResponse: + """Override that deals with the extra escape characters in ollama responses.""" + timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc) + choice = response.choices[0] + items: list[ModelResponsePart] = [] + + if choice.message.content is not None: + items.append(TextPart(choice.message.content)) + + if choice.message.tool_calls is not None: + for c in choice.message.tool_calls: + try: + items.append( + ToolCallPart.from_raw_args( + c.function.name, NestedJSONDecoder().decode(c.function.arguments), c.id + ) + ) + except json.JSONDecodeError: + items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id)) + + return ModelResponse(items, timestamp=timestamp) + + @dataclass(init=False) class OllamaModel(Model): """A model that implements Ollama using the OpenAI API. @@ -102,6 +155,17 @@ def __init__( oai_client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client_) self.openai_model = OpenAIModel(model_name=model_name, openai_client=oai_client) + @staticmethod + def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam: + return { + 'type': 'function', + 'function': { + 'name': f.name, + 'description': f.description, + 'parameters': f.parameters_json_schema, + }, + } + async def agent_model( self, *, @@ -109,10 +173,15 @@ async def agent_model( allow_text_result: bool, result_tools: list[ToolDefinition], ) -> AgentModel: - return await self.openai_model.agent_model( - function_tools=function_tools, + tools = [self._map_tool_definition(r) for r in function_tools] + if result_tools: + tools += [self._map_tool_definition(r) for r in result_tools] + + return OllamaAgentModel( + client=self.openai_model.client, + model_name=self.openai_model.model_name, allow_text_result=allow_text_result, - result_tools=result_tools, + tools=tools, ) def name(self) -> str: