From b90de755f9b9aa334077f771c67ad7c454896925 Mon Sep 17 00:00:00 2001 From: Kit Langton Date: Wed, 25 Mar 2026 20:15:05 -0400 Subject: [PATCH] fix+refactor(mcp): lifecycle tests, cancelPending fix, Effect migration (#19042) --- packages/opencode/specs/effect-migration.md | 2 +- packages/opencode/src/mcp/auth.ts | 215 ++-- packages/opencode/src/mcp/index.ts | 1147 +++++++++--------- packages/opencode/src/mcp/oauth-callback.ts | 29 +- packages/opencode/test/mcp/lifecycle.test.ts | 660 ++++++++++ 5 files changed, 1373 insertions(+), 680 deletions(-) create mode 100644 packages/opencode/test/mcp/lifecycle.test.ts diff --git a/packages/opencode/specs/effect-migration.md b/packages/opencode/specs/effect-migration.md index 073da7b32b..d98750eac9 100644 --- a/packages/opencode/specs/effect-migration.md +++ b/packages/opencode/specs/effect-migration.md @@ -175,4 +175,4 @@ Still open and likely worth migrating: - [ ] `Provider` - [x] `Project` - [ ] `LSP` -- [ ] `MCP` +- [x] `MCP` diff --git a/packages/opencode/src/mcp/auth.ts b/packages/opencode/src/mcp/auth.ts index 399986376d..3c2b93f337 100644 --- a/packages/opencode/src/mcp/auth.ts +++ b/packages/opencode/src/mcp/auth.ts @@ -1,7 +1,9 @@ import path from "path" import z from "zod" import { Global } from "../global" -import { Filesystem } from "../util/filesystem" +import { Effect, Layer, ServiceMap } from "effect" +import { AppFileSystem } from "@/filesystem" +import { makeRunPromise } from "@/effect/run-service" export namespace McpAuth { export const Tokens = z.object({ @@ -25,106 +27,155 @@ export namespace McpAuth { clientInfo: ClientInfo.optional(), codeVerifier: z.string().optional(), oauthState: z.string().optional(), - serverUrl: z.string().optional(), // Track the URL these credentials are for + serverUrl: z.string().optional(), }) export type Entry = z.infer const filepath = path.join(Global.Path.data, "mcp-auth.json") - export async function get(mcpName: string): Promise { - const data = await all() - return data[mcpName] + export interface Interface { + readonly all: () => Effect.Effect> + readonly get: (mcpName: string) => Effect.Effect + readonly getForUrl: (mcpName: string, serverUrl: string) => Effect.Effect + readonly set: (mcpName: string, entry: Entry, serverUrl?: string) => Effect.Effect + readonly remove: (mcpName: string) => Effect.Effect + readonly updateTokens: (mcpName: string, tokens: Tokens, serverUrl?: string) => Effect.Effect + readonly updateClientInfo: (mcpName: string, clientInfo: ClientInfo, serverUrl?: string) => Effect.Effect + readonly updateCodeVerifier: (mcpName: string, codeVerifier: string) => Effect.Effect + readonly clearCodeVerifier: (mcpName: string) => Effect.Effect + readonly updateOAuthState: (mcpName: string, oauthState: string) => Effect.Effect + readonly getOAuthState: (mcpName: string) => Effect.Effect + readonly clearOAuthState: (mcpName: string) => Effect.Effect + readonly isTokenExpired: (mcpName: string) => Effect.Effect } - /** - * Get auth entry and validate it's for the correct URL. - * Returns undefined if URL has changed (credentials are invalid). - */ - export async function getForUrl(mcpName: string, serverUrl: string): Promise { - const entry = await get(mcpName) - if (!entry) return undefined + export class Service extends ServiceMap.Service()("@opencode/McpAuth") {} - // If no serverUrl is stored, this is from an old version - consider it invalid - if (!entry.serverUrl) return undefined + export const layer = Layer.effect( + Service, + Effect.gen(function* () { + const fs = yield* AppFileSystem.Service - // If URL has changed, credentials are invalid - if (entry.serverUrl !== serverUrl) return undefined + const all = Effect.fn("McpAuth.all")(function* () { + return yield* fs.readJson(filepath).pipe( + Effect.map((data) => data as Record), + Effect.catch(() => Effect.succeed({} as Record)), + ) + }) - return entry - } + const get = Effect.fn("McpAuth.get")(function* (mcpName: string) { + const data = yield* all() + return data[mcpName] + }) - export async function all(): Promise> { - return Filesystem.readJson>(filepath).catch(() => ({})) - } + const getForUrl = Effect.fn("McpAuth.getForUrl")(function* (mcpName: string, serverUrl: string) { + const entry = yield* get(mcpName) + if (!entry) return undefined + if (!entry.serverUrl) return undefined + if (entry.serverUrl !== serverUrl) return undefined + return entry + }) - export async function set(mcpName: string, entry: Entry, serverUrl?: string): Promise { - const data = await all() - // Always update serverUrl if provided - if (serverUrl) { - entry.serverUrl = serverUrl - } - await Filesystem.writeJson(filepath, { ...data, [mcpName]: entry }, 0o600) - } + const set = Effect.fn("McpAuth.set")(function* (mcpName: string, entry: Entry, serverUrl?: string) { + const data = yield* all() + if (serverUrl) entry.serverUrl = serverUrl + yield* fs.writeJson(filepath, { ...data, [mcpName]: entry }, 0o600).pipe(Effect.orDie) + }) - export async function remove(mcpName: string): Promise { - const data = await all() - delete data[mcpName] - await Filesystem.writeJson(filepath, data, 0o600) - } + const remove = Effect.fn("McpAuth.remove")(function* (mcpName: string) { + const data = yield* all() + delete data[mcpName] + yield* fs.writeJson(filepath, data, 0o600).pipe(Effect.orDie) + }) - export async function updateTokens(mcpName: string, tokens: Tokens, serverUrl?: string): Promise { - const entry = (await get(mcpName)) ?? {} - entry.tokens = tokens - await set(mcpName, entry, serverUrl) - } + const updateField = (field: K, spanName: string) => + Effect.fn(`McpAuth.${spanName}`)(function* (mcpName: string, value: NonNullable, serverUrl?: string) { + const entry = (yield* get(mcpName)) ?? {} + entry[field] = value + yield* set(mcpName, entry, serverUrl) + }) - export async function updateClientInfo(mcpName: string, clientInfo: ClientInfo, serverUrl?: string): Promise { - const entry = (await get(mcpName)) ?? {} - entry.clientInfo = clientInfo - await set(mcpName, entry, serverUrl) - } + const clearField = (field: K, spanName: string) => + Effect.fn(`McpAuth.${spanName}`)(function* (mcpName: string) { + const entry = yield* get(mcpName) + if (entry) { + delete entry[field] + yield* set(mcpName, entry) + } + }) - export async function updateCodeVerifier(mcpName: string, codeVerifier: string): Promise { - const entry = (await get(mcpName)) ?? {} - entry.codeVerifier = codeVerifier - await set(mcpName, entry) - } + const updateTokens = updateField("tokens", "updateTokens") + const updateClientInfo = updateField("clientInfo", "updateClientInfo") + const updateCodeVerifier = updateField("codeVerifier", "updateCodeVerifier") + const updateOAuthState = updateField("oauthState", "updateOAuthState") + const clearCodeVerifier = clearField("codeVerifier", "clearCodeVerifier") + const clearOAuthState = clearField("oauthState", "clearOAuthState") - export async function clearCodeVerifier(mcpName: string): Promise { - const entry = await get(mcpName) - if (entry) { - delete entry.codeVerifier - await set(mcpName, entry) - } - } + const getOAuthState = Effect.fn("McpAuth.getOAuthState")(function* (mcpName: string) { + const entry = yield* get(mcpName) + return entry?.oauthState + }) - export async function updateOAuthState(mcpName: string, oauthState: string): Promise { - const entry = (await get(mcpName)) ?? {} - entry.oauthState = oauthState - await set(mcpName, entry) - } + const isTokenExpired = Effect.fn("McpAuth.isTokenExpired")(function* (mcpName: string) { + const entry = yield* get(mcpName) + if (!entry?.tokens) return null + if (!entry.tokens.expiresAt) return false + return entry.tokens.expiresAt < Date.now() / 1000 + }) - export async function getOAuthState(mcpName: string): Promise { - const entry = await get(mcpName) - return entry?.oauthState - } + return Service.of({ + all, + get, + getForUrl, + set, + remove, + updateTokens, + updateClientInfo, + updateCodeVerifier, + clearCodeVerifier, + updateOAuthState, + getOAuthState, + clearOAuthState, + isTokenExpired, + }) + }), + ) - export async function clearOAuthState(mcpName: string): Promise { - const entry = await get(mcpName) - if (entry) { - delete entry.oauthState - await set(mcpName, entry) - } - } + const defaultLayer = layer.pipe(Layer.provide(AppFileSystem.defaultLayer)) - /** - * Check if stored tokens are expired. - * Returns null if no tokens exist, false if no expiry or not expired, true if expired. - */ - export async function isTokenExpired(mcpName: string): Promise { - const entry = await get(mcpName) - if (!entry?.tokens) return null - if (!entry.tokens.expiresAt) return false - return entry.tokens.expiresAt < Date.now() / 1000 - } + const runPromise = makeRunPromise(Service, defaultLayer) + + // Async facades for backward compat (used by McpOAuthProvider, CLI) + + export const get = async (mcpName: string) => runPromise((svc) => svc.get(mcpName)) + + export const getForUrl = async (mcpName: string, serverUrl: string) => + runPromise((svc) => svc.getForUrl(mcpName, serverUrl)) + + export const all = async () => runPromise((svc) => svc.all()) + + export const set = async (mcpName: string, entry: Entry, serverUrl?: string) => + runPromise((svc) => svc.set(mcpName, entry, serverUrl)) + + export const remove = async (mcpName: string) => runPromise((svc) => svc.remove(mcpName)) + + export const updateTokens = async (mcpName: string, tokens: Tokens, serverUrl?: string) => + runPromise((svc) => svc.updateTokens(mcpName, tokens, serverUrl)) + + export const updateClientInfo = async (mcpName: string, clientInfo: ClientInfo, serverUrl?: string) => + runPromise((svc) => svc.updateClientInfo(mcpName, clientInfo, serverUrl)) + + export const updateCodeVerifier = async (mcpName: string, codeVerifier: string) => + runPromise((svc) => svc.updateCodeVerifier(mcpName, codeVerifier)) + + export const clearCodeVerifier = async (mcpName: string) => runPromise((svc) => svc.clearCodeVerifier(mcpName)) + + export const updateOAuthState = async (mcpName: string, oauthState: string) => + runPromise((svc) => svc.updateOAuthState(mcpName, oauthState)) + + export const getOAuthState = async (mcpName: string) => runPromise((svc) => svc.getOAuthState(mcpName)) + + export const clearOAuthState = async (mcpName: string) => runPromise((svc) => svc.clearOAuthState(mcpName)) + + export const isTokenExpired = async (mcpName: string) => runPromise((svc) => svc.isTokenExpired(mcpName)) } diff --git a/packages/opencode/src/mcp/index.ts b/packages/opencode/src/mcp/index.ts index bf5a0d3ce7..748f4abf02 100644 --- a/packages/opencode/src/mcp/index.ts +++ b/packages/opencode/src/mcp/index.ts @@ -11,12 +11,12 @@ import { } from "@modelcontextprotocol/sdk/types.js" import { Config } from "../config/config" import { Log } from "../util/log" -import { Process } from "../util/process" import { NamedError } from "@opencode-ai/util/error" import z from "zod/v4" import { Instance } from "../project/instance" import { Installation } from "../installation" import { withTimeout } from "@/util/timeout" +import { AppFileSystem } from "@/filesystem" import { McpOAuthProvider } from "./oauth-provider" import { McpOAuthCallback } from "./oauth-callback" import { McpAuth } from "./auth" @@ -24,6 +24,13 @@ import { BusEvent } from "../bus/bus-event" import { Bus } from "@/bus" import { TuiEvent } from "@/cli/cmd/tui/event" import open from "open" +import { Effect, Layer, Option, ServiceMap, Stream } from "effect" +import { InstanceState } from "@/effect/instance-state" +import { makeRunPromise } from "@/effect/run-service" +import { ChildProcess, ChildProcessSpawner } from "effect/unstable/process" +import * as CrossSpawnSpawner from "@/effect/cross-spawn-spawner" +import { NodeFileSystem } from "@effect/platform-node" +import * as NodePath from "@effect/platform-node/NodePath" export namespace MCP { const log = Log.create({ service: "mcp" }) @@ -109,16 +116,21 @@ export namespace MCP { }) export type Status = z.infer - // Register notification handlers for MCP client - function registerNotificationHandlers(client: MCPClient, serverName: string) { - client.setNotificationHandler(ToolListChangedNotificationSchema, async () => { - log.info("tools list changed notification received", { server: serverName }) - Bus.publish(ToolsChanged, { server: serverName }) - }) + // Store transports for OAuth servers to allow finishing auth + type TransportWithAuth = StreamableHTTPClientTransport | SSEClientTransport + const pendingOAuthTransports = new Map() + + // Prompt cache types + type PromptInfo = Awaited>["prompts"][number] + type ResourceInfo = Awaited>["resources"][number] + type McpEntry = NonNullable[string] + + function isMcpConfigured(entry: McpEntry): entry is Config.Mcp { + return typeof entry === "object" && entry !== null && "type" in entry } // Convert MCP tool definition to AI SDK Tool type - async function convertMcpTool(mcpTool: MCPToolDef, client: MCPClient, timeout?: number): Promise { + function convertMcpTool(mcpTool: MCPToolDef, client: MCPClient, timeout?: number): Tool { const inputSchema = mcpTool.inputSchema // Spread first, then override type to ensure it's always "object" @@ -148,178 +160,33 @@ export namespace MCP { }) } - // Store transports for OAuth servers to allow finishing auth - type TransportWithAuth = StreamableHTTPClientTransport | SSEClientTransport - const pendingOAuthTransports = new Map() - - // Prompt cache types - type PromptInfo = Awaited>["prompts"][number] - - type ResourceInfo = Awaited>["resources"][number] - type McpEntry = NonNullable[string] - function isMcpConfigured(entry: McpEntry): entry is Config.Mcp { - return typeof entry === "object" && entry !== null && "type" in entry - } - - async function descendants(pid: number): Promise { - if (process.platform === "win32") return [] - const pids: number[] = [] - const queue = [pid] - while (queue.length > 0) { - const current = queue.shift()! - const lines = await Process.lines(["pgrep", "-P", String(current)], { nothrow: true }) - for (const tok of lines) { - const cpid = parseInt(tok, 10) - if (!isNaN(cpid) && !pids.includes(cpid)) { - pids.push(cpid) - queue.push(cpid) - } - } - } - return pids - } - - const state = Instance.state( - async () => { - const cfg = await Config.get() - const config = cfg.mcp ?? {} - const clients: Record = {} - const status: Record = {} - - await Promise.all( - Object.entries(config).map(async ([key, mcp]) => { - if (!isMcpConfigured(mcp)) { - log.error("Ignoring MCP config entry without type", { key }) - return - } - - // If disabled by config, mark as disabled without trying to connect - if (mcp.enabled === false) { - status[key] = { status: "disabled" } - return - } - - const result = await create(key, mcp).catch(() => undefined) - if (!result) return - - status[key] = result.status - - if (result.mcpClient) { - clients[key] = result.mcpClient - } - }), - ) - return { - status, - clients, - } - }, - async (state) => { - // The MCP SDK only signals the direct child process on close. - // Servers like chrome-devtools-mcp spawn grandchild processes - // (e.g. Chrome) that the SDK never reaches, leaving them orphaned. - // Kill the full descendant tree first so the server exits promptly - // and no processes are left behind. - for (const client of Object.values(state.clients)) { - const pid = (client.transport as any)?.pid - if (typeof pid !== "number") continue - for (const dpid of await descendants(pid)) { - try { - process.kill(dpid, "SIGTERM") - } catch {} - } - } - - await Promise.all( - Object.values(state.clients).map((client) => - client.close().catch((error) => { - log.error("Failed to close MCP client", { - error, - }) - }), - ), - ) - pendingOAuthTransports.clear() - }, - ) - - // Helper function to fetch prompts for a specific client - async function fetchPromptsForClient(clientName: string, client: Client) { - const prompts = await client.listPrompts().catch((e) => { - log.error("failed to get prompts", { clientName, error: e.message }) + async function defs(key: string, client: MCPClient, timeout?: number) { + const result = await withTimeout(client.listTools(), timeout ?? DEFAULT_TIMEOUT).catch((err) => { + log.error("failed to get tools from client", { key, error: err }) return undefined }) - - if (!prompts) { - return - } - - const commands: Record = {} - - for (const prompt of prompts.prompts) { - const sanitizedClientName = clientName.replace(/[^a-zA-Z0-9_-]/g, "_") - const sanitizedPromptName = prompt.name.replace(/[^a-zA-Z0-9_-]/g, "_") - const key = sanitizedClientName + ":" + sanitizedPromptName - - commands[key] = { ...prompt, client: clientName } - } - return commands + return result?.tools } - async function fetchResourcesForClient(clientName: string, client: Client) { - const resources = await client.listResources().catch((e) => { - log.error("failed to get prompts", { clientName, error: e.message }) + async function fetchFromClient( + clientName: string, + client: Client, + listFn: (c: Client) => Promise, + label: string, + ): Promise | undefined> { + const items = await listFn(client).catch((e: any) => { + log.error(`failed to get ${label}`, { clientName, error: e.message }) return undefined }) + if (!items) return undefined - if (!resources) { - return - } - - const commands: Record = {} - - for (const resource of resources.resources) { - const sanitizedClientName = clientName.replace(/[^a-zA-Z0-9_-]/g, "_") - const sanitizedResourceName = resource.name.replace(/[^a-zA-Z0-9_-]/g, "_") - const key = sanitizedClientName + ":" + sanitizedResourceName - - commands[key] = { ...resource, client: clientName } - } - return commands - } - - export async function add(name: string, mcp: Config.Mcp) { - const s = await state() - const result = await create(name, mcp) - if (!result) { - const status = { - status: "failed" as const, - error: "unknown error", - } - s.status[name] = status - return { - status, - } - } - if (!result.mcpClient) { - s.status[name] = result.status - return { - status: s.status, - } - } - // Close existing client if present to prevent memory leaks - const existingClient = s.clients[name] - if (existingClient) { - await existingClient.close().catch((error) => { - log.error("Failed to close existing MCP client", { name, error }) - }) - } - s.clients[name] = result.mcpClient - s.status[name] = result.status - - return { - status: s.status, + const out: Record = {} + const sanitizedClient = clientName.replace(/[^a-zA-Z0-9_-]/g, "_") + for (const item of items) { + const sanitizedName = item.name.replace(/[^a-zA-Z0-9_-]/g, "_") + out[sanitizedClient + ":" + sanitizedName] = { ...item, client: clientName } } + return out } async function create(key: string, mcp: Config.Mcp) { @@ -385,7 +252,6 @@ export namespace MCP { version: Installation.VERSION, }) await withTimeout(client.connect(transport), connectTimeout) - registerNotificationHandlers(client, key) mcpClient = client log.info("connected", { key, transport: name }) status = { status: "connected" } @@ -470,7 +336,6 @@ export namespace MCP { version: Installation.VERSION, }) await withTimeout(client.connect(transport), connectTimeout) - registerNotificationHandlers(client, key) mcpClient = client status = { status: "connected", @@ -503,475 +368,569 @@ export namespace MCP { } } - const result = await withTimeout(mcpClient.listTools(), mcp.timeout ?? DEFAULT_TIMEOUT).catch((err) => { - log.error("failed to get tools from client", { key, error: err }) - return undefined - }) - if (!result) { + const listed = await defs(key, mcpClient, mcp.timeout) + if (!listed) { await mcpClient.close().catch((error) => { log.error("Failed to close MCP client", { error, }) }) - status = { - status: "failed", - error: "Failed to get tools", - } return { mcpClient: undefined, - status: { - status: "failed" as const, - error: "Failed to get tools", - }, + status: { status: "failed" as const, error: "Failed to get tools" }, } } - log.info("create() successfully created client", { key, toolCount: result.tools.length }) + log.info("create() successfully created client", { key, toolCount: listed.length }) return { mcpClient, status, + defs: listed, } } - export async function status() { - const s = await state() - const cfg = await Config.get() - const config = cfg.mcp ?? {} - const result: Record = {} + // --- Effect Service --- - // Include all configured MCPs from config, not just connected ones - for (const [key, mcp] of Object.entries(config)) { - if (!isMcpConfigured(mcp)) continue - result[key] = s.status[key] ?? { status: "disabled" } - } - - return result + interface State { + status: Record + clients: Record + defs: Record } - export async function clients() { - return state().then((state) => state.clients) + export interface Interface { + readonly status: () => Effect.Effect> + readonly clients: () => Effect.Effect> + readonly tools: () => Effect.Effect> + readonly prompts: () => Effect.Effect> + readonly resources: () => Effect.Effect> + readonly add: (name: string, mcp: Config.Mcp) => Effect.Effect<{ status: Record | Status }> + readonly connect: (name: string) => Effect.Effect + readonly disconnect: (name: string) => Effect.Effect + readonly getPrompt: ( + clientName: string, + name: string, + args?: Record, + ) => Effect.Effect> | undefined> + readonly readResource: ( + clientName: string, + resourceUri: string, + ) => Effect.Effect> | undefined> + readonly startAuth: (mcpName: string) => Effect.Effect<{ authorizationUrl: string; oauthState: string }> + readonly authenticate: (mcpName: string) => Effect.Effect + readonly finishAuth: (mcpName: string, authorizationCode: string) => Effect.Effect + readonly removeAuth: (mcpName: string) => Effect.Effect + readonly supportsOAuth: (mcpName: string) => Effect.Effect + readonly hasStoredTokens: (mcpName: string) => Effect.Effect + readonly getAuthStatus: (mcpName: string) => Effect.Effect } - export async function connect(name: string) { - const cfg = await Config.get() - const config = cfg.mcp ?? {} - const mcp = config[name] - if (!mcp) { - log.error("MCP config not found", { name }) - return - } + export class Service extends ServiceMap.Service()("@opencode/MCP") {} - if (!isMcpConfigured(mcp)) { - log.error("Ignoring MCP connect request for config without type", { name }) - return - } + export const layer = Layer.effect( + Service, + Effect.gen(function* () { + const spawner = yield* ChildProcessSpawner.ChildProcessSpawner + const auth = yield* McpAuth.Service - const result = await create(name, { ...mcp, enabled: true }) - - if (!result) { - const s = await state() - s.status[name] = { - status: "failed", - error: "Unknown error during connection", - } - return - } - - const s = await state() - s.status[name] = result.status - if (result.mcpClient) { - // Close existing client if present to prevent memory leaks - const existingClient = s.clients[name] - if (existingClient) { - await existingClient.close().catch((error) => { - log.error("Failed to close existing MCP client", { name, error }) - }) - } - s.clients[name] = result.mcpClient - } - } - - export async function disconnect(name: string) { - const s = await state() - const client = s.clients[name] - if (client) { - await client.close().catch((error) => { - log.error("Failed to close MCP client", { name, error }) - }) - delete s.clients[name] - } - s.status[name] = { status: "disabled" } - } - - export async function tools() { - const result: Record = {} - const s = await state() - const cfg = await Config.get() - const config = cfg.mcp ?? {} - const clientsSnapshot = await clients() - const defaultTimeout = cfg.experimental?.mcp_timeout - - const connectedClients = Object.entries(clientsSnapshot).filter( - ([clientName]) => s.status[clientName]?.status === "connected", - ) - - const toolsResults = await Promise.all( - connectedClients.map(async ([clientName, client]) => { - const toolsResult = await client.listTools().catch((e) => { - log.error("failed to get tools", { clientName, error: e.message }) - const failedStatus = { - status: "failed" as const, - error: e instanceof Error ? e.message : String(e), + const descendants = Effect.fnUntraced( + function* (pid: number) { + if (process.platform === "win32") return [] as number[] + const pids: number[] = [] + const queue = [pid] + while (queue.length > 0) { + const current = queue.shift()! + const handle = yield* spawner.spawn( + ChildProcess.make("pgrep", ["-P", String(current)], { stdin: "ignore" }), + ) + const text = yield* Stream.mkString(Stream.decodeText(handle.stdout)) + yield* handle.exitCode + for (const tok of text.split("\n")) { + const cpid = parseInt(tok, 10) + if (!isNaN(cpid) && !pids.includes(cpid)) { + pids.push(cpid) + queue.push(cpid) + } + } } - s.status[clientName] = failedStatus - delete s.clients[clientName] - return undefined - }) - return { clientName, client, toolsResult } - }), - ) - - for (const { clientName, client, toolsResult } of toolsResults) { - if (!toolsResult) continue - const mcpConfig = config[clientName] - const entry = isMcpConfigured(mcpConfig) ? mcpConfig : undefined - const timeout = entry?.timeout ?? defaultTimeout - for (const mcpTool of toolsResult.tools) { - const sanitizedClientName = clientName.replace(/[^a-zA-Z0-9_-]/g, "_") - const sanitizedToolName = mcpTool.name.replace(/[^a-zA-Z0-9_-]/g, "_") - result[sanitizedClientName + "_" + sanitizedToolName] = await convertMcpTool(mcpTool, client, timeout) - } - } - return result - } - - export async function prompts() { - const s = await state() - const clientsSnapshot = await clients() - - const prompts = Object.fromEntries( - ( - await Promise.all( - Object.entries(clientsSnapshot).map(async ([clientName, client]) => { - if (s.status[clientName]?.status !== "connected") { - return [] - } - - return Object.entries((await fetchPromptsForClient(clientName, client)) ?? {}) - }), - ) - ).flat(), - ) - - return prompts - } - - export async function resources() { - const s = await state() - const clientsSnapshot = await clients() - - const result = Object.fromEntries( - ( - await Promise.all( - Object.entries(clientsSnapshot).map(async ([clientName, client]) => { - if (s.status[clientName]?.status !== "connected") { - return [] - } - - return Object.entries((await fetchResourcesForClient(clientName, client)) ?? {}) - }), - ) - ).flat(), - ) - - return result - } - - export async function getPrompt(clientName: string, name: string, args?: Record) { - const clientsSnapshot = await clients() - const client = clientsSnapshot[clientName] - - if (!client) { - log.warn("client not found for prompt", { - clientName, - }) - return undefined - } - - const result = await client - .getPrompt({ - name: name, - arguments: args, - }) - .catch((e) => { - log.error("failed to get prompt from MCP server", { - clientName, - promptName: name, - error: e.message, - }) - return undefined - }) - - return result - } - - export async function readResource(clientName: string, resourceUri: string) { - const clientsSnapshot = await clients() - const client = clientsSnapshot[clientName] - - if (!client) { - log.warn("client not found for prompt", { - clientName: clientName, - }) - return undefined - } - - const result = await client - .readResource({ - uri: resourceUri, - }) - .catch((e) => { - log.error("failed to get prompt from MCP server", { - clientName: clientName, - resourceUri: resourceUri, - error: e.message, - }) - return undefined - }) - - return result - } - - /** - * Start OAuth authentication flow for an MCP server. - * Returns the authorization URL that should be opened in a browser. - */ - export async function startAuth(mcpName: string): Promise<{ authorizationUrl: string }> { - const cfg = await Config.get() - const mcpConfig = cfg.mcp?.[mcpName] - - if (!mcpConfig) { - throw new Error(`MCP server not found: ${mcpName}`) - } - - if (!isMcpConfigured(mcpConfig)) { - throw new Error(`MCP server ${mcpName} is disabled or missing configuration`) - } - - if (mcpConfig.type !== "remote") { - throw new Error(`MCP server ${mcpName} is not a remote server`) - } - - if (mcpConfig.oauth === false) { - throw new Error(`MCP server ${mcpName} has OAuth explicitly disabled`) - } - - // Start the callback server - await McpOAuthCallback.ensureRunning() - - // Generate and store a cryptographically secure state parameter BEFORE creating the provider - // The SDK will call provider.state() to read this value - const oauthState = Array.from(crypto.getRandomValues(new Uint8Array(32))) - .map((b) => b.toString(16).padStart(2, "0")) - .join("") - await McpAuth.updateOAuthState(mcpName, oauthState) - - // Create a new auth provider for this flow - // OAuth config is optional - if not provided, we'll use auto-discovery - const oauthConfig = typeof mcpConfig.oauth === "object" ? mcpConfig.oauth : undefined - let capturedUrl: URL | undefined - const authProvider = new McpOAuthProvider( - mcpName, - mcpConfig.url, - { - clientId: oauthConfig?.clientId, - clientSecret: oauthConfig?.clientSecret, - scope: oauthConfig?.scope, - }, - { - onRedirect: async (url) => { - capturedUrl = url + return pids }, - }, - ) + Effect.scoped, + Effect.catch(() => Effect.succeed([] as number[])), + ) - // Create transport with auth provider - const transport = new StreamableHTTPClientTransport(new URL(mcpConfig.url), { - authProvider, - }) + function watch(s: State, name: string, client: MCPClient, timeout?: number) { + client.setNotificationHandler(ToolListChangedNotificationSchema, async () => { + log.info("tools list changed notification received", { server: name }) + if (s.clients[name] !== client || s.status[name]?.status !== "connected") return - // Try to connect - this will trigger the OAuth flow - try { - const client = new Client({ - name: "opencode", - version: Installation.VERSION, - }) - await client.connect(transport) - // If we get here, we're already authenticated - return { authorizationUrl: "" } - } catch (error) { - if (error instanceof UnauthorizedError && capturedUrl) { - // Store transport for finishAuth - pendingOAuthTransports.set(mcpName, transport) - return { authorizationUrl: capturedUrl.toString() } - } - throw error - } - } + const listed = await defs(name, client, timeout) + if (!listed) return + if (s.clients[name] !== client || s.status[name]?.status !== "connected") return - /** - * Complete OAuth authentication after user authorizes in browser. - * Opens the browser and waits for callback. - */ - export async function authenticate(mcpName: string): Promise { - const { authorizationUrl } = await startAuth(mcpName) - - if (!authorizationUrl) { - // Already authenticated - const s = await state() - return s.status[mcpName] ?? { status: "connected" } - } - - // Get the state that was already generated and stored in startAuth() - const oauthState = await McpAuth.getOAuthState(mcpName) - if (!oauthState) { - throw new Error("OAuth state not found - this should not happen") - } - - // The SDK has already added the state parameter to the authorization URL - // We just need to open the browser - log.info("opening browser for oauth", { mcpName, url: authorizationUrl, state: oauthState }) - - // Register the callback BEFORE opening the browser to avoid race condition - // when the IdP has an active SSO session and redirects immediately - const callbackPromise = McpOAuthCallback.waitForCallback(oauthState) - - try { - const subprocess = await open(authorizationUrl) - // The open package spawns a detached process and returns immediately. - // We need to listen for errors which fire asynchronously: - // - "error" event: command not found (ENOENT) - // - "exit" with non-zero code: command exists but failed (e.g., no display) - await new Promise((resolve, reject) => { - // Give the process a moment to fail if it's going to - const timeout = setTimeout(() => resolve(), 500) - subprocess.on("error", (error) => { - clearTimeout(timeout) - reject(error) + s.defs[name] = listed + await Bus.publish(ToolsChanged, { server: name }).catch((error) => + log.warn("failed to publish tools changed", { server: name, error }), + ) }) - subprocess.on("exit", (code) => { - if (code !== null && code !== 0) { - clearTimeout(timeout) - reject(new Error(`Browser open failed with exit code ${code}`)) + } + + const cache = yield* InstanceState.make( + Effect.fn("MCP.state")(function* () { + const cfg = yield* Effect.promise(() => Config.get()) + const config = cfg.mcp ?? {} + const s: State = { + status: {}, + clients: {}, + defs: {}, + } + + yield* Effect.forEach( + Object.entries(config), + ([key, mcp]) => + Effect.gen(function* () { + if (!isMcpConfigured(mcp)) { + log.error("Ignoring MCP config entry without type", { key }) + return + } + + if (mcp.enabled === false) { + s.status[key] = { status: "disabled" } + return + } + + const result = yield* Effect.promise(() => create(key, mcp).catch(() => undefined)) + if (!result) return + + s.status[key] = result.status + if (result.mcpClient) { + s.clients[key] = result.mcpClient + s.defs[key] = result.defs + watch(s, key, result.mcpClient, mcp.timeout) + } + }), + { concurrency: "unbounded" }, + ) + + yield* Effect.addFinalizer(() => + Effect.gen(function* () { + yield* Effect.forEach( + Object.values(s.clients), + (client) => + Effect.gen(function* () { + const pid = (client.transport as any)?.pid + if (typeof pid === "number") { + const pids = yield* descendants(pid) + for (const dpid of pids) { + try { + process.kill(dpid, "SIGTERM") + } catch {} + } + } + yield* Effect.tryPromise(() => client.close()).pipe(Effect.ignore) + }), + { concurrency: "unbounded" }, + ) + pendingOAuthTransports.clear() + }), + ) + + return s + }), + ) + + function closeClient(s: State, name: string) { + const client = s.clients[name] + delete s.defs[name] + if (!client) return Effect.void + return Effect.promise(() => + client.close().catch((error: any) => log.error("failed to close MCP client", { name, error })), + ) + } + + const status = Effect.fn("MCP.status")(function* () { + const s = yield* InstanceState.get(cache) + const cfg = yield* Effect.promise(() => Config.get()) + const config = cfg.mcp ?? {} + const result: Record = {} + + for (const [key, mcp] of Object.entries(config)) { + if (!isMcpConfigured(mcp)) continue + result[key] = s.status[key] ?? { status: "disabled" } + } + + return result + }) + + const clients = Effect.fn("MCP.clients")(function* () { + const s = yield* InstanceState.get(cache) + return s.clients + }) + + const createAndStore = Effect.fn("MCP.createAndStore")(function* (name: string, mcp: Config.Mcp) { + const s = yield* InstanceState.get(cache) + const result = yield* Effect.promise(() => create(name, mcp)) + + if (!result) { + yield* closeClient(s, name) + delete s.clients[name] + s.status[name] = { status: "failed" as const, error: "unknown error" } + return s.status[name] + } + + s.status[name] = result.status + if (!result.mcpClient) { + yield* closeClient(s, name) + delete s.clients[name] + return result.status + } + + yield* closeClient(s, name) + s.clients[name] = result.mcpClient + s.defs[name] = result.defs + watch(s, name, result.mcpClient, mcp.timeout) + return result.status + }) + + const add = Effect.fn("MCP.add")(function* (name: string, mcp: Config.Mcp) { + yield* createAndStore(name, mcp) + const s = yield* InstanceState.get(cache) + return { status: s.status } + }) + + const connect = Effect.fn("MCP.connect")(function* (name: string) { + const mcp = yield* getMcpConfig(name) + if (!mcp) { + log.error("MCP config not found or invalid", { name }) + return + } + yield* createAndStore(name, { ...mcp, enabled: true }) + }) + + const disconnect = Effect.fn("MCP.disconnect")(function* (name: string) { + const s = yield* InstanceState.get(cache) + yield* closeClient(s, name) + delete s.clients[name] + s.status[name] = { status: "disabled" } + }) + + const tools = Effect.fn("MCP.tools")(function* () { + const result: Record = {} + const s = yield* InstanceState.get(cache) + const cfg = yield* Effect.promise(() => Config.get()) + const config = cfg.mcp ?? {} + const defaultTimeout = cfg.experimental?.mcp_timeout + + const connectedClients = Object.entries(s.clients).filter( + ([clientName]) => s.status[clientName]?.status === "connected", + ) + + yield* Effect.forEach( + connectedClients, + ([clientName, client]) => + Effect.gen(function* () { + const mcpConfig = config[clientName] + const entry = mcpConfig && isMcpConfigured(mcpConfig) ? mcpConfig : undefined + + const listed = s.defs[clientName] + if (!listed) { + log.warn("missing cached tools for connected server", { clientName }) + return + } + + const timeout = entry?.timeout ?? defaultTimeout + for (const mcpTool of listed) { + const sanitizedClientName = clientName.replace(/[^a-zA-Z0-9_-]/g, "_") + const sanitizedToolName = mcpTool.name.replace(/[^a-zA-Z0-9_-]/g, "_") + result[sanitizedClientName + "_" + sanitizedToolName] = convertMcpTool(mcpTool, client, timeout) + } + }), + { concurrency: "unbounded" }, + ) + return result + }) + + function collectFromConnected( + s: State, + fetchFn: (clientName: string, client: Client) => Promise | undefined>, + ) { + return Effect.forEach( + Object.entries(s.clients).filter(([name]) => s.status[name]?.status === "connected"), + ([clientName, client]) => + Effect.promise(async () => Object.entries((await fetchFn(clientName, client)) ?? {})), + { concurrency: "unbounded" }, + ).pipe(Effect.map((results) => Object.fromEntries(results.flat()))) + } + + const prompts = Effect.fn("MCP.prompts")(function* () { + const s = yield* InstanceState.get(cache) + return yield* collectFromConnected(s, (name, client) => + fetchFromClient(name, client, (c) => c.listPrompts().then((r) => r.prompts), "prompts"), + ) + }) + + const resources = Effect.fn("MCP.resources")(function* () { + const s = yield* InstanceState.get(cache) + return yield* collectFromConnected(s, (name, client) => + fetchFromClient(name, client, (c) => c.listResources().then((r) => r.resources), "resources"), + ) + }) + + const withClient = Effect.fnUntraced(function* ( + clientName: string, + fn: (client: MCPClient) => Promise, + label: string, + meta?: Record, + ) { + const s = yield* InstanceState.get(cache) + const client = s.clients[clientName] + if (!client) { + log.warn(`client not found for ${label}`, { clientName }) + return undefined + } + return yield* Effect.tryPromise({ + try: () => fn(client), + catch: (e: any) => { + log.error(`failed to ${label}`, { clientName, ...meta, error: e?.message }) + return e + }, + }).pipe(Effect.orElseSucceed(() => undefined)) + }) + + const getPrompt = Effect.fn("MCP.getPrompt")(function* ( + clientName: string, + name: string, + args?: Record, + ) { + return yield* withClient(clientName, (client) => client.getPrompt({ name, arguments: args }), "getPrompt", { + promptName: name, + }) + }) + + const readResource = Effect.fn("MCP.readResource")(function* (clientName: string, resourceUri: string) { + return yield* withClient(clientName, (client) => client.readResource({ uri: resourceUri }), "readResource", { + resourceUri, + }) + }) + + const getMcpConfig = Effect.fnUntraced(function* (mcpName: string) { + const cfg = yield* Effect.promise(() => Config.get()) + const mcpConfig = cfg.mcp?.[mcpName] + if (!mcpConfig || !isMcpConfigured(mcpConfig)) return undefined + return mcpConfig + }) + + const startAuth = Effect.fn("MCP.startAuth")(function* (mcpName: string) { + const mcpConfig = yield* getMcpConfig(mcpName) + if (!mcpConfig) throw new Error(`MCP server ${mcpName} not found or disabled`) + if (mcpConfig.type !== "remote") throw new Error(`MCP server ${mcpName} is not a remote server`) + if (mcpConfig.oauth === false) throw new Error(`MCP server ${mcpName} has OAuth explicitly disabled`) + + yield* Effect.promise(() => McpOAuthCallback.ensureRunning()) + + const oauthState = Array.from(crypto.getRandomValues(new Uint8Array(32))) + .map((b) => b.toString(16).padStart(2, "0")) + .join("") + yield* auth.updateOAuthState(mcpName, oauthState) + const oauthConfig = typeof mcpConfig.oauth === "object" ? mcpConfig.oauth : undefined + let capturedUrl: URL | undefined + const authProvider = new McpOAuthProvider( + mcpName, + mcpConfig.url, + { + clientId: oauthConfig?.clientId, + clientSecret: oauthConfig?.clientSecret, + scope: oauthConfig?.scope, + }, + { + onRedirect: async (url) => { + capturedUrl = url + }, + }, + ) + + const transport = new StreamableHTTPClientTransport(new URL(mcpConfig.url), { authProvider }) + + return yield* Effect.promise(async () => { + try { + const client = new Client({ name: "opencode", version: Installation.VERSION }) + await client.connect(transport) + return { authorizationUrl: "", oauthState } + } catch (error) { + if (error instanceof UnauthorizedError && capturedUrl) { + pendingOAuthTransports.set(mcpName, transport) + return { authorizationUrl: capturedUrl.toString(), oauthState } + } + throw error } }) }) - } catch (error) { - // Browser opening failed (e.g., in remote/headless sessions like SSH, devcontainers) - // Emit event so CLI can display the URL for manual opening - log.warn("failed to open browser, user must open URL manually", { mcpName, error }) - Bus.publish(BrowserOpenFailed, { mcpName, url: authorizationUrl }) - } - // Wait for callback using the already-registered promise - const code = await callbackPromise + const authenticate = Effect.fn("MCP.authenticate")(function* (mcpName: string) { + const { authorizationUrl, oauthState } = yield* startAuth(mcpName) + if (!authorizationUrl) return { status: "connected" } as Status - // Validate and clear the state - const storedState = await McpAuth.getOAuthState(mcpName) - if (storedState !== oauthState) { - await McpAuth.clearOAuthState(mcpName) - throw new Error("OAuth state mismatch - potential CSRF attack") - } + log.info("opening browser for oauth", { mcpName, url: authorizationUrl, state: oauthState }) - await McpAuth.clearOAuthState(mcpName) + const callbackPromise = McpOAuthCallback.waitForCallback(oauthState, mcpName) - // Finish auth - return finishAuth(mcpName, code) - } + yield* Effect.tryPromise(() => open(authorizationUrl)).pipe( + Effect.flatMap((subprocess) => + Effect.callback((resume) => { + const timer = setTimeout(() => resume(Effect.void), 500) + subprocess.on("error", (err) => { + clearTimeout(timer) + resume(Effect.fail(err)) + }) + subprocess.on("exit", (code) => { + if (code !== null && code !== 0) { + clearTimeout(timer) + resume(Effect.fail(new Error(`Browser open failed with exit code ${code}`))) + } + }) + }), + ), + Effect.catch(() => { + log.warn("failed to open browser, user must open URL manually", { mcpName }) + return Effect.promise(() => Bus.publish(BrowserOpenFailed, { mcpName, url: authorizationUrl })) + }), + ) - /** - * Complete OAuth authentication with the authorization code. - */ - export async function finishAuth(mcpName: string, authorizationCode: string): Promise { - const transport = pendingOAuthTransports.get(mcpName) + const code = yield* Effect.promise(() => callbackPromise) - if (!transport) { - throw new Error(`No pending OAuth flow for MCP server: ${mcpName}`) - } + const storedState = yield* auth.getOAuthState(mcpName) + if (storedState !== oauthState) { + yield* auth.clearOAuthState(mcpName) + throw new Error("OAuth state mismatch - potential CSRF attack") + } + yield* auth.clearOAuthState(mcpName) + return yield* finishAuth(mcpName, code) + }) - try { - // Call finishAuth on the transport - await transport.finishAuth(authorizationCode) + const finishAuth = Effect.fn("MCP.finishAuth")(function* (mcpName: string, authorizationCode: string) { + const transport = pendingOAuthTransports.get(mcpName) + if (!transport) throw new Error(`No pending OAuth flow for MCP server: ${mcpName}`) - // Clear the code verifier after successful auth - await McpAuth.clearCodeVerifier(mcpName) + const result = yield* Effect.tryPromise({ + try: async () => { + await transport.finishAuth(authorizationCode) + return true + }, + catch: (error) => { + log.error("failed to finish oauth", { mcpName, error }) + return error + }, + }).pipe(Effect.option) - // Now try to reconnect - const cfg = await Config.get() - const mcpConfig = cfg.mcp?.[mcpName] + if (Option.isNone(result)) { + return { status: "failed", error: "OAuth completion failed" } as Status + } - if (!mcpConfig) { - throw new Error(`MCP server not found: ${mcpName}`) - } + yield* auth.clearCodeVerifier(mcpName) + pendingOAuthTransports.delete(mcpName) - if (!isMcpConfigured(mcpConfig)) { - throw new Error(`MCP server ${mcpName} is disabled or missing configuration`) - } + const mcpConfig = yield* getMcpConfig(mcpName) + if (!mcpConfig) return { status: "failed", error: "MCP config not found after auth" } as Status - // Re-add the MCP server to establish connection - pendingOAuthTransports.delete(mcpName) - const result = await add(mcpName, mcpConfig) + return yield* createAndStore(mcpName, mcpConfig) + }) - const statusRecord = result.status as Record - return statusRecord[mcpName] ?? { status: "failed", error: "Unknown error after auth" } - } catch (error) { - log.error("failed to finish oauth", { mcpName, error }) - return { - status: "failed", - error: error instanceof Error ? error.message : String(error), - } - } - } + const removeAuth = Effect.fn("MCP.removeAuth")(function* (mcpName: string) { + yield* auth.remove(mcpName) + McpOAuthCallback.cancelPending(mcpName) + pendingOAuthTransports.delete(mcpName) + log.info("removed oauth credentials", { mcpName }) + }) - /** - * Remove OAuth credentials for an MCP server. - */ - export async function removeAuth(mcpName: string): Promise { - await McpAuth.remove(mcpName) - McpOAuthCallback.cancelPending(mcpName) - pendingOAuthTransports.delete(mcpName) - await McpAuth.clearOAuthState(mcpName) - log.info("removed oauth credentials", { mcpName }) - } + const supportsOAuth = Effect.fn("MCP.supportsOAuth")(function* (mcpName: string) { + const mcpConfig = yield* getMcpConfig(mcpName) + if (!mcpConfig) return false + return mcpConfig.type === "remote" && mcpConfig.oauth !== false + }) - /** - * Check if an MCP server supports OAuth (remote servers support OAuth by default unless explicitly disabled). - */ - export async function supportsOAuth(mcpName: string): Promise { - const cfg = await Config.get() - const mcpConfig = cfg.mcp?.[mcpName] - if (!mcpConfig) return false - if (!isMcpConfigured(mcpConfig)) return false - return mcpConfig.type === "remote" && mcpConfig.oauth !== false - } + const hasStoredTokens = Effect.fn("MCP.hasStoredTokens")(function* (mcpName: string) { + const entry = yield* auth.get(mcpName) + return !!entry?.tokens + }) - /** - * Check if an MCP server has stored OAuth tokens. - */ - export async function hasStoredTokens(mcpName: string): Promise { - const entry = await McpAuth.get(mcpName) - return !!entry?.tokens - } + const getAuthStatus = Effect.fn("MCP.getAuthStatus")(function* (mcpName: string) { + const entry = yield* auth.get(mcpName) + if (!entry?.tokens) return "not_authenticated" as AuthStatus + const expired = yield* auth.isTokenExpired(mcpName) + return (expired ? "expired" : "authenticated") as AuthStatus + }) + + return Service.of({ + status, + clients, + tools, + prompts, + resources, + add, + connect, + disconnect, + getPrompt, + readResource, + startAuth, + authenticate, + finishAuth, + removeAuth, + supportsOAuth, + hasStoredTokens, + getAuthStatus, + }) + }), + ) export type AuthStatus = "authenticated" | "expired" | "not_authenticated" - /** - * Get the authentication status for an MCP server. - */ - export async function getAuthStatus(mcpName: string): Promise { - const hasTokens = await hasStoredTokens(mcpName) - if (!hasTokens) return "not_authenticated" - const expired = await McpAuth.isTokenExpired(mcpName) - return expired ? "expired" : "authenticated" - } + // --- Per-service runtime --- + + const defaultLayer = layer.pipe( + Layer.provide(McpAuth.layer), + Layer.provide(CrossSpawnSpawner.layer), + Layer.provide(AppFileSystem.defaultLayer), + Layer.provide(NodeFileSystem.layer), + Layer.provide(NodePath.layer), + ) + + const runPromise = makeRunPromise(Service, defaultLayer) + + // --- Async facade functions --- + + export const status = async () => runPromise((svc) => svc.status()) + + export const clients = async () => runPromise((svc) => svc.clients()) + + export const tools = async () => runPromise((svc) => svc.tools()) + + export const prompts = async () => runPromise((svc) => svc.prompts()) + + export const resources = async () => runPromise((svc) => svc.resources()) + + export const add = async (name: string, mcp: Config.Mcp) => runPromise((svc) => svc.add(name, mcp)) + + export const connect = async (name: string) => runPromise((svc) => svc.connect(name)) + + export const disconnect = async (name: string) => runPromise((svc) => svc.disconnect(name)) + + export const getPrompt = async (clientName: string, name: string, args?: Record) => + runPromise((svc) => svc.getPrompt(clientName, name, args)) + + export const readResource = async (clientName: string, resourceUri: string) => + runPromise((svc) => svc.readResource(clientName, resourceUri)) + + export const startAuth = async (mcpName: string) => runPromise((svc) => svc.startAuth(mcpName)) + + export const authenticate = async (mcpName: string) => runPromise((svc) => svc.authenticate(mcpName)) + + export const finishAuth = async (mcpName: string, authorizationCode: string) => + runPromise((svc) => svc.finishAuth(mcpName, authorizationCode)) + + export const removeAuth = async (mcpName: string) => runPromise((svc) => svc.removeAuth(mcpName)) + + export const supportsOAuth = async (mcpName: string) => runPromise((svc) => svc.supportsOAuth(mcpName)) + + export const hasStoredTokens = async (mcpName: string) => runPromise((svc) => svc.hasStoredTokens(mcpName)) + + export const getAuthStatus = async (mcpName: string) => runPromise((svc) => svc.getAuthStatus(mcpName)) } diff --git a/packages/opencode/src/mcp/oauth-callback.ts b/packages/opencode/src/mcp/oauth-callback.ts index db8e621d6c..3a1ca54044 100644 --- a/packages/opencode/src/mcp/oauth-callback.ts +++ b/packages/opencode/src/mcp/oauth-callback.ts @@ -54,6 +54,9 @@ interface PendingAuth { export namespace McpOAuthCallback { let server: ReturnType | undefined const pendingAuths = new Map() + // Reverse index: mcpName → oauthState, so cancelPending(mcpName) can + // find the right entry in pendingAuths (which is keyed by oauthState). + const mcpNameToState = new Map() const CALLBACK_TIMEOUT_MS = 5 * 60 * 1000 // 5 minutes @@ -98,6 +101,12 @@ export namespace McpOAuthCallback { const pending = pendingAuths.get(state)! clearTimeout(pending.timeout) pendingAuths.delete(state) + for (const [name, s] of mcpNameToState) { + if (s === state) { + mcpNameToState.delete(name) + break + } + } pending.reject(new Error(errorMsg)) } return new Response(HTML_ERROR(errorMsg), { @@ -126,6 +135,13 @@ export namespace McpOAuthCallback { clearTimeout(pending.timeout) pendingAuths.delete(state) + // Clean up reverse index + for (const [name, s] of mcpNameToState) { + if (s === state) { + mcpNameToState.delete(name) + break + } + } pending.resolve(code) return new Response(HTML_SUCCESS, { @@ -137,11 +153,13 @@ export namespace McpOAuthCallback { log.info("oauth callback server started", { port: OAUTH_CALLBACK_PORT }) } - export function waitForCallback(oauthState: string): Promise { + export function waitForCallback(oauthState: string, mcpName?: string): Promise { + if (mcpName) mcpNameToState.set(mcpName, oauthState) return new Promise((resolve, reject) => { const timeout = setTimeout(() => { if (pendingAuths.has(oauthState)) { pendingAuths.delete(oauthState) + if (mcpName) mcpNameToState.delete(mcpName) reject(new Error("OAuth callback timeout - authorization took too long")) } }, CALLBACK_TIMEOUT_MS) @@ -151,10 +169,14 @@ export namespace McpOAuthCallback { } export function cancelPending(mcpName: string): void { - const pending = pendingAuths.get(mcpName) + // Look up the oauthState for this mcpName via the reverse index + const oauthState = mcpNameToState.get(mcpName) + const key = oauthState ?? mcpName + const pending = pendingAuths.get(key) if (pending) { clearTimeout(pending.timeout) - pendingAuths.delete(mcpName) + pendingAuths.delete(key) + mcpNameToState.delete(mcpName) pending.reject(new Error("Authorization cancelled")) } } @@ -184,6 +206,7 @@ export namespace McpOAuthCallback { pending.reject(new Error("OAuth callback server stopped")) } pendingAuths.clear() + mcpNameToState.clear() } export function isRunning(): boolean { diff --git a/packages/opencode/test/mcp/lifecycle.test.ts b/packages/opencode/test/mcp/lifecycle.test.ts new file mode 100644 index 0000000000..2880c053f1 --- /dev/null +++ b/packages/opencode/test/mcp/lifecycle.test.ts @@ -0,0 +1,660 @@ +import { test, expect, mock, beforeEach } from "bun:test" + +// --- Mock infrastructure --- + +// Per-client state for controlling mock behavior +interface MockClientState { + tools: Array<{ name: string; description?: string; inputSchema: object }> + listToolsCalls: number + listToolsShouldFail: boolean + listToolsError: string + listPromptsShouldFail: boolean + listResourcesShouldFail: boolean + prompts: Array<{ name: string; description?: string }> + resources: Array<{ name: string; uri: string; description?: string }> + closed: boolean + notificationHandlers: Map any> +} + +const clientStates = new Map() +let lastCreatedClientName: string | undefined +let connectShouldFail = false +let connectError = "Mock transport cannot connect" +// Tracks how many Client instances were created (detects leaks) +let clientCreateCount = 0 + +function getOrCreateClientState(name?: string): MockClientState { + const key = name ?? "default" + let state = clientStates.get(key) + if (!state) { + state = { + tools: [{ name: "test_tool", description: "A test tool", inputSchema: { type: "object", properties: {} } }], + listToolsCalls: 0, + listToolsShouldFail: false, + listToolsError: "listTools failed", + listPromptsShouldFail: false, + listResourcesShouldFail: false, + prompts: [], + resources: [], + closed: false, + notificationHandlers: new Map(), + } + clientStates.set(key, state) + } + return state +} + +// Mock transport that succeeds or fails based on connectShouldFail +class MockStdioTransport { + stderr: null = null + pid = 12345 + constructor(_opts: any) {} + async start() { + if (connectShouldFail) throw new Error(connectError) + } + async close() {} +} + +class MockStreamableHTTP { + constructor(_url: URL, _opts?: any) {} + async start() { + if (connectShouldFail) throw new Error(connectError) + } + async close() {} + async finishAuth() {} +} + +class MockSSE { + constructor(_url: URL, _opts?: any) {} + async start() { + throw new Error("SSE fallback - not used in these tests") + } + async close() {} +} + +mock.module("@modelcontextprotocol/sdk/client/stdio.js", () => ({ + StdioClientTransport: MockStdioTransport, +})) + +mock.module("@modelcontextprotocol/sdk/client/streamableHttp.js", () => ({ + StreamableHTTPClientTransport: MockStreamableHTTP, +})) + +mock.module("@modelcontextprotocol/sdk/client/sse.js", () => ({ + SSEClientTransport: MockSSE, +})) + +mock.module("@modelcontextprotocol/sdk/client/auth.js", () => ({ + UnauthorizedError: class extends Error { + constructor() { + super("Unauthorized") + } + }, +})) + +// Mock Client that delegates to per-name MockClientState +mock.module("@modelcontextprotocol/sdk/client/index.js", () => ({ + Client: class MockClient { + _state!: MockClientState + transport: any + + constructor(_opts: any) { + clientCreateCount++ + } + + async connect(transport: { start: () => Promise }) { + this.transport = transport + await transport.start() + // After successful connect, bind to the last-created client name + this._state = getOrCreateClientState(lastCreatedClientName) + } + + setNotificationHandler(schema: unknown, handler: (...args: any[]) => any) { + this._state?.notificationHandlers.set(schema, handler) + } + + async listTools() { + if (this._state) this._state.listToolsCalls++ + if (this._state?.listToolsShouldFail) { + throw new Error(this._state.listToolsError) + } + return { tools: this._state?.tools ?? [] } + } + + async listPrompts() { + if (this._state?.listPromptsShouldFail) { + throw new Error("listPrompts failed") + } + return { prompts: this._state?.prompts ?? [] } + } + + async listResources() { + if (this._state?.listResourcesShouldFail) { + throw new Error("listResources failed") + } + return { resources: this._state?.resources ?? [] } + } + + async close() { + if (this._state) this._state.closed = true + } + }, +})) + +beforeEach(() => { + clientStates.clear() + lastCreatedClientName = undefined + connectShouldFail = false + connectError = "Mock transport cannot connect" + clientCreateCount = 0 +}) + +// Import after mocks +const { MCP } = await import("../../src/mcp/index") +const { Instance } = await import("../../src/project/instance") +const { tmpdir } = await import("../fixture/fixture") + +// --- Helper --- + +function withInstance(config: Record, fn: () => Promise) { + return async () => { + await using tmp = await tmpdir({ + init: async (dir) => { + await Bun.write( + `${dir}/opencode.json`, + JSON.stringify({ + $schema: "https://opencode.ai/config.json", + mcp: config, + }), + ) + }, + }) + + await Instance.provide({ + directory: tmp.path, + fn: async () => { + await fn() + // dispose instance to clean up state between tests + await Instance.dispose() + }, + }) + } +} + +// ======================================================================== +// Test: tools() are cached after connect +// ======================================================================== + +test( + "tools() reuses cached tool definitions after connect", + withInstance({}, async () => { + lastCreatedClientName = "my-server" + const serverState = getOrCreateClientState("my-server") + serverState.tools = [ + { name: "do_thing", description: "does a thing", inputSchema: { type: "object", properties: {} } }, + ] + + // First: add the server successfully + const addResult = await MCP.add("my-server", { + type: "local", + command: ["echo", "test"], + }) + expect((addResult.status as any)["my-server"]?.status ?? (addResult.status as any).status).toBe("connected") + + expect(serverState.listToolsCalls).toBe(1) + + const toolsA = await MCP.tools() + const toolsB = await MCP.tools() + expect(Object.keys(toolsA).length).toBeGreaterThan(0) + expect(Object.keys(toolsB).length).toBeGreaterThan(0) + expect(serverState.listToolsCalls).toBe(1) + }), +) + +// ======================================================================== +// Test: tool change notifications refresh the cache +// ======================================================================== + +test( + "tool change notifications refresh cached tool definitions", + withInstance({}, async () => { + lastCreatedClientName = "status-server" + const serverState = getOrCreateClientState("status-server") + + await MCP.add("status-server", { + type: "local", + command: ["echo", "test"], + }) + + const before = await MCP.tools() + expect(Object.keys(before).some((key) => key.includes("test_tool"))).toBe(true) + expect(serverState.listToolsCalls).toBe(1) + + serverState.tools = [{ name: "next_tool", description: "next", inputSchema: { type: "object", properties: {} } }] + + const handler = Array.from(serverState.notificationHandlers.values())[0] + expect(handler).toBeDefined() + await handler?.() + + const after = await MCP.tools() + expect(Object.keys(after).some((key) => key.includes("next_tool"))).toBe(true) + expect(Object.keys(after).some((key) => key.includes("test_tool"))).toBe(false) + expect(serverState.listToolsCalls).toBe(2) + }), +) + +// ======================================================================== +// Test: connect() / disconnect() lifecycle +// ======================================================================== + +test( + "disconnect sets status to disabled and removes client", + withInstance( + { + "disc-server": { + type: "local", + command: ["echo", "test"], + }, + }, + async () => { + lastCreatedClientName = "disc-server" + getOrCreateClientState("disc-server") + + await MCP.add("disc-server", { + type: "local", + command: ["echo", "test"], + }) + + const statusBefore = await MCP.status() + expect(statusBefore["disc-server"]?.status).toBe("connected") + + await MCP.disconnect("disc-server") + + const statusAfter = await MCP.status() + expect(statusAfter["disc-server"]?.status).toBe("disabled") + + // Tools should be empty after disconnect + const tools = await MCP.tools() + const serverTools = Object.keys(tools).filter((k) => k.startsWith("disc-server")) + expect(serverTools.length).toBe(0) + }, + ), +) + +test( + "connect() after disconnect() re-establishes the server", + withInstance( + { + "reconn-server": { + type: "local", + command: ["echo", "test"], + }, + }, + async () => { + lastCreatedClientName = "reconn-server" + const serverState = getOrCreateClientState("reconn-server") + serverState.tools = [{ name: "my_tool", description: "a tool", inputSchema: { type: "object", properties: {} } }] + + await MCP.add("reconn-server", { + type: "local", + command: ["echo", "test"], + }) + + await MCP.disconnect("reconn-server") + expect((await MCP.status())["reconn-server"]?.status).toBe("disabled") + + // Reconnect + await MCP.connect("reconn-server") + expect((await MCP.status())["reconn-server"]?.status).toBe("connected") + + const tools = await MCP.tools() + expect(Object.keys(tools).some((k) => k.includes("my_tool"))).toBe(true) + }, + ), +) + +// ======================================================================== +// Test: add() closes existing client before replacing +// ======================================================================== + +test( + "add() closes the old client when replacing a server", + // Don't put the server in config — add it dynamically so we control + // exactly which client instance is "first" vs "second". + withInstance({}, async () => { + lastCreatedClientName = "replace-server" + const firstState = getOrCreateClientState("replace-server") + + await MCP.add("replace-server", { + type: "local", + command: ["echo", "test"], + }) + + expect(firstState.closed).toBe(false) + + // Create new state for second client + clientStates.delete("replace-server") + const secondState = getOrCreateClientState("replace-server") + + // Re-add should close the first client + await MCP.add("replace-server", { + type: "local", + command: ["echo", "test"], + }) + + expect(firstState.closed).toBe(true) + expect(secondState.closed).toBe(false) + }), +) + +// ======================================================================== +// Test: state init with mixed success/failure +// ======================================================================== + +test( + "init connects available servers even when one fails", + withInstance( + { + "good-server": { + type: "local", + command: ["echo", "good"], + }, + "bad-server": { + type: "local", + command: ["echo", "bad"], + }, + }, + async () => { + // Set up good server + const goodState = getOrCreateClientState("good-server") + goodState.tools = [{ name: "good_tool", description: "works", inputSchema: { type: "object", properties: {} } }] + + // Set up bad server - will fail on listTools during create() + const badState = getOrCreateClientState("bad-server") + badState.listToolsShouldFail = true + + // Add good server first + lastCreatedClientName = "good-server" + await MCP.add("good-server", { + type: "local", + command: ["echo", "good"], + }) + + // Add bad server - should fail but not affect good server + lastCreatedClientName = "bad-server" + await MCP.add("bad-server", { + type: "local", + command: ["echo", "bad"], + }) + + const status = await MCP.status() + expect(status["good-server"]?.status).toBe("connected") + expect(status["bad-server"]?.status).toBe("failed") + + // Good server's tools should still be available + const tools = await MCP.tools() + expect(Object.keys(tools).some((k) => k.includes("good_tool"))).toBe(true) + }, + ), +) + +// ======================================================================== +// Test: disabled server via config +// ======================================================================== + +test( + "disabled server is marked as disabled without attempting connection", + withInstance( + { + "disabled-server": { + type: "local", + command: ["echo", "test"], + enabled: false, + }, + }, + async () => { + const countBefore = clientCreateCount + + await MCP.add("disabled-server", { + type: "local", + command: ["echo", "test"], + enabled: false, + } as any) + + // No client should have been created + expect(clientCreateCount).toBe(countBefore) + + const status = await MCP.status() + expect(status["disabled-server"]?.status).toBe("disabled") + }, + ), +) + +// ======================================================================== +// Test: prompts() and resources() +// ======================================================================== + +test( + "prompts() returns prompts from connected servers", + withInstance( + { + "prompt-server": { + type: "local", + command: ["echo", "test"], + }, + }, + async () => { + lastCreatedClientName = "prompt-server" + const serverState = getOrCreateClientState("prompt-server") + serverState.prompts = [{ name: "my-prompt", description: "A test prompt" }] + + await MCP.add("prompt-server", { + type: "local", + command: ["echo", "test"], + }) + + const prompts = await MCP.prompts() + expect(Object.keys(prompts).length).toBe(1) + const key = Object.keys(prompts)[0] + expect(key).toContain("prompt-server") + expect(key).toContain("my-prompt") + }, + ), +) + +test( + "resources() returns resources from connected servers", + withInstance( + { + "resource-server": { + type: "local", + command: ["echo", "test"], + }, + }, + async () => { + lastCreatedClientName = "resource-server" + const serverState = getOrCreateClientState("resource-server") + serverState.resources = [{ name: "my-resource", uri: "file:///test.txt", description: "A test resource" }] + + await MCP.add("resource-server", { + type: "local", + command: ["echo", "test"], + }) + + const resources = await MCP.resources() + expect(Object.keys(resources).length).toBe(1) + const key = Object.keys(resources)[0] + expect(key).toContain("resource-server") + expect(key).toContain("my-resource") + }, + ), +) + +test( + "prompts() skips disconnected servers", + withInstance( + { + "prompt-disc-server": { + type: "local", + command: ["echo", "test"], + }, + }, + async () => { + lastCreatedClientName = "prompt-disc-server" + const serverState = getOrCreateClientState("prompt-disc-server") + serverState.prompts = [{ name: "hidden-prompt", description: "Should not appear" }] + + await MCP.add("prompt-disc-server", { + type: "local", + command: ["echo", "test"], + }) + + await MCP.disconnect("prompt-disc-server") + + const prompts = await MCP.prompts() + expect(Object.keys(prompts).length).toBe(0) + }, + ), +) + +// ======================================================================== +// Test: connect() on nonexistent server +// ======================================================================== + +test( + "connect() on nonexistent server does not throw", + withInstance({}, async () => { + // Should not throw + await MCP.connect("nonexistent") + const status = await MCP.status() + expect(status["nonexistent"]).toBeUndefined() + }), +) + +// ======================================================================== +// Test: disconnect() on nonexistent server +// ======================================================================== + +test( + "disconnect() on nonexistent server does not throw", + withInstance({}, async () => { + await MCP.disconnect("nonexistent") + // Should complete without error + }), +) + +// ======================================================================== +// Test: tools() with no MCP servers configured +// ======================================================================== + +test( + "tools() returns empty when no MCP servers are configured", + withInstance({}, async () => { + const tools = await MCP.tools() + expect(Object.keys(tools).length).toBe(0) + }), +) + +// ======================================================================== +// Test: connect failure during create() +// ======================================================================== + +test( + "server that fails to connect is marked as failed", + withInstance( + { + "fail-connect": { + type: "local", + command: ["echo", "test"], + }, + }, + async () => { + lastCreatedClientName = "fail-connect" + getOrCreateClientState("fail-connect") + connectShouldFail = true + connectError = "Connection refused" + + await MCP.add("fail-connect", { + type: "local", + command: ["echo", "test"], + }) + + const status = await MCP.status() + expect(status["fail-connect"]?.status).toBe("failed") + if (status["fail-connect"]?.status === "failed") { + expect(status["fail-connect"].error).toContain("Connection refused") + } + + // No tools should be available + const tools = await MCP.tools() + expect(Object.keys(tools).length).toBe(0) + }, + ), +) + +// ======================================================================== +// Bug #5: McpOAuthCallback.cancelPending uses wrong key +// ======================================================================== + +test("McpOAuthCallback.cancelPending is keyed by mcpName but pendingAuths uses oauthState", async () => { + const { McpOAuthCallback } = await import("../../src/mcp/oauth-callback") + + // Register a pending auth with an oauthState key, associated to an mcpName + const oauthState = "abc123hexstate" + const callbackPromise = McpOAuthCallback.waitForCallback(oauthState, "my-mcp-server") + + // cancelPending is called with mcpName — should find the entry via reverse index + McpOAuthCallback.cancelPending("my-mcp-server") + + // The callback should still be pending because cancelPending looked up + // "my-mcp-server" in a map keyed by "abc123hexstate" + let resolved = false + let rejected = false + callbackPromise.then(() => (resolved = true)).catch(() => (rejected = true)) + + // Give it a tick + await new Promise((r) => setTimeout(r, 50)) + + // cancelPending("my-mcp-server") should have rejected the pending callback + expect(rejected).toBe(true) + + await McpOAuthCallback.stop() +}) + +// ======================================================================== +// Test: multiple tools from same server get correct name prefixes +// ======================================================================== + +test( + "tools() prefixes tool names with sanitized server name", + withInstance( + { + "my.special-server": { + type: "local", + command: ["echo", "test"], + }, + }, + async () => { + lastCreatedClientName = "my.special-server" + const serverState = getOrCreateClientState("my.special-server") + serverState.tools = [ + { name: "tool-a", description: "Tool A", inputSchema: { type: "object", properties: {} } }, + { name: "tool.b", description: "Tool B", inputSchema: { type: "object", properties: {} } }, + ] + + await MCP.add("my.special-server", { + type: "local", + command: ["echo", "test"], + }) + + const tools = await MCP.tools() + const keys = Object.keys(tools) + + // Server name dots should be replaced with underscores + expect(keys.some((k) => k.startsWith("my_special-server_"))).toBe(true) + // Tool name dots should be replaced with underscores + expect(keys.some((k) => k.endsWith("tool_b"))).toBe(true) + expect(keys.length).toBe(2) + }, + ), +)