refactor: replace Bun.serve with Node http.createServer in OAuth handlers (#18327)

Co-authored-by: LukeParkerDev <10430890+Hona@users.noreply.github.com>
pull/18335/head
Dax 2026-04-06 12:17:29 -04:00 committed by GitHub
parent 965c751522
commit 2e4c43c1cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 147 additions and 139 deletions

View File

@ -1,4 +1,5 @@
import { createConnection } from "net" import { createConnection } from "net"
import { createServer } from "http"
import { Log } from "../util/log" import { Log } from "../util/log"
import { OAUTH_CALLBACK_PORT, OAUTH_CALLBACK_PATH } from "./oauth-provider" import { OAUTH_CALLBACK_PORT, OAUTH_CALLBACK_PATH } from "./oauth-provider"
@ -52,7 +53,7 @@ interface PendingAuth {
} }
export namespace McpOAuthCallback { export namespace McpOAuthCallback {
let server: ReturnType<typeof Bun.serve> | undefined let server: ReturnType<typeof createServer> | undefined
const pendingAuths = new Map<string, PendingAuth>() const pendingAuths = new Map<string, PendingAuth>()
// Reverse index: mcpName → oauthState, so cancelPending(mcpName) can // Reverse index: mcpName → oauthState, so cancelPending(mcpName) can
// find the right entry in pendingAuths (which is keyed by oauthState). // find the right entry in pendingAuths (which is keyed by oauthState).
@ -60,22 +61,22 @@ export namespace McpOAuthCallback {
const CALLBACK_TIMEOUT_MS = 5 * 60 * 1000 // 5 minutes const CALLBACK_TIMEOUT_MS = 5 * 60 * 1000 // 5 minutes
export async function ensureRunning(): Promise<void> { function cleanupStateIndex(oauthState: string) {
if (server) return for (const [name, state] of mcpNameToState) {
if (state === oauthState) {
const running = await isPortInUse() mcpNameToState.delete(name)
if (running) { break
log.info("oauth callback server already running on another instance", { port: OAUTH_CALLBACK_PORT }) }
return }
} }
server = Bun.serve({ function handleRequest(req: import("http").IncomingMessage, res: import("http").ServerResponse) {
port: OAUTH_CALLBACK_PORT, const url = new URL(req.url || "/", `http://localhost:${OAUTH_CALLBACK_PORT}`)
fetch(req) {
const url = new URL(req.url)
if (url.pathname !== OAUTH_CALLBACK_PATH) { if (url.pathname !== OAUTH_CALLBACK_PATH) {
return new Response("Not found", { status: 404 }) res.writeHead(404)
res.end("Not found")
return
} }
const code = url.searchParams.get("code") const code = url.searchParams.get("code")
@ -89,10 +90,9 @@ export namespace McpOAuthCallback {
if (!state) { if (!state) {
const errorMsg = "Missing required state parameter - potential CSRF attack" const errorMsg = "Missing required state parameter - potential CSRF attack"
log.error("oauth callback missing state parameter", { url: url.toString() }) log.error("oauth callback missing state parameter", { url: url.toString() })
return new Response(HTML_ERROR(errorMsg), { res.writeHead(400, { "Content-Type": "text/html" })
status: 400, res.end(HTML_ERROR(errorMsg))
headers: { "Content-Type": "text/html" }, return
})
} }
if (error) { if (error) {
@ -101,56 +101,57 @@ export namespace McpOAuthCallback {
const pending = pendingAuths.get(state)! const pending = pendingAuths.get(state)!
clearTimeout(pending.timeout) clearTimeout(pending.timeout)
pendingAuths.delete(state) pendingAuths.delete(state)
for (const [name, s] of mcpNameToState) { cleanupStateIndex(state)
if (s === state) {
mcpNameToState.delete(name)
break
}
}
pending.reject(new Error(errorMsg)) pending.reject(new Error(errorMsg))
} }
return new Response(HTML_ERROR(errorMsg), { res.writeHead(200, { "Content-Type": "text/html" })
headers: { "Content-Type": "text/html" }, res.end(HTML_ERROR(errorMsg))
}) return
} }
if (!code) { if (!code) {
return new Response(HTML_ERROR("No authorization code provided"), { res.writeHead(400, { "Content-Type": "text/html" })
status: 400, res.end(HTML_ERROR("No authorization code provided"))
headers: { "Content-Type": "text/html" }, return
})
} }
// Validate state parameter // Validate state parameter
if (!pendingAuths.has(state)) { if (!pendingAuths.has(state)) {
const errorMsg = "Invalid or expired state parameter - potential CSRF attack" const errorMsg = "Invalid or expired state parameter - potential CSRF attack"
log.error("oauth callback with invalid state", { state, pendingStates: Array.from(pendingAuths.keys()) }) log.error("oauth callback with invalid state", { state, pendingStates: Array.from(pendingAuths.keys()) })
return new Response(HTML_ERROR(errorMsg), { res.writeHead(400, { "Content-Type": "text/html" })
status: 400, res.end(HTML_ERROR(errorMsg))
headers: { "Content-Type": "text/html" }, return
})
} }
const pending = pendingAuths.get(state)! const pending = pendingAuths.get(state)!
clearTimeout(pending.timeout) clearTimeout(pending.timeout)
pendingAuths.delete(state) pendingAuths.delete(state)
// Clean up reverse index cleanupStateIndex(state)
for (const [name, s] of mcpNameToState) {
if (s === state) {
mcpNameToState.delete(name)
break
}
}
pending.resolve(code) pending.resolve(code)
return new Response(HTML_SUCCESS, { res.writeHead(200, { "Content-Type": "text/html" })
headers: { "Content-Type": "text/html" }, res.end(HTML_SUCCESS)
}) }
},
})
export async function ensureRunning(): Promise<void> {
if (server) return
const running = await isPortInUse()
if (running) {
log.info("oauth callback server already running on another instance", { port: OAUTH_CALLBACK_PORT })
return
}
server = createServer(handleRequest)
await new Promise<void>((resolve, reject) => {
server!.listen(OAUTH_CALLBACK_PORT, () => {
log.info("oauth callback server started", { port: OAUTH_CALLBACK_PORT }) log.info("oauth callback server started", { port: OAUTH_CALLBACK_PORT })
resolve()
})
server!.on("error", reject)
})
} }
export function waitForCallback(oauthState: string, mcpName?: string): Promise<string> { export function waitForCallback(oauthState: string, mcpName?: string): Promise<string> {
@ -196,7 +197,7 @@ export namespace McpOAuthCallback {
export async function stop(): Promise<void> { export async function stop(): Promise<void> {
if (server) { if (server) {
server.stop() await new Promise<void>((resolve) => server!.close(() => resolve()))
server = undefined server = undefined
log.info("oauth callback server stopped") log.info("oauth callback server stopped")
} }

View File

@ -6,6 +6,7 @@ import os from "os"
import { ProviderTransform } from "@/provider/transform" import { ProviderTransform } from "@/provider/transform"
import { ModelID, ProviderID } from "@/provider/schema" import { ModelID, ProviderID } from "@/provider/schema"
import { setTimeout as sleep } from "node:timers/promises" import { setTimeout as sleep } from "node:timers/promises"
import { createServer } from "http"
const log = Log.create({ service: "plugin.codex" }) const log = Log.create({ service: "plugin.codex" })
@ -241,7 +242,7 @@ interface PendingOAuth {
reject: (error: Error) => void reject: (error: Error) => void
} }
let oauthServer: ReturnType<typeof Bun.serve> | undefined let oauthServer: ReturnType<typeof createServer> | undefined
let pendingOAuth: PendingOAuth | undefined let pendingOAuth: PendingOAuth | undefined
async function startOAuthServer(): Promise<{ port: number; redirectUri: string }> { async function startOAuthServer(): Promise<{ port: number; redirectUri: string }> {
@ -249,10 +250,8 @@ async function startOAuthServer(): Promise<{ port: number; redirectUri: string }
return { port: OAUTH_PORT, redirectUri: `http://localhost:${OAUTH_PORT}/auth/callback` } return { port: OAUTH_PORT, redirectUri: `http://localhost:${OAUTH_PORT}/auth/callback` }
} }
oauthServer = Bun.serve({ oauthServer = createServer((req, res) => {
port: OAUTH_PORT, const url = new URL(req.url || "/", `http://localhost:${OAUTH_PORT}`)
fetch(req) {
const url = new URL(req.url)
if (url.pathname === "/auth/callback") { if (url.pathname === "/auth/callback") {
const code = url.searchParams.get("code") const code = url.searchParams.get("code")
@ -264,29 +263,27 @@ async function startOAuthServer(): Promise<{ port: number; redirectUri: string }
const errorMsg = errorDescription || error const errorMsg = errorDescription || error
pendingOAuth?.reject(new Error(errorMsg)) pendingOAuth?.reject(new Error(errorMsg))
pendingOAuth = undefined pendingOAuth = undefined
return new Response(HTML_ERROR(errorMsg), { res.writeHead(200, { "Content-Type": "text/html" })
headers: { "Content-Type": "text/html" }, res.end(HTML_ERROR(errorMsg))
}) return
} }
if (!code) { if (!code) {
const errorMsg = "Missing authorization code" const errorMsg = "Missing authorization code"
pendingOAuth?.reject(new Error(errorMsg)) pendingOAuth?.reject(new Error(errorMsg))
pendingOAuth = undefined pendingOAuth = undefined
return new Response(HTML_ERROR(errorMsg), { res.writeHead(400, { "Content-Type": "text/html" })
status: 400, res.end(HTML_ERROR(errorMsg))
headers: { "Content-Type": "text/html" }, return
})
} }
if (!pendingOAuth || state !== pendingOAuth.state) { if (!pendingOAuth || state !== pendingOAuth.state) {
const errorMsg = "Invalid state - potential CSRF attack" const errorMsg = "Invalid state - potential CSRF attack"
pendingOAuth?.reject(new Error(errorMsg)) pendingOAuth?.reject(new Error(errorMsg))
pendingOAuth = undefined pendingOAuth = undefined
return new Response(HTML_ERROR(errorMsg), { res.writeHead(400, { "Content-Type": "text/html" })
status: 400, res.end(HTML_ERROR(errorMsg))
headers: { "Content-Type": "text/html" }, return
})
} }
const current = pendingOAuth const current = pendingOAuth
@ -296,30 +293,40 @@ async function startOAuthServer(): Promise<{ port: number; redirectUri: string }
.then((tokens) => current.resolve(tokens)) .then((tokens) => current.resolve(tokens))
.catch((err) => current.reject(err)) .catch((err) => current.reject(err))
return new Response(HTML_SUCCESS, { res.writeHead(200, { "Content-Type": "text/html" })
headers: { "Content-Type": "text/html" }, res.end(HTML_SUCCESS)
}) return
} }
if (url.pathname === "/cancel") { if (url.pathname === "/cancel") {
pendingOAuth?.reject(new Error("Login cancelled")) pendingOAuth?.reject(new Error("Login cancelled"))
pendingOAuth = undefined pendingOAuth = undefined
return new Response("Login cancelled", { status: 200 }) res.writeHead(200)
res.end("Login cancelled")
return
} }
return new Response("Not found", { status: 404 }) res.writeHead(404)
}, res.end("Not found")
}) })
await new Promise<void>((resolve, reject) => {
oauthServer!.listen(OAUTH_PORT, () => {
log.info("codex oauth server started", { port: OAUTH_PORT }) log.info("codex oauth server started", { port: OAUTH_PORT })
resolve()
})
oauthServer!.on("error", reject)
})
return { port: OAUTH_PORT, redirectUri: `http://localhost:${OAUTH_PORT}/auth/callback` } return { port: OAUTH_PORT, redirectUri: `http://localhost:${OAUTH_PORT}/auth/callback` }
} }
function stopOAuthServer() { function stopOAuthServer() {
if (oauthServer) { if (oauthServer) {
oauthServer.stop() oauthServer.close(() => {
oauthServer = undefined
log.info("codex oauth server stopped") log.info("codex oauth server stopped")
})
oauthServer = undefined
} }
} }