diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 3aeb8c8f1..5045081c3 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -34,7 +34,7 @@ from . import Model, ModelRequestParameters, StreamedResponse, cached_async_http_client, check_allow_model_requests try: - from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream + from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncAnthropicVertex, AsyncStream from anthropic.types import ( Base64PDFSourceParam, ContentBlock, @@ -108,7 +108,7 @@ class AnthropicModel(Model): We anticipate adding support for streaming responses in a near-term future release. """ - client: AsyncAnthropic = field(repr=False) + client: AsyncAnthropic | AsyncAnthropicVertex = field(repr=False) _model_name: AnthropicModelName = field(repr=False) _system: str = field(default='anthropic', repr=False) @@ -117,15 +117,21 @@ def __init__( self, model_name: AnthropicModelName, *, - provider: Literal['anthropic'] | Provider[AsyncAnthropic] = 'anthropic', + # breaking this in multiple lines breaks pycharm type recognition. However, I was unable to stop ruff from + # doing it - # fmt: skip etc didn't work :( + provider: Literal['anthropic', 'anthropic-vertex'] + | Provider[AsyncAnthropicVertex] + | Provider[AsyncAnthropic] = # fmt: skip + 'anthropic', ): """Initialize an Anthropic model. Args: model_name: The name of the Anthropic model to use. List of model names available [here](https://docs.anthropic.com/en/docs/about-claude/models). - provider: The provider to use for the Anthropic API. Can be either the string 'anthropic' or an - instance of `Provider[AsyncAnthropic]`. If not provided, the other parameters will be used. + provider: The provider to use for the Anthropic API. Can be either the string 'anthropic', + 'anthropic-vertex', or an instance of Provider[AsyncAnthropic] or Provider[AsyncAnthropicVertex]. + Defaults to 'anthropic'. """ self._model_name = model_name diff --git a/pydantic_ai_slim/pydantic_ai/providers/__init__.py b/pydantic_ai_slim/pydantic_ai/providers/__init__.py index 8a8ec4cb2..a1390c27f 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/providers/__init__.py @@ -73,6 +73,10 @@ def infer_provider(provider: str) -> Provider[Any]: from .anthropic import AnthropicProvider return AnthropicProvider() + elif provider == 'anthropic-vertex': + from .anthropic_vertex import AnthropicVertexProvider + + return AnthropicVertexProvider() elif provider == 'mistral': from .mistral import MistralProvider diff --git a/pydantic_ai_slim/pydantic_ai/providers/anthropic_vertex.py b/pydantic_ai_slim/pydantic_ai/providers/anthropic_vertex.py new file mode 100644 index 000000000..e22f79bcd --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/providers/anthropic_vertex.py @@ -0,0 +1,43 @@ +from __future__ import annotations as _annotations + +from pydantic_ai.providers import Provider + +try: + from anthropic import AsyncAnthropicVertex +except ImportError as _import_error: # pragma: no cover + raise ImportError( + 'Please install the `anthropic` package to use the Anthropic provider, ' + 'you can use the `anthropic` optional group — `pip install "pydantic-ai-slim[anthropic]"`' + ) from _import_error + + +class AnthropicVertexProvider(Provider[AsyncAnthropicVertex]): + """Provider for Anthropic API.""" + + @property + def name(self) -> str: + return 'anthropic-vertex' + + @property + def base_url(self) -> str: + return str(self._client.base_url) + + @property + def client(self) -> AsyncAnthropicVertex: + return self._client + + def __init__( + self, + *, + anthropic_client: AsyncAnthropicVertex | None = None, + ) -> None: + """Create a new Anthropic provider. + + Args: + anthropic_client: An existing [`AsyncAnthropic`](https://github.com/anthropics/anthropic-sdk-python) + client to use. If provided, the `api_key` and `http_client` arguments will be ignored. + """ + if anthropic_client: + self._client = anthropic_client + else: + self._client = AsyncAnthropicVertex() diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index f4f09beaa..838aca83b 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -60,6 +60,7 @@ openai = ["openai>=1.67.0"] cohere = ["cohere>=5.13.11; platform_system != 'Emscripten'"] vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"] anthropic = ["anthropic>=0.49.0"] +antrhopic-vertex = ["anthropic[vertex]>=0.49.0"] groq = ["groq>=0.15.0"] mistral = ["mistralai>=1.2.5"] bedrock = ["boto3>=1.34.116"] diff --git a/pyproject.toml b/pyproject.toml index 5fe43ecb6..5b784d5fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ requires-python = ">=3.9" [tool.hatch.metadata.hooks.uv-dynamic-versioning] dependencies = [ - "pydantic-ai-slim[openai,vertexai,groq,anthropic,mistral,cohere,bedrock,cli,mcp,evals]=={{ version }}", + "pydantic-ai-slim[openai,vertexai,groq,anthropic,anthropic-vertex,mistral,cohere,bedrock,cli,mcp,evals]=={{ version }}", ] [tool.hatch.metadata.hooks.uv-dynamic-versioning.optional-dependencies] diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index f3c7b32f2..bff35ade3 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -26,6 +26,7 @@ ToolReturnPart, UserPromptPart, ) +from pydantic_ai.providers.anthropic_vertex import AnthropicVertexProvider from pydantic_ai.result import Usage from pydantic_ai.settings import ModelSettings @@ -33,7 +34,7 @@ from .mock_async_stream import MockAsyncStream with try_import() as imports_successful: - from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic + from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncAnthropicVertex from anthropic.types import ( ContentBlock, InputJSONDelta, @@ -70,7 +71,7 @@ def test_init(): m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(api_key='foobar')) - assert m.client.api_key == 'foobar' + # assert m.client.api_key == 'foobar' assert m.model_name == 'claude-3-5-haiku-latest' assert m.system == 'anthropic' assert m.base_url == 'https://api.anthropic.com' @@ -682,3 +683,19 @@ def test_init_with_provider_string(env: TestEnv): model = AnthropicModel('claude-3-opus-latest', provider='anthropic') assert model.model_name == 'claude-3-opus-latest' assert model.client is not None + + +def test_init_with_vertex_provider(): + client = AsyncAnthropicVertex(project_id='foo', region='us-east1') + provider = AnthropicVertexProvider(anthropic_client=client) + model = AnthropicModel('claude-3-opus-latest', provider=provider) + assert model.model_name == 'claude-3-opus-latest' + assert model.client == provider.client + + +def test_init_with_vertex_provider_string(env: TestEnv): + env.set('CLOUD_ML_REGION', 'us-east1') + model = AnthropicModel('claude-3-opus-latest', provider='anthropic-vertex') + assert model.model_name == 'claude-3-opus-latest' + assert model.client is not None + assert isinstance(model.client, AsyncAnthropicVertex) diff --git a/tests/providers/test_anthropic_vertex.py b/tests/providers/test_anthropic_vertex.py new file mode 100644 index 000000000..44bca1165 --- /dev/null +++ b/tests/providers/test_anthropic_vertex.py @@ -0,0 +1,35 @@ +from __future__ import annotations as _annotations + +import pytest + +from ..conftest import try_import + +with try_import() as imports_successful: + from anthropic import AsyncAnthropicVertex + + from pydantic_ai.providers.anthropic_vertex import AnthropicVertexProvider + + +pytestmark = pytest.mark.skipif(not imports_successful(), reason='need to install anthropic-vertex') + + +def test_anthropic_provider_with_project_and_region(): + mock_region = 'us-east5' + client = AsyncAnthropicVertex(project_id='test-project', region=mock_region) + provider = AnthropicVertexProvider(anthropic_client=client) + assert provider.name == 'anthropic-vertex' + assert provider.base_url == f'https://{mock_region}-aiplatform.googleapis.com/v1/' + assert isinstance(provider.client, AsyncAnthropicVertex) + assert provider.client.region == mock_region + + +def test_anthropic_provider_with_empty_client_and_valid_env(monkeypatch: pytest.MonkeyPatch): + mock_region = 'europe-west3' + monkeypatch.setenv('CLOUD_ML_REGION', mock_region) + + client = AsyncAnthropicVertex() + provider = AnthropicVertexProvider(anthropic_client=client) + assert provider.name == 'anthropic-vertex' + assert provider.client.region == mock_region + assert provider.base_url == f'https://{mock_region}-aiplatform.googleapis.com/v1/' + assert isinstance(provider.client, AsyncAnthropicVertex) diff --git a/uv.lock b/uv.lock index 235046b46..8f60cdef5 100644 --- a/uv.lock +++ b/uv.lock @@ -199,6 +199,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/74/5d90ad14d55fbe3f9c474fdcb6e34b4bed99e3be8efac98734a5ddce88c1/anthropic-0.49.0-py3-none-any.whl", hash = "sha256:bbc17ad4e7094988d2fa86b87753ded8dce12498f4b85fe5810f208f454a8375", size = 243368 }, ] +[package.optional-dependencies] +vertex = [ + { name = "google-auth" }, +] + [[package]] name = "anyio" version = "4.8.0" @@ -2820,7 +2825,7 @@ lint = [ requires-dist = [ { name = "logfire", marker = "extra == 'logfire'", specifier = ">=3.11.0" }, { name = "pydantic-ai-examples", marker = "extra == 'examples'", editable = "examples" }, - { name = "pydantic-ai-slim", extras = ["anthropic", "bedrock", "cli", "cohere", "evals", "groq", "mcp", "mistral", "openai", "vertexai"], editable = "pydantic_ai_slim" }, + { name = "pydantic-ai-slim", extras = ["anthropic", "anthropic-vertex", "bedrock", "cli", "cohere", "evals", "groq", "mcp", "mistral", "openai", "vertexai"], editable = "pydantic_ai_slim" }, ] provides-extras = ["examples", "logfire"] @@ -2894,6 +2899,9 @@ dependencies = [ anthropic = [ { name = "anthropic" }, ] +antrhopic-vertex = [ + { name = "anthropic", extra = ["vertex"] }, +] bedrock = [ { name = "boto3" }, ] @@ -2954,6 +2962,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" }, + { name = "anthropic", extras = ["vertex"], marker = "extra == 'antrhopic-vertex'", specifier = ">=0.49.0" }, { name = "argcomplete", marker = "extra == 'cli'", specifier = ">=3.5.0" }, { name = "boto3", marker = "extra == 'bedrock'", specifier = ">=1.34.116" }, { name = "cohere", marker = "sys_platform != 'emscripten' and extra == 'cohere'", specifier = ">=5.13.11" }, @@ -2978,7 +2987,7 @@ requires-dist = [ { name = "tavily-python", marker = "extra == 'tavily'", specifier = ">=0.5.0" }, { name = "typing-inspection", specifier = ">=0.4.0" }, ] -provides-extras = ["anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "groq", "logfire", "mcp", "mistral", "openai", "tavily", "vertexai"] +provides-extras = ["anthropic", "antrhopic-vertex", "bedrock", "cli", "cohere", "duckduckgo", "evals", "groq", "logfire", "mcp", "mistral", "openai", "tavily", "vertexai"] [package.metadata.requires-dev] dev = [