Skip to content

[Feature] New API to discover, fetch, and call tools from server extensions #1521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
80 changes: 80 additions & 0 deletions jupyter_server/extension/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import importlib
from itertools import starmap

import jsonschema
from tornado.gen import multi
from traitlets import Any, Bool, Dict, HasTraits, Instance, List, Unicode, default, observe
from traitlets import validate as validate_trait
Expand All @@ -13,6 +14,37 @@
from .config import ExtensionConfigManager
from .utils import ExtensionMetadataError, ExtensionModuleNotFound, get_loader, get_metadata

# probably this should go in it's own file? Not sure where though
MCP_TOOL_SCHEMA = {
"type": "object",
"properties": {
"name": {"type": "string"},
"description": {"type": "string"},
"inputSchema": {
"type": "object",
"properties": {
"type": {"type": "string", "enum": ["object"]},
"properties": {"type": "object"},
"required": {"type": "array", "items": {"type": "string"}},
},
"required": ["type", "properties"],
},
"annotations": {
"type": "object",
"properties": {
"title": {"type": "string"},
"readOnlyHint": {"type": "boolean"},
"destructiveHint": {"type": "boolean"},
"idempotentHint": {"type": "boolean"},
"openWorldHint": {"type": "boolean"},
},
"additionalProperties": True,
},
},
"required": ["name", "inputSchema"],
"additionalProperties": False,
}


class ExtensionPoint(HasTraits):
"""A simple API for connecting to a Jupyter Server extension
Expand Down Expand Up @@ -97,6 +129,38 @@ def module(self):
"""The imported module (using importlib.import_module)"""
return self._module

@property
def tools(self):
"""Structured tools exposed by this extension point, if any.

Searches for a `jupyter_server_extension_tools` function on the extension module or app.
"""
loc = self.app or self.module
if not loc:
return {}

tools_func = getattr(loc, "jupyter_server_extension_tools", None)
if not callable(tools_func):
return {}

tools = {}
try:
result = tools_func()
# Support (tools_dict, schema) or just tools_dict
if isinstance(result, tuple) and len(result) == 2:
tools_dict, schema = result
else:
tools_dict = result
schema = MCP_TOOL_SCHEMA

for name, tool in tools_dict.items():
jsonschema.validate(instance=tool["metadata"], schema=schema)
tools[name] = tool
except Exception as e:
# not sure if this should fail quietly, raise an error, or log it?
print(f"[tool-discovery] Failed to load tools from {self.module_name}: {e}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using print, we should make this a self.log.error.

Before we do that, though, we'll need to make ExtensionPoint a LoggingConfigurable and set the parent trait to ExtensionPackage. I can work on this in a separate PR and we can rebase here once that's merged.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #1522. If/when that gets merged, we can rebase your PR to use the logger created there.

return tools

def _get_linker(self):
"""Get a linker."""
if self.app:
Expand Down Expand Up @@ -443,6 +507,22 @@ def load_all_extensions(self):
for name in self.sorted_extensions:
self.load_extension(name)

def get_tools(self) -> Dict[str, Any]:
"""Aggregate and return structured tools (with metadata) from all enabled extensions."""
all_tools = {}

for ext_name, ext_pkg in self.extensions.items():
if not ext_pkg.enabled:
continue

for point in ext_pkg.extension_points.values():
for name, tool in point.tools.items():
if name in all_tools:
raise ValueError(f"Duplicate tool name detected: '{name}'")
all_tools[name] = tool

return all_tools

async def start_all_extensions(self):
"""Start all enabled extensions."""
# Sort the extension names to enforce deterministic loading
Expand Down
5 changes: 5 additions & 0 deletions jupyter_server/serverapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2540,6 +2540,11 @@ def load_server_extensions(self) -> None:
"""
self.extension_manager.load_all_extensions()


def get_tools(self):
"""Return tools exposed by all extensions."""
return self.extension_manager.get_tools()

def init_mime_overrides(self) -> None:
# On some Windows machines, an application has registered incorrect
# mimetypes in the registry.
Expand Down
1 change: 1 addition & 0 deletions jupyter_server/services/contents/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ async def post(self, path=""):
self.finish()



# -----------------------------------------------------------------------------
# URL to handler mappings
# -----------------------------------------------------------------------------
Expand Down
14 changes: 14 additions & 0 deletions jupyter_server/services/tools/handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from tornado import web
from jupyter_server.base.handlers import APIHandler

class ListToolInfoHandler(APIHandler):
@web.authenticated
async def get(self):
tools = self.serverapp.extension_manager.discover_tools()
self.finish({"discovered_tools": tools})



default_handlers = [
(r"/api/tools", ListToolInfoHandler),
]
22 changes: 22 additions & 0 deletions tests/extension/mockextensions/mockext_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""A mock extension exposing a structured tool."""

def jupyter_server_extension_tools():
return {
"mock_tool": {
"metadata": {
"name": "mock_tool",
"description": "A mock tool for testing.",
"inputSchema": {
"type": "object",
"properties": {
"input": {"type": "string"}
},
"required": ["input"]
}
},
"callable": lambda input: f"Echo: {input}"
}
}

def _load_jupyter_server_extension(serverapp):
serverapp.log.info("Loaded mock tool extension.")
21 changes: 21 additions & 0 deletions tests/extension/mockextensions/mockext_tool_dupes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""A mock extension that defines a duplicate tool name to test conflict handling."""

def jupyter_server_extension_tools():
return {
"mock_tool": { # <-- duplicate on purpose
"metadata": {
"name": "mock_tool",
"description": "Conflicting tool name.",
"inputSchema": {
"type": "object",
"properties": {
"input": {"type": "string"}
}
}
},
"callable": lambda input: f"Echo again: {input}"
}
}

def _load_jupyter_server_extension(serverapp):
serverapp.log.info("Loaded dupe tool extension.")
39 changes: 39 additions & 0 deletions tests/extension/mockextensions/mockext_tool_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""A mock extension that provides a custom validation schema."""

OPENAI_TOOL_SCHEMA = {
"type": "object",
"properties": {
"name": {"type": "string"},
"description": {"type": "string"},
"parameters": {
"type": "object",
"properties": {
"input": {"type": "string"}
},
"required": ["input"]
}
},
"required": ["name", "parameters"]
}

def jupyter_server_extension_tools():
tools = {
"openai_style_tool": {
"metadata": {
"name": "openai_style_tool",
"description": "Tool using OpenAI-style parameters",
"parameters": {
"type": "object",
"properties": {
"input": {"type": "string"}
},
"required": ["input"]
}
},
"callable": lambda input: f"Got {input}"
}
}
return (tools, OPENAI_TOOL_SCHEMA)

def _load_jupyter_server_extension(serverapp):
serverapp.log.info("Loaded mock custom-schema extension.")
25 changes: 25 additions & 0 deletions tests/extension/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,28 @@ def test_disable_no_import(jp_serverapp, has_app):
assert ext_pkg.extension_points == {}
assert ext_pkg.version == ""
assert ext_pkg.metadata == []


def test_extension_point_tools_default_schema():
ep = ExtensionPoint(metadata={"module": "tests.extension.mockextensions.mockext_tool"})
assert "mock_tool" in ep.tools


def test_extension_point_tools_custom_schema():
ep = ExtensionPoint(metadata={"module": "tests.extension.mockextensions.mockext_customschema"})
assert "openai_style_tool" in ep.tools
metadata = ep.tools["openai_style_tool"]["metadata"]
assert "parameters" in metadata


def test_extension_manager_duplicate_tool_name_raises(jp_serverapp):
from jupyter_server.extension.manager import ExtensionManager

manager = ExtensionManager(serverapp=jp_serverapp)
manager.add_extension("tests.extension.mockextensions.mockext_tool", enabled=True)
manager.add_extension("tests.extension.mockextensions.mockext_dupes", enabled=True)
manager.link_all_extensions()

with pytest.raises(ValueError, match="Duplicate tool name detected: 'mock_tool'"):
manager.get_tools()

29 changes: 29 additions & 0 deletions tests/services/tools/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import json
import pytest

@pytest.fixture
def jp_server_config():
return {
"ServerApp": {
"jpserver_extensions": {
"tests.extension.mockextensions.mockext_tool": True,
"tests.extension.mockextensions.mockext_customschema": True,
}
}
}

@pytest.mark.asyncio
async def test_multiple_tools_present(jp_fetch):
response = await jp_fetch("api", "tools", method="GET")
assert response.code == 200

body = json.loads(response.body.decode())
tools = body["discovered_tools"]

# Check default schema tool
assert "mock_tool" in tools
assert "inputSchema" in tools["mock_tool"]

# Check custom schema tool
assert "openai_style_tool" in tools
assert "parameters" in tools["openai_style_tool"]
18 changes: 18 additions & 0 deletions tests/test_serverapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,24 @@ def test_immutable_cache_trait():
assert serverapp.web_app.settings["static_immutable_cache"] == ["/test/immutable"]


# testing get_tools
def test_serverapp_get_tools_empty(jp_serverapp):
# testing the default empty state
tools = jp_serverapp.get_tools()
assert tools == {}

def test_serverapp_get_tools(jp_serverapp):
jp_serverapp.extension_manager.add_extension(
"tests.extension.mockextensions.mockext_tool", enabled=True
)
jp_serverapp.extension_manager.link_all_extensions()

tools = jp_serverapp.get_tools()
assert "mock_tool" in tools
metadata = tools["mock_tool"]["metadata"]
assert metadata["name"] == "mock_tool"


def test():
pass

Expand Down
Loading