[prototype] add example implementation of client id metadata documents (SEP-991) by pcarleton · Pull Request #839 · modelcontextprotocol/typescript-sdk
I have been working on adding CIMD support to my AS and here's what I've come up with. Adding here in case it helps:
import { OAuthClient } from '../server/models/oauth-client'; import axios from 'axios'; import { isSSRFSafeURL } from 'ssrfcheck'; import { Resolver } from 'node:dns/promises'; import { isIP } from 'node:net'; import { getClientMetadata, setClientMetadata } from '@/server/utils/aws'; const resolver = new Resolver(); resolver.setServers(['1.1.1.1']); // Use Cloudflare public DNS export async function getClient(clientId: string, clientSecret?: string) { let url; try { url = new URL(clientId); } catch {} // If the clientId is a valid url, do SSRF check and fetch client metadata if (url) { // # 2. Check the cache const cachedClientMetadata = await getClientMetadata(clientId); if (cachedClientMetadata) { return convertClientMetadataToOAuthClient(cachedClientMetadata); } // # 2. If it's a domain name then resolve the IP to check it directly const originalHostname = url.hostname; if (!isIP(url.hostname)) { // Resolve the IP from DNS const resolvedIps = await resolver.resolve4(url.hostname); let ipAddress = resolvedIps.find(Boolean); if (!ipAddress) { const resolvedIps = await resolver.resolve6(url.hostname); ipAddress = resolvedIps.find(Boolean); if (!ipAddress) { throw new Error('Client URL is not valid'); } } // Replace the hostname with the resolved IP Address url.hostname = ipAddress; } // # 3. Perform SSRF check on the IP const isSafe = isSSRFSafeURL(url.toString(), { allowedProtocols: ['https'], autoPrependProtocol: false, }); if (!isSafe) { throw new Error('Client URL is not allowed'); } // # 4. Send request using resolved IP that was checked let response; try { response = await axios.get(url.toString(), { timeout: 5000, maxContentLength: 5120, maxRedirects: 0, // Avoid SSRF redirect attacks headers: { Accept: 'application/json', // use the original url domain hostname Host: originalHostname, }, }); } catch (e) { throw new Error('Unable to fetch client metadata'); } // # 5. Check the response is a JSON if ( !String(response.headers['content-type'] || '') .split(';') .map(s => s.trim()) .includes('application/json') ) { throw new Error('Client URL must return a JSON response'); } // # 6. Validate the response is a valid client metadata const result = cimdMetadataSchema.safeParse(response.data); if (!result.success) { throw new Error( `Client metadata is invalid: ${result.error.issues.map(issue => issue.message).join(', ')}` ); } // # 7. Validate the client ID matches the requested client ID if (result.data.client_id !== clientId) { throw new Error('Client ID mismatch'); } // # 8. Save the client metadata to the cache and database await setClientMetadata(clientId, result.data); const oauthClient = convertClientMetadataToOAuthClient(result.data); await OAuthClient.upsert(oauthClient); return oauthClient; } const client = await OAuthClient.findOne({ where: { id: clientId, ...(clientSecret && { secret: clientSecret }), }, }); if (!client) throw new Error('Client not found'); return { id: client.id, redirectUris: client.redirectUris, grants: client.grants, accessTokenLifetime: client.accessTokenLifetime, refreshTokenLifetime: client.refreshTokenLifetime, name: client.name, uri: client.uri, scope: client.scope, }; }