[prototype] add example implementation of client id metadata documents (SEP-991) by pcarleton · Pull Request #839 · modelcontextprotocol/typescript-sdk

I have been working on adding CIMD support to my AS and here's what I've come up with. Adding here in case it helps:

import { OAuthClient } from '../server/models/oauth-client';
import axios from 'axios';
import { isSSRFSafeURL } from 'ssrfcheck';
import { Resolver } from 'node:dns/promises';
import { isIP } from 'node:net';
import { getClientMetadata, setClientMetadata } from '@/server/utils/aws';

const resolver = new Resolver();
resolver.setServers(['1.1.1.1']); // Use Cloudflare public DNS

export async function getClient(clientId: string, clientSecret?: string) {
  let url;
  try {
    url = new URL(clientId);
  } catch {}

  // If the clientId is a valid url, do SSRF check and fetch client metadata
  if (url) {
    // # 2. Check the cache
    const cachedClientMetadata = await getClientMetadata(clientId);
    if (cachedClientMetadata) {
      return convertClientMetadataToOAuthClient(cachedClientMetadata);
    }

    // # 2. If it's a domain name then resolve the IP to check it directly
    const originalHostname = url.hostname;
    if (!isIP(url.hostname)) {
      // Resolve the IP from DNS
      const resolvedIps = await resolver.resolve4(url.hostname);
      let ipAddress = resolvedIps.find(Boolean);
      if (!ipAddress) {
        const resolvedIps = await resolver.resolve6(url.hostname);
        ipAddress = resolvedIps.find(Boolean);
        if (!ipAddress) {
          throw new Error('Client URL is not valid');
        }
      }
      // Replace the hostname with the resolved IP Address
      url.hostname = ipAddress;
    }

    // # 3. Perform SSRF check on the IP
    const isSafe = isSSRFSafeURL(url.toString(), {
      allowedProtocols: ['https'],
      autoPrependProtocol: false,
    });
    if (!isSafe) {
      throw new Error('Client URL is not allowed');
    }

    // # 4. Send request using resolved IP that was checked
    let response;
    try {
      response = await axios.get(url.toString(), {
        timeout: 5000,
        maxContentLength: 5120,
        maxRedirects: 0, // Avoid SSRF redirect attacks
        headers: {
          Accept: 'application/json',
          // use the original url domain hostname
          Host: originalHostname,
        },
      });
    } catch (e) {
      throw new Error('Unable to fetch client metadata');
    }

    // # 5. Check the response is a JSON
    if (
      !String(response.headers['content-type'] || '')
        .split(';')
        .map(s => s.trim())
        .includes('application/json')
    ) {
      throw new Error('Client URL must return a JSON response');
    }

    // # 6. Validate the response is a valid client metadata
    const result = cimdMetadataSchema.safeParse(response.data);
    if (!result.success) {
      throw new Error(
        `Client metadata is invalid: ${result.error.issues.map(issue => issue.message).join(', ')}`
      );
    }

    // # 7. Validate the client ID matches the requested client ID
    if (result.data.client_id !== clientId) {
      throw new Error('Client ID mismatch');
    }

    // # 8. Save the client metadata to the cache and database
    await setClientMetadata(clientId, result.data);
    const oauthClient = convertClientMetadataToOAuthClient(result.data);
    await OAuthClient.upsert(oauthClient);

    return oauthClient;
  }

  const client = await OAuthClient.findOne({
    where: {
      id: clientId,
      ...(clientSecret && { secret: clientSecret }),
    },
  });
  if (!client) throw new Error('Client not found');

  return {
    id: client.id,
    redirectUris: client.redirectUris,
    grants: client.grants,
    accessTokenLifetime: client.accessTokenLifetime,
    refreshTokenLifetime: client.refreshTokenLifetime,
    name: client.name,
    uri: client.uri,
    scope: client.scope,
  };
}