Skip to content

Feature/response models #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 49 additions & 11 deletions fastapi_mcp/http_tools.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,9 @@
logger = logging.getLogger("fastapi_mcp")


def resolve_schema_references(schema: Dict[str, Any], openapi_schema: Dict[str, Any]) -> Dict[str, Any]:
def resolve_schema_references(
schema: Dict[str, Any], openapi_schema: Dict[str, Any], top_schema=None
) -> Dict[str, Any]:
"""
Resolve schema references in OpenAPI schemas.

@@ -31,6 +33,9 @@ def resolve_schema_references(schema: Dict[str, Any], openapi_schema: Dict[str,
# Make a copy to avoid modifying the input schema
schema = schema.copy()

# Create a a definnition prefix for the schema
def_prefix = "#/$defs/"

# Handle $ref directly in the schema
if "$ref" in schema:
ref_path = schema["$ref"]
@@ -41,18 +46,42 @@ def resolve_schema_references(schema: Dict[str, Any], openapi_schema: Dict[str,
if model_name in openapi_schema["components"]["schemas"]:
# Replace with the resolved schema
ref_schema = openapi_schema["components"]["schemas"][model_name].copy()
# Remove the $ref key and merge with the original schema
schema.pop("$ref")
schema.update(ref_schema)

if top_schema is not None:
# Create the $defs key if it doesn't exist
if "$defs" not in top_schema:
top_schema["$defs"] = {}

ref_schema = resolve_schema_references(ref_schema, openapi_schema, top_schema=top_schema)

# Create the definition reference
top_schema["$defs"][model_name] = ref_schema

# Update the schema with the definition reference
schema["$ref"] = def_prefix + model_name
else:
# Update the schema with the definition reference
schema.pop("$ref")
schema.update(ref_schema)
top_schema = schema

# Handle anyOf, oneOf, allOf
for key in ["anyOf", "oneOf", "allOf"]:
if key in schema:
for index, item in enumerate(schema[key]):
item = resolve_schema_references(item, openapi_schema, top_schema=top_schema)
schema[key][index] = item

# Handle array items
if "type" in schema and schema["type"] == "array" and "items" in schema:
schema["items"] = resolve_schema_references(schema["items"], openapi_schema)
schema["items"] = resolve_schema_references(schema["items"], openapi_schema, top_schema=top_schema)

# Handle object properties
if "properties" in schema:
for prop_name, prop_schema in schema["properties"].items():
schema["properties"][prop_name] = resolve_schema_references(prop_schema, openapi_schema)
schema["properties"][prop_name] = resolve_schema_references(
prop_schema, openapi_schema, top_schema=top_schema
)

return schema

@@ -72,9 +101,6 @@ def clean_schema_for_display(schema: Dict[str, Any]) -> Dict[str, Any]:

# Remove common internal fields that are not helpful for LLMs
fields_to_remove = [
"allOf",
"anyOf",
"oneOf",
"nullable",
"discriminator",
"readOnly",
@@ -481,7 +507,11 @@ async def http_tool_function(kwargs: Dict[str, Any] = Field(default_factory=dict
return response.text

# Create a proper input schema for the tool
input_schema = {"type": "object", "properties": properties, "title": f"{operation_id}Arguments"}
input_schema = {
"type": "object",
"properties": properties,
"title": f"{operation_id}Arguments",
}

if required_props:
input_schema["required"] = required_props
@@ -561,7 +591,15 @@ def generate_example_from_schema(schema: Dict[str, Any], model_name: Optional[st
}
elif model_name == "HTTPValidationError":
# Create a realistic validation error example
return {"detail": [{"loc": ["body", "name"], "msg": "field required", "type": "value_error.missing"}]}
return {
"detail": [
{
"loc": ["body", "name"],
"msg": "field required",
"type": "value_error.missing",
}
]
}

# Handle different types
schema_type = schema.get("type")
16 changes: 14 additions & 2 deletions tests/test_http_tools.py
Original file line number Diff line number Diff line change
@@ -86,7 +86,13 @@ def test_resolve_schema_references():
openapi_schema = {
"components": {
"schemas": {
"Item": {"type": "object", "properties": {"id": {"type": "integer"}, "name": {"type": "string"}}}
"Item": {
"type": "object",
"properties": {
"id": {"type": "integer"},
"name": {"type": "string"},
},
}
}
}
}
@@ -141,7 +147,13 @@ def test_create_mcp_tools_from_complex_app(complex_app):
assert len(api_tools) == 5, f"Expected 5 API tools, got {len(api_tools)}"

# Check for all expected tools with the correct name pattern
tool_operations = ["list_items", "read_item", "create_item", "update_item", "delete_item"]
tool_operations = [
"list_items",
"read_item",
"create_item",
"update_item",
"delete_item",
]
for operation in tool_operations:
matching_tools = [t for t in tools if operation in t.name]
assert len(matching_tools) > 0, f"No tool found for operation '{operation}'"
74 changes: 69 additions & 5 deletions tests/test_tool_generation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json

import pytest
from fastapi import FastAPI
from pydantic import BaseModel
@@ -14,6 +16,29 @@ class Item(BaseModel):
tags: List[str] = []


class Task(BaseModel):
id: int
title: str
description: Optional[str] = None
completed: bool = False
required_resources: List[Item] = []


def remove_default_values(schema: dict) -> dict:
if "default" in schema:
schema.pop("default")

for value in schema.values():
if isinstance(value, dict):
remove_default_values(value)

return schema


def normalize_json_schema(schema: dict) -> str:
return json.dumps(remove_default_values(schema), sort_keys=True)


@pytest.fixture
def sample_app():
"""Create a sample FastAPI app for testing."""
@@ -50,6 +75,14 @@ async def create_item(item: Item):
"""
return item

@app.get("/tasks/", response_model=List[Task], tags=["tasks"])
async def list_tasks(
skip: int = 0,
limit: int = 10,
):
"""List all tasks with pagination options."""
return []

return app


@@ -96,7 +129,10 @@ def test_tool_generation_with_full_schema(sample_app):
"""Test that MCP tools include full response schema when requested."""
# Create MCP server with full schema for all operations
mcp_server = add_mcp_server(
sample_app, serve_tools=True, base_url="http://localhost:8000", describe_full_response_schema=True
sample_app,
serve_tools=True,
base_url="http://localhost:8000",
describe_full_response_schema=True,
)

# Extract tools for inspection
@@ -109,16 +145,44 @@ def test_tool_generation_with_full_schema(sample_app):
continue

description = tool.description
# Check that the tool includes information about the Item schema
assert "Item" in description, f"Item schema should be included in the description for {tool.name}"
assert "price" in description, f"Item properties should be included in the description for {tool.name}"

# Check that the tool includes information about the Item or Task schema
if tool.name == "list_tasks_tasks__get":
model = Task
elif "Item" in description:
model = Item
elif "Task" not in description:
raise ValueError(f"Item or Task schema should be included in the description for {tool.name}")

assert "price" in description or "required_resources" in description, (
f"Item or Task properties should be included in the description for {tool.name}"
)

# Get the output schema from the description
lines = description.split("\n")
for index, line in enumerate(lines):
if "Output Schema" in line:
index += 2
break

# Normalize the output schema
output_schema_str = normalize_json_schema(json.loads("\n".join(lines[index:-1])))

# Generate and normalize the model schema
model_schema_str = normalize_json_schema(model.model_json_schema())

# Check that the output schema matches the model schema
assert output_schema_str == model_schema_str, f"Output schema does not match model schema for {tool.name}"


def test_tool_generation_with_all_responses(sample_app):
"""Test that MCP tools include all possible responses when requested."""
# Create MCP server with all response status codes
mcp_server = add_mcp_server(
sample_app, serve_tools=True, base_url="http://localhost:8000", describe_all_responses=True
sample_app,
serve_tools=True,
base_url="http://localhost:8000",
describe_all_responses=True,
)

# Extract tools for inspection