Add schema validation to lowlevel server (#1005) · modelcontextprotocol/python-sdk@c8bbfc0
@@ -68,13 +68,15 @@ async def main():
6868from __future__ import annotations as _annotations
69697070import contextvars
71+import json
7172import logging
7273import warnings
7374from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
7475from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
75-from typing import Any, Generic
76+from typing import Any, Generic, TypeAlias, cast
76777778import anyio
79+import jsonschema
7880from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
7981from pydantic import AnyUrl
8082from typing_extensions import TypeVar
@@ -94,6 +96,11 @@ async def main():
9496LifespanResultT = TypeVar("LifespanResultT")
9597RequestT = TypeVar("RequestT", default=Any)
969899+# type aliases for tool call results
100+StructuredContent: TypeAlias = dict[str, Any]
101+UnstructuredContent: TypeAlias = Iterable[types.ContentBlock]
102+CombinationContent: TypeAlias = tuple[UnstructuredContent, StructuredContent]
103+97104# This will be properly typed in each Server instance's context
98105request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar("request_ctx")
99106@@ -143,6 +150,7 @@ def __init__(
143150 }
144151self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
145152self.notification_options = NotificationOptions()
153+self._tool_cache: dict[str, types.Tool] = {}
146154logger.debug("Initializing server %r", name)
147155148156def create_initialization_options(
@@ -373,33 +381,120 @@ def decorator(func: Callable[[], Awaitable[list[types.Tool]]]):
373381374382async def handler(_: Any):
375383tools = await func()
384+# Refresh the tool cache
385+self._tool_cache.clear()
386+for tool in tools:
387+self._tool_cache[tool.name] = tool
376388return types.ServerResult(types.ListToolsResult(tools=tools))
377389378390self.request_handlers[types.ListToolsRequest] = handler
379391return func
380392381393return decorator
382394383-def call_tool(self):
395+def _make_error_result(self, error_message: str) -> types.ServerResult:
396+"""Create a ServerResult with an error CallToolResult."""
397+return types.ServerResult(
398+types.CallToolResult(
399+content=[types.TextContent(type="text", text=error_message)],
400+isError=True,
401+ )
402+ )
403+404+async def _get_cached_tool_definition(self, tool_name: str) -> types.Tool | None:
405+"""Get tool definition from cache, refreshing if necessary.
406+407+ Returns the Tool object if found, None otherwise.
408+ """
409+if tool_name not in self._tool_cache:
410+if types.ListToolsRequest in self.request_handlers:
411+logger.debug("Tool cache miss for %s, refreshing cache", tool_name)
412+await self.request_handlers[types.ListToolsRequest](None)
413+414+tool = self._tool_cache.get(tool_name)
415+if tool is None:
416+logger.warning("Tool '%s' not listed, no validation will be performed", tool_name)
417+418+return tool
419+420+def call_tool(self, *, validate_input: bool = True):
421+"""Register a tool call handler.
422+423+ Args:
424+ validate_input: If True, validates input against inputSchema. Default is True.
425+426+ The handler validates input against inputSchema (if validate_input=True), calls the tool function,
427+ and builds a CallToolResult with the results:
428+ - Unstructured content (iterable of ContentBlock): returned in content
429+ - Structured content (dict): returned in structuredContent, serialized JSON text returned in content
430+ - Both: returned in content and structuredContent
431+432+ If outputSchema is defined, validates structuredContent or errors if missing.
433+ """
434+384435def decorator(
385436func: Callable[
386437 ...,
387-Awaitable[Iterable[types.ContentBlock]],
438+Awaitable[UnstructuredContent | StructuredContent | CombinationContent],
388439 ],
389440 ):
390441logger.debug("Registering handler for CallToolRequest")
391442392443async def handler(req: types.CallToolRequest):
393444try:
394-results = await func(req.params.name, (req.params.arguments or {}))
395-return types.ServerResult(types.CallToolResult(content=list(results), isError=False))
396-except Exception as e:
445+tool_name = req.params.name
446+arguments = req.params.arguments or {}
447+tool = await self._get_cached_tool_definition(tool_name)
448+449+# input validation
450+if validate_input and tool:
451+try:
452+jsonschema.validate(instance=arguments, schema=tool.inputSchema)
453+except jsonschema.ValidationError as e:
454+return self._make_error_result(f"Input validation error: {e.message}")
455+456+# tool call
457+results = await func(tool_name, arguments)
458+459+# output normalization
460+unstructured_content: UnstructuredContent
461+maybe_structured_content: StructuredContent | None
462+if isinstance(results, tuple) and len(results) == 2:
463+# tool returned both structured and unstructured content
464+unstructured_content, maybe_structured_content = cast(CombinationContent, results)
465+elif isinstance(results, dict):
466+# tool returned structured content only
467+maybe_structured_content = cast(StructuredContent, results)
468+unstructured_content = [types.TextContent(type="text", text=json.dumps(results, indent=2))]
469+elif hasattr(results, "__iter__"):
470+# tool returned unstructured content only
471+unstructured_content = cast(UnstructuredContent, results)
472+maybe_structured_content = None
473+else:
474+return self._make_error_result(f"Unexpected return type from tool: {type(results).__name__}")
475+476+# output validation
477+if tool and tool.outputSchema is not None:
478+if maybe_structured_content is None:
479+return self._make_error_result(
480+"Output validation error: outputSchema defined but no structured output returned"
481+ )
482+else:
483+try:
484+jsonschema.validate(instance=maybe_structured_content, schema=tool.outputSchema)
485+except jsonschema.ValidationError as e:
486+return self._make_error_result(f"Output validation error: {e.message}")
487+488+# result
397489return types.ServerResult(
398490types.CallToolResult(
399-content=[types.TextContent(type="text", text=str(e))],
400-isError=True,
491+content=list(unstructured_content),
492+structuredContent=maybe_structured_content,
493+isError=False,
401494 )
402495 )
496+except Exception as e:
497+return self._make_error_result(str(e))
403498404499self.request_handlers[types.CallToolRequest] = handler
405500return func