Skip to content

Add provider for Anthropic's Vertexai Client #1392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from . import Model, ModelRequestParameters, StreamedResponse, cached_async_http_client, check_allow_model_requests

try:
from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream
from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncAnthropicVertex, AsyncStream
from anthropic.types import (
Base64PDFSourceParam,
ContentBlock,
Expand Down Expand Up @@ -108,7 +108,7 @@ class AnthropicModel(Model):
We anticipate adding support for streaming responses in a near-term future release.
"""

client: AsyncAnthropic = field(repr=False)
client: AsyncAnthropic | AsyncAnthropicVertex = field(repr=False)

_model_name: AnthropicModelName = field(repr=False)
_system: str = field(default='anthropic', repr=False)
Expand All @@ -117,15 +117,21 @@ def __init__(
self,
model_name: AnthropicModelName,
*,
provider: Literal['anthropic'] | Provider[AsyncAnthropic] = 'anthropic',
# breaking this in multiple lines breaks pycharm type recognition. However, I was unable to stop ruff from
# doing it - # fmt: skip etc didn't work :(
provider: Literal['anthropic', 'anthropic-vertex']
| Provider[AsyncAnthropicVertex]
| Provider[AsyncAnthropic] = # fmt: skip
'anthropic',
Comment on lines +120 to +125
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to do # fmt: off and # fmt: on after.

):
"""Initialize an Anthropic model.

Args:
model_name: The name of the Anthropic model to use. List of model names available
[here](https://docs.anthropic.com/en/docs/about-claude/models).
provider: The provider to use for the Anthropic API. Can be either the string 'anthropic' or an
instance of `Provider[AsyncAnthropic]`. If not provided, the other parameters will be used.
provider: The provider to use for the Anthropic API. Can be either the string 'anthropic',
'anthropic-vertex', or an instance of Provider[AsyncAnthropic] or Provider[AsyncAnthropicVertex].
Defaults to 'anthropic'.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you should also add an entry to the KnownModelName literal at models/__init__.py

"""
self._model_name = model_name

Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def infer_provider(provider: str) -> Provider[Any]:
from .anthropic import AnthropicProvider

return AnthropicProvider()
elif provider == 'anthropic-vertex':
from .anthropic_vertex import AnthropicVertexProvider

return AnthropicVertexProvider()
Comment on lines +76 to +79
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is tricky. If you want google-vertex:<anthropic-model>, the agent should be able to infer this provider instead of the other... Maybe the provider should offer different clients depending on the model?

Right now, the GoogleVertexProvider is a generic on the client, which currently is httpx.AsyncClient. What the provider can offer a client via method?

class GoogleVertexProvider(Provider):
    def get_client(self, tp: type[T]) -> T:
        if isinstance(tp, httpx.AsyncClient):
            return self.httpx_client
        elif isinstance(tp, AsyncAnthropicVertex):
            return self.anthropic_client
        else:
            raise ValueError('not supported')

elif provider == 'mistral':
from .mistral import MistralProvider

Expand Down
43 changes: 43 additions & 0 deletions pydantic_ai_slim/pydantic_ai/providers/anthropic_vertex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations as _annotations

from pydantic_ai.providers import Provider

try:
from anthropic import AsyncAnthropicVertex
except ImportError as _import_error: # pragma: no cover
raise ImportError(
'Please install the `anthropic` package to use the Anthropic provider, '
'you can use the `anthropic` optional group — `pip install "pydantic-ai-slim[anthropic]"`'
) from _import_error


class AnthropicVertexProvider(Provider[AsyncAnthropicVertex]):
"""Provider for Anthropic API."""

@property
def name(self) -> str:
return 'anthropic-vertex'

@property
def base_url(self) -> str:
return str(self._client.base_url)

@property
def client(self) -> AsyncAnthropicVertex:
return self._client

def __init__(
self,
*,
anthropic_client: AsyncAnthropicVertex | None = None,
) -> None:
"""Create a new Anthropic provider.

Args:
anthropic_client: An existing [`AsyncAnthropic`](https://github.com/anthropics/anthropic-sdk-python)
client to use. If provided, the `api_key` and `http_client` arguments will be ignored.
"""
if anthropic_client:
self._client = anthropic_client
else:
self._client = AsyncAnthropicVertex()
1 change: 1 addition & 0 deletions pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ openai = ["openai>=1.67.0"]
cohere = ["cohere>=5.13.11; platform_system != 'Emscripten'"]
vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"]
anthropic = ["anthropic>=0.49.0"]
antrhopic-vertex = ["anthropic[vertex]>=0.49.0"]
groq = ["groq>=0.15.0"]
mistral = ["mistralai>=1.2.5"]
bedrock = ["boto3>=1.34.116"]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ requires-python = ">=3.9"

[tool.hatch.metadata.hooks.uv-dynamic-versioning]
dependencies = [
"pydantic-ai-slim[openai,vertexai,groq,anthropic,mistral,cohere,bedrock,cli,mcp,evals]=={{ version }}",
"pydantic-ai-slim[openai,vertexai,groq,anthropic,anthropic-vertex,mistral,cohere,bedrock,cli,mcp,evals]=={{ version }}",
]

[tool.hatch.metadata.hooks.uv-dynamic-versioning.optional-dependencies]
Expand Down
21 changes: 19 additions & 2 deletions tests/models/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@
ToolReturnPart,
UserPromptPart,
)
from pydantic_ai.providers.anthropic_vertex import AnthropicVertexProvider
from pydantic_ai.result import Usage
from pydantic_ai.settings import ModelSettings

from ..conftest import IsDatetime, IsNow, IsStr, TestEnv, raise_if_exception, try_import
from .mock_async_stream import MockAsyncStream

with try_import() as imports_successful:
from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic
from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncAnthropicVertex
from anthropic.types import (
ContentBlock,
InputJSONDelta,
Expand Down Expand Up @@ -70,7 +71,7 @@

def test_init():
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(api_key='foobar'))
assert m.client.api_key == 'foobar'
# assert m.client.api_key == 'foobar'
assert m.model_name == 'claude-3-5-haiku-latest'
assert m.system == 'anthropic'
assert m.base_url == 'https://api.anthropic.com'
Expand Down Expand Up @@ -682,3 +683,19 @@ def test_init_with_provider_string(env: TestEnv):
model = AnthropicModel('claude-3-opus-latest', provider='anthropic')
assert model.model_name == 'claude-3-opus-latest'
assert model.client is not None


def test_init_with_vertex_provider():
client = AsyncAnthropicVertex(project_id='foo', region='us-east1')
provider = AnthropicVertexProvider(anthropic_client=client)
model = AnthropicModel('claude-3-opus-latest', provider=provider)
assert model.model_name == 'claude-3-opus-latest'
assert model.client == provider.client


def test_init_with_vertex_provider_string(env: TestEnv):
env.set('CLOUD_ML_REGION', 'us-east1')
model = AnthropicModel('claude-3-opus-latest', provider='anthropic-vertex')
assert model.model_name == 'claude-3-opus-latest'
assert model.client is not None
assert isinstance(model.client, AsyncAnthropicVertex)
35 changes: 35 additions & 0 deletions tests/providers/test_anthropic_vertex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations as _annotations

import pytest

from ..conftest import try_import

with try_import() as imports_successful:
from anthropic import AsyncAnthropicVertex

from pydantic_ai.providers.anthropic_vertex import AnthropicVertexProvider


pytestmark = pytest.mark.skipif(not imports_successful(), reason='need to install anthropic-vertex')


def test_anthropic_provider_with_project_and_region():
mock_region = 'us-east5'
client = AsyncAnthropicVertex(project_id='test-project', region=mock_region)
provider = AnthropicVertexProvider(anthropic_client=client)
assert provider.name == 'anthropic-vertex'
assert provider.base_url == f'https://{mock_region}-aiplatform.googleapis.com/v1/'
assert isinstance(provider.client, AsyncAnthropicVertex)
assert provider.client.region == mock_region


def test_anthropic_provider_with_empty_client_and_valid_env(monkeypatch: pytest.MonkeyPatch):
mock_region = 'europe-west3'
monkeypatch.setenv('CLOUD_ML_REGION', mock_region)

client = AsyncAnthropicVertex()
provider = AnthropicVertexProvider(anthropic_client=client)
assert provider.name == 'anthropic-vertex'
assert provider.client.region == mock_region
assert provider.base_url == f'https://{mock_region}-aiplatform.googleapis.com/v1/'
assert isinstance(provider.client, AsyncAnthropicVertex)
13 changes: 11 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading