diff --git a/README.md b/README.md index e0a4dfb..5e52659 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ A zero-configuration tool for automatically exposing FastAPI endpoints as Model - **Preserving schemas** of your request models and response models - **Preserve documentation** of all your endpoints, just as it is in Swagger - **Extend** - Add custom MCP tools alongside the auto-generated ones +- **Authentication** - Secure your MCP server with various authentication methods ## Installation @@ -49,6 +50,71 @@ add_mcp_server( That's it! Your auto-generated MCP server is now available at `https://app.base.url/mcp`. +## Authentication + +FastAPI-MCP supports various authentication methods to secure your MCP server: + +### Simple Bearer Token Authentication + +```python +from fastapi import FastAPI +from fastapi_mcp import add_mcp_server, AuthConfig + +app = FastAPI() + +# Configure authentication with a bearer token +auth_config = AuthConfig( + enabled=True, + bearer_token="your-secret-token" # This should be a secure token +) + +# Add MCP server with authentication +mcp_server = add_mcp_server( + app, + mount_path="/mcp", + name="Authenticated MCP API", + auth_config=auth_config +) +``` + +### API Key Authentication + +```python +from fastapi_mcp import AuthConfig + +# Configure authentication with an API key in header +auth_config = AuthConfig( + enabled=True, + api_key="your-api-key", + api_key_name="X-API-Key", + api_key_in="header" # Can be "header" or "query" +) +``` + +### Custom Authentication + +For more complex authentication scenarios, you can use a custom authentication function: + +```python +from fastapi import Request +from fastapi_mcp import AuthConfig + +# Define your custom authentication function +async def my_auth_function(request: Request) -> bool: + # Your authentication logic here + # Return True if authenticated, False otherwise + token = request.headers.get("Authorization", "").replace("Bearer ", "") + return token == "valid-token" + +# Configure authentication with a custom function +auth_config = AuthConfig( + enabled=True, + custom_auth_func=my_auth_function +) +``` + +See the [examples/authenticated_server.py](examples/authenticated_server.py) and [examples/simple_bearer_auth.py](examples/simple_bearer_auth.py) for complete working examples. + ## Advanced Usage FastAPI-MCP provides several ways to customize and control how your MCP server is created and configured. Here are some advanced usage patterns: @@ -160,4 +226,4 @@ MIT License. Copyright (c) 2024 Tadata Inc. ## About -Developed and maintained by [Tadata Inc.](https://github.com/tadata-org) +Developed and maintained by [Tadata Inc.](https://github.com/tadata-org) \ No newline at end of file diff --git a/examples/authenticated_server.py b/examples/authenticated_server.py new file mode 100644 index 0000000..1ff43be --- /dev/null +++ b/examples/authenticated_server.py @@ -0,0 +1,231 @@ +""" +Example of a FastAPI-MCP server with authentication. + +This example demonstrates how to add authentication to an MCP server. +""" + +from fastapi import FastAPI, Depends, HTTPException +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm +from pydantic import BaseModel +from typing import List, Optional, Dict +from datetime import datetime, timedelta +import jwt +from jwt.exceptions import PyJWTError + +from fastapi_mcp import add_mcp_server, AuthConfig + +# Create FastAPI app +app = FastAPI( + title="Authenticated MCP Example", + description="An example of an MCP server with authentication", +) + +# Secret key for JWT token +SECRET_KEY = "a_very_secret_key_that_should_be_changed_in_production" +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 30 + +# Fake users database +fake_users_db = { + "johndoe": { + "username": "johndoe", + "full_name": "John Doe", + "email": "john@example.com", + "hashed_password": "fakehashedsecret", + "disabled": False, + }, + "alice": { + "username": "alice", + "full_name": "Alice Wonderland", + "email": "alice@example.com", + "hashed_password": "fakehashedsecret2", + "disabled": False, + }, +} + +# OAuth2 scheme +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + + +# Models +class User(BaseModel): + username: str + email: Optional[str] = None + full_name: Optional[str] = None + disabled: Optional[bool] = None + + +class UserInDB(User): + hashed_password: str + + +class Token(BaseModel): + access_token: str + token_type: str + + +class TokenData(BaseModel): + username: Optional[str] = None + + +def verify_password(plain_password, hashed_password): + """Verify password.""" + # This is a fake hash verification - in a real app, use proper hashing + return plain_password == "secret" and hashed_password == "fakehashedsecret" + + +def get_user(db, username: str): + """Get user from the database.""" + if username in db: + user_dict = db[username] + return UserInDB(**user_dict) + + +def authenticate_user(fake_db, username: str, password: str): + """Authenticate a user.""" + user = get_user(fake_db, username) + if not user: + return False + if not verify_password(password, user.hashed_password): + return False + return user + + +def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): + """Create a JWT token.""" + to_encode = data.copy() + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=15) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + + +async def get_current_user(token: str = Depends(oauth2_scheme)): + """Get the current user from JWT token.""" + credentials_exception = HTTPException( + status_code=401, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str = payload.get("sub") + if username is None: + raise credentials_exception + token_data = TokenData(username=username) + except PyJWTError: + raise credentials_exception + user = get_user(fake_users_db, username=token_data.username) + if user is None: + raise credentials_exception + return user + + +async def get_current_active_user(current_user: User = Depends(get_current_user)): + """Get the current active user.""" + if current_user.disabled: + raise HTTPException(status_code=400, detail="Inactive user") + return current_user + + +@app.post("/token", response_model=Token) +async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()): + """Login to get an access token.""" + user = authenticate_user(fake_users_db, form_data.username, form_data.password) + if not user: + raise HTTPException( + status_code=401, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + access_token = create_access_token( + data={"sub": user.username}, expires_delta=access_token_expires + ) + return {"access_token": access_token, "token_type": "bearer"} + + +@app.get("/users/me/", response_model=User) +async def read_users_me(current_user: User = Depends(get_current_active_user)): + """Get current user information.""" + return current_user + + +# Sample items for a simple API +items_db: Dict[str, dict] = { + "1": {"id": "1", "name": "Foo", "description": "This is foo"}, + "2": {"id": "2", "name": "Bar", "description": "This is bar"}, + "3": {"id": "3", "name": "Baz", "description": "This is baz"}, +} + + +@app.get("/items/", response_model=List[dict]) +async def read_items(current_user: User = Depends(get_current_active_user)): + """Get all items, requires authentication.""" + return list(items_db.values()) + + +@app.get("/items/{item_id}", response_model=dict) +async def read_item(item_id: str, current_user: User = Depends(get_current_active_user)): + """Get a specific item, requires authentication.""" + if item_id not in items_db: + raise HTTPException(status_code=404, detail="Item not found") + return items_db[item_id] + + +# Function to verify JWT tokens for MCP authentication +async def authenticate_mcp_request(request): + """Custom authentication function for MCP server.""" + try: + # Get token from Authorization header + auth_header = request.headers.get("Authorization", "") + if not auth_header or not auth_header.startswith("Bearer "): + return False + + token = auth_header.replace("Bearer ", "") + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + username = payload.get("sub") + + # Verify user exists in database + if not username or username not in fake_users_db: + return False + + # Check if user is disabled + if fake_users_db[username].get("disabled", False): + return False + + return True + except Exception: + return False + + +# Add MCP server with authentication +auth_config = AuthConfig( + enabled=True, + # Define a custom auth function that validates JWT tokens + custom_auth_func=authenticate_mcp_request +) + +mcp_server = add_mcp_server( + app, + mount_path="/mcp", + name="Authenticated MCP API", + description="MCP server with JWT authentication", + base_url="http://localhost:8000", + auth_config=auth_config +) + + +# Add a custom MCP tool +@mcp_server.tool() +async def get_item_count() -> int: + """Get the total number of items in the database.""" + return len(items_db) + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="127.0.0.1", port=8000) \ No newline at end of file diff --git a/examples/simple_bearer_auth.py b/examples/simple_bearer_auth.py new file mode 100644 index 0000000..2b89fd7 --- /dev/null +++ b/examples/simple_bearer_auth.py @@ -0,0 +1,93 @@ +""" +Example of a FastAPI-MCP server with simple bearer token authentication. + +This example demonstrates how to add simple bearer token authentication to an MCP server. +""" + +from fastapi import FastAPI, Depends, HTTPException, Header +from typing import List, Dict, Optional + +from fastapi_mcp import add_mcp_server, AuthConfig + +# Create FastAPI app +app = FastAPI( + title="Simple Bearer Auth MCP Example", + description="An example of an MCP server with simple bearer token authentication", +) + +# Define a fixed API token for simplicity +# In a real application, this should be stored securely +API_TOKEN = "mcp-test-token-123456" + + +# Helper function to verify the token in regular FastAPI endpoints +async def verify_token(authorization: Optional[str] = Header(None)): + if not authorization or not authorization.startswith("Bearer "): + raise HTTPException( + status_code=401, + detail="Invalid authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + + token = authorization.replace("Bearer ", "") + if token != API_TOKEN: + raise HTTPException( + status_code=401, + detail="Invalid token", + headers={"WWW-Authenticate": "Bearer"}, + ) + + return token + + +# Sample items database +items_db: Dict[str, dict] = { + "1": {"id": "1", "name": "Foo", "description": "This is foo"}, + "2": {"id": "2", "name": "Bar", "description": "This is bar"}, + "3": {"id": "3", "name": "Baz", "description": "This is baz"}, +} + + +@app.get("/items/", response_model=List[dict]) +async def read_items(token: str = Depends(verify_token)): + """Get all items, requires authentication.""" + return list(items_db.values()) + + +@app.get("/items/{item_id}", response_model=dict) +async def read_item(item_id: str, token: str = Depends(verify_token)): + """Get a specific item, requires authentication.""" + if item_id not in items_db: + raise HTTPException(status_code=404, detail="Item not found") + return items_db[item_id] + + +# Add MCP server with authentication using a simple bearer token +auth_config = AuthConfig( + enabled=True, + bearer_token=API_TOKEN # Directly use the API token for bearer auth +) + +mcp_server = add_mcp_server( + app, + mount_path="/mcp", + name="Bearer Auth MCP API", + description="MCP server with simple bearer token authentication", + base_url="http://localhost:8000", + auth_config=auth_config +) + + +# Add a custom MCP tool +@mcp_server.tool() +async def get_item_count() -> int: + """Get the total number of items in the database.""" + return len(items_db) + + +if __name__ == "__main__": + import uvicorn + print(f"Starting server with bearer token authentication. Use token: {API_TOKEN}") + print("Example: curl -H 'Authorization: Bearer mcp-test-token-123456' http://localhost:8000/items/") + print("For MCP: Use the same Authorization header when connecting to /mcp") + uvicorn.run(app, host="127.0.0.1", port=8000) \ No newline at end of file diff --git a/fastapi_mcp/__init__.py b/fastapi_mcp/__init__.py index 7af9dd7..8d5cb59 100644 --- a/fastapi_mcp/__init__.py +++ b/fastapi_mcp/__init__.py @@ -8,10 +8,14 @@ from .server import add_mcp_server, create_mcp_server, mount_mcp_server from .http_tools import create_mcp_tools_from_openapi +from .auth import AuthConfig, MCPAuthenticator, create_auth_dependency __all__ = [ "add_mcp_server", "create_mcp_server", "mount_mcp_server", "create_mcp_tools_from_openapi", -] + "AuthConfig", + "MCPAuthenticator", + "create_auth_dependency", +] \ No newline at end of file diff --git a/fastapi_mcp/auth.py b/fastapi_mcp/auth.py new file mode 100644 index 0000000..f775d2c --- /dev/null +++ b/fastapi_mcp/auth.py @@ -0,0 +1,163 @@ +""" +Authentication module for FastAPI-MCP. + +This module provides functionality for adding authentication to MCP servers in FastAPI applications. +""" + +from typing import Optional, Callable, Dict, Any, Union +from fastapi import Request, HTTPException, Depends, Security +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from fastapi.security.api_key import APIKeyHeader, APIKeyQuery +from pydantic import BaseModel + + +class AuthConfig(BaseModel): + """Configuration for MCP authentication.""" + enabled: bool = True + bearer_token: Optional[str] = None + api_key: Optional[str] = None + api_key_name: str = "X-API-Key" + api_key_in: str = "header" # "header" or "query" + custom_auth_func: Optional[Callable[[Request], bool]] = None + + +class MCPAuthenticator: + """Class to handle authentication for MCP connections.""" + + def __init__(self, config: AuthConfig): + self.config = config + + # Set up the security schemes based on configuration + if self.config.api_key and self.config.api_key_in == "header": + self.api_key_header = APIKeyHeader(name=self.config.api_key_name, auto_error=False) + elif self.config.api_key and self.config.api_key_in == "query": + self.api_key_query = APIKeyQuery(name=self.config.api_key_name, auto_error=False) + + if self.config.bearer_token: + self.http_bearer = HTTPBearer(auto_error=False) + + async def authenticate_request(self, request: Request) -> bool: + """ + Authenticate an incoming request against configured auth methods. + Returns True if authentication succeeds, False otherwise. + """ + if not self.config.enabled: + return True + + # Check bearer token if configured + if self.config.bearer_token: + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.startswith("Bearer "): + token = auth_header.replace("Bearer ", "") + if token == self.config.bearer_token: + return True + + # Check API key if configured + if self.config.api_key: + if self.config.api_key_in == "header": + api_key = request.headers.get(self.config.api_key_name) + if api_key and api_key == self.config.api_key: + return True + elif self.config.api_key_in == "query": + api_key = request.query_params.get(self.config.api_key_name) + if api_key and api_key == self.config.api_key: + return True + + # Use custom auth function if provided + if self.config.custom_auth_func: + try: + return await self.config.custom_auth_func(request) + except Exception: + return False + + # If we get here, authentication failed + return False + + async def authenticate_or_raise(self, request: Request) -> None: + """Authenticate request or raise HTTP 401 exception.""" + if not await self.authenticate_request(request): + raise HTTPException( + status_code=401, + detail="Invalid authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # Dependencies for use with FastAPI routes + async def verify_bearer_token( + self, credentials: HTTPAuthorizationCredentials = Security(HTTPBearer()) + ) -> str: + """Verify bearer token for FastAPI dependency injection.""" + if not self.config.enabled: + return "authenticated" + + if not self.config.bearer_token: + raise HTTPException( + status_code=500, + detail="Bearer token authentication not configured", + ) + + if credentials.credentials != self.config.bearer_token: + raise HTTPException( + status_code=401, + detail="Invalid authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + return credentials.credentials + + async def verify_api_key_header( + self, api_key: str = Security(APIKeyHeader(name="X-API-Key")) + ) -> str: + """Verify API key in header for FastAPI dependency injection.""" + if not self.config.enabled: + return "authenticated" + + if not self.config.api_key: + raise HTTPException( + status_code=500, + detail="API key authentication not configured" + ) + + if api_key != self.config.api_key: + raise HTTPException( + status_code=401, + detail="Invalid API Key" + ) + return api_key + + async def verify_api_key_query( + self, api_key: str = Security(APIKeyQuery(name="api_key")) + ) -> str: + """Verify API key in query for FastAPI dependency injection.""" + if not self.config.enabled: + return "authenticated" + + if not self.config.api_key: + raise HTTPException( + status_code=500, + detail="API key authentication not configured" + ) + + if api_key != self.config.api_key: + raise HTTPException( + status_code=401, + detail="Invalid API Key" + ) + return api_key + +# Function to create an auth dependency from various authentication options +def create_auth_dependency( + auth_config: Union[AuthConfig, Dict[str, Any]] +) -> Callable: + """Create a FastAPI dependency for authentication based on the provided config.""" + + # Convert dict to AuthConfig if needed + if isinstance(auth_config, dict): + auth_config = AuthConfig(**auth_config) + + authenticator = MCPAuthenticator(auth_config) + + async def auth_dependency(request: Request) -> None: + """FastAPI dependency for authentication.""" + await authenticator.authenticate_or_raise(request) + + return auth_dependency \ No newline at end of file diff --git a/fastapi_mcp/server.py b/fastapi_mcp/server.py index db65adb..d7b8c29 100644 --- a/fastapi_mcp/server.py +++ b/fastapi_mcp/server.py @@ -4,14 +4,15 @@ This module provides functionality for creating and mounting MCP servers to FastAPI applications. """ -from typing import Dict, Optional, Any +from typing import Dict, Optional, Any, Union -from fastapi import FastAPI +from fastapi import FastAPI, Request, HTTPException from mcp.server.fastmcp import FastMCP from mcp.server.sse import SseServerTransport from fastapi import Request from .http_tools import create_mcp_tools_from_openapi +from .auth import AuthConfig, MCPAuthenticator, create_auth_dependency def create_mcp_server( @@ -56,6 +57,7 @@ def mount_mcp_server( base_url: Optional[str] = None, describe_all_responses: bool = False, describe_full_response_schema: bool = False, + auth_config: Optional[Union[AuthConfig, Dict[str, Any]]] = None, ) -> None: """ Mount an MCP server to a FastAPI app. @@ -66,8 +68,11 @@ def mount_mcp_server( mount_path: Path where the MCP server will be mounted serve_tools: Whether to serve tools from the FastAPI app base_url: Base URL for API requests - describe_all_responses: Whether to include all possible response schemas in tool descriptions. Recommended to keep False, as the LLM will probably derive if there is an error. - describe_full_response_schema: Whether to include full json schema for responses in tool descriptions. Recommended to keep False, as examples are more LLM friendly, and save tokens. + describe_all_responses: Whether to include all possible response schemas in tool descriptions. + Recommended to keep False, as the LLM will probably derive if there is an error. + describe_full_response_schema: Whether to include full json schema for responses in tool descriptions. + Recommended to keep False, as examples are more LLM friendly, and save tokens. + auth_config: Optional authentication configuration """ # Normalize mount path if not mount_path.startswith("/"): @@ -78,8 +83,24 @@ def mount_mcp_server( # Create SSE transport for MCP messages sse_transport = SseServerTransport(f"{mount_path}/messages/") + # Set up authentication if configured + authenticator = None + if auth_config: + if isinstance(auth_config, dict): + auth_config = AuthConfig(**auth_config) + authenticator = MCPAuthenticator(auth_config) + # Define MCP connection handler async def handle_mcp_connection(request: Request): + # Check authentication if enabled + if authenticator and authenticator.config.enabled: + if not await authenticator.authenticate_request(request): + raise HTTPException( + status_code=401, + detail="Invalid authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams: await mcp_server._mcp_server.run( streams[0], @@ -89,7 +110,18 @@ async def handle_mcp_connection(request: Request): # Mount the MCP connection handler app.get(mount_path)(handle_mcp_connection) - app.mount(f"{mount_path}/messages/", app=sse_transport.handle_post_message) + + # Handle SSE message posting with authentication if configured + if authenticator and authenticator.config.enabled: + # Create a wrapped message handler that checks authentication + async def authenticated_message_handler(request: Request): + await authenticator.authenticate_or_raise(request) + return await sse_transport.handle_post_message(request) + + app.post(f"{mount_path}/messages/")(authenticated_message_handler) + else: + # Use the default message handler without authentication + app.mount(f"{mount_path}/messages/", app=sse_transport.handle_post_message) # Serve tools from the FastAPI app if requested if serve_tools: @@ -112,6 +144,7 @@ def add_mcp_server( base_url: Optional[str] = None, describe_all_responses: bool = False, describe_full_response_schema: bool = False, + auth_config: Optional[Union[AuthConfig, Dict[str, Any]]] = None, ) -> FastMCP: """ Add an MCP server to a FastAPI app. @@ -124,8 +157,11 @@ def add_mcp_server( capabilities: Optional capabilities for the MCP server serve_tools: Whether to serve tools from the FastAPI app base_url: Base URL for API requests (defaults to http://localhost:$PORT) - describe_all_responses: Whether to include all possible response schemas in tool descriptions. Recommended to keep False, as the LLM will probably derive if there is an error. - describe_full_response_schema: Whether to include full json schema for responses in tool descriptions. Recommended to keep False, as examples are more LLM friendly, and save tokens. + describe_all_responses: Whether to include all possible response schemas in tool descriptions. + Recommended to keep False, as the LLM will probably derive if there is an error. + describe_full_response_schema: Whether to include full json schema for responses in tool descriptions. + Recommended to keep False, as examples are more LLM friendly, and save tokens. + auth_config: Optional authentication configuration for the MCP server Returns: The FastMCP instance that was created and mounted @@ -142,6 +178,7 @@ def add_mcp_server( base_url, describe_all_responses=describe_all_responses, describe_full_response_schema=describe_full_response_schema, + auth_config=auth_config, ) - return mcp_server + return mcp_server \ No newline at end of file