Add schema validation to lowlevel server (#1005) · modelcontextprotocol/python-sdk@c8bbfc0

@@ -68,13 +68,15 @@ async def main():

6868

from __future__ import annotations as _annotations

69697070

import contextvars

71+

import json

7172

import logging

7273

import warnings

7374

from collections.abc import AsyncIterator, Awaitable, Callable, Iterable

7475

from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager

75-

from typing import Any, Generic

76+

from typing import Any, Generic, TypeAlias, cast

76777778

import anyio

79+

import jsonschema

7880

from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

7981

from pydantic import AnyUrl

8082

from typing_extensions import TypeVar

@@ -94,6 +96,11 @@ async def main():

9496

LifespanResultT = TypeVar("LifespanResultT")

9597

RequestT = TypeVar("RequestT", default=Any)

969899+

# type aliases for tool call results

100+

StructuredContent: TypeAlias = dict[str, Any]

101+

UnstructuredContent: TypeAlias = Iterable[types.ContentBlock]

102+

CombinationContent: TypeAlias = tuple[UnstructuredContent, StructuredContent]

103+97104

# This will be properly typed in each Server instance's context

98105

request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar("request_ctx")

99106

@@ -143,6 +150,7 @@ def __init__(

143150

}

144151

self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}

145152

self.notification_options = NotificationOptions()

153+

self._tool_cache: dict[str, types.Tool] = {}

146154

logger.debug("Initializing server %r", name)

147155148156

def create_initialization_options(

@@ -373,33 +381,120 @@ def decorator(func: Callable[[], Awaitable[list[types.Tool]]]):

373381374382

async def handler(_: Any):

375383

tools = await func()

384+

# Refresh the tool cache

385+

self._tool_cache.clear()

386+

for tool in tools:

387+

self._tool_cache[tool.name] = tool

376388

return types.ServerResult(types.ListToolsResult(tools=tools))

377389378390

self.request_handlers[types.ListToolsRequest] = handler

379391

return func

380392381393

return decorator

382394383-

def call_tool(self):

395+

def _make_error_result(self, error_message: str) -> types.ServerResult:

396+

"""Create a ServerResult with an error CallToolResult."""

397+

return types.ServerResult(

398+

types.CallToolResult(

399+

content=[types.TextContent(type="text", text=error_message)],

400+

isError=True,

401+

)

402+

)

403+404+

async def _get_cached_tool_definition(self, tool_name: str) -> types.Tool | None:

405+

"""Get tool definition from cache, refreshing if necessary.

406+407+

Returns the Tool object if found, None otherwise.

408+

"""

409+

if tool_name not in self._tool_cache:

410+

if types.ListToolsRequest in self.request_handlers:

411+

logger.debug("Tool cache miss for %s, refreshing cache", tool_name)

412+

await self.request_handlers[types.ListToolsRequest](None)

413+414+

tool = self._tool_cache.get(tool_name)

415+

if tool is None:

416+

logger.warning("Tool '%s' not listed, no validation will be performed", tool_name)

417+418+

return tool

419+420+

def call_tool(self, *, validate_input: bool = True):

421+

"""Register a tool call handler.

422+423+

Args:

424+

validate_input: If True, validates input against inputSchema. Default is True.

425+426+

The handler validates input against inputSchema (if validate_input=True), calls the tool function,

427+

and builds a CallToolResult with the results:

428+

- Unstructured content (iterable of ContentBlock): returned in content

429+

- Structured content (dict): returned in structuredContent, serialized JSON text returned in content

430+

- Both: returned in content and structuredContent

431+432+

If outputSchema is defined, validates structuredContent or errors if missing.

433+

"""

434+384435

def decorator(

385436

func: Callable[

386437

...,

387-

Awaitable[Iterable[types.ContentBlock]],

438+

Awaitable[UnstructuredContent | StructuredContent | CombinationContent],

388439

],

389440

):

390441

logger.debug("Registering handler for CallToolRequest")

391442392443

async def handler(req: types.CallToolRequest):

393444

try:

394-

results = await func(req.params.name, (req.params.arguments or {}))

395-

return types.ServerResult(types.CallToolResult(content=list(results), isError=False))

396-

except Exception as e:

445+

tool_name = req.params.name

446+

arguments = req.params.arguments or {}

447+

tool = await self._get_cached_tool_definition(tool_name)

448+449+

# input validation

450+

if validate_input and tool:

451+

try:

452+

jsonschema.validate(instance=arguments, schema=tool.inputSchema)

453+

except jsonschema.ValidationError as e:

454+

return self._make_error_result(f"Input validation error: {e.message}")

455+456+

# tool call

457+

results = await func(tool_name, arguments)

458+459+

# output normalization

460+

unstructured_content: UnstructuredContent

461+

maybe_structured_content: StructuredContent | None

462+

if isinstance(results, tuple) and len(results) == 2:

463+

# tool returned both structured and unstructured content

464+

unstructured_content, maybe_structured_content = cast(CombinationContent, results)

465+

elif isinstance(results, dict):

466+

# tool returned structured content only

467+

maybe_structured_content = cast(StructuredContent, results)

468+

unstructured_content = [types.TextContent(type="text", text=json.dumps(results, indent=2))]

469+

elif hasattr(results, "__iter__"):

470+

# tool returned unstructured content only

471+

unstructured_content = cast(UnstructuredContent, results)

472+

maybe_structured_content = None

473+

else:

474+

return self._make_error_result(f"Unexpected return type from tool: {type(results).__name__}")

475+476+

# output validation

477+

if tool and tool.outputSchema is not None:

478+

if maybe_structured_content is None:

479+

return self._make_error_result(

480+

"Output validation error: outputSchema defined but no structured output returned"

481+

)

482+

else:

483+

try:

484+

jsonschema.validate(instance=maybe_structured_content, schema=tool.outputSchema)

485+

except jsonschema.ValidationError as e:

486+

return self._make_error_result(f"Output validation error: {e.message}")

487+488+

# result

397489

return types.ServerResult(

398490

types.CallToolResult(

399-

content=[types.TextContent(type="text", text=str(e))],

400-

isError=True,

491+

content=list(unstructured_content),

492+

structuredContent=maybe_structured_content,

493+

isError=False,

401494

)

402495

)

496+

except Exception as e:

497+

return self._make_error_result(str(e))

403498404499

self.request_handlers[types.CallToolRequest] = handler

405500

return func