chore: extract OAuth changes into #18327

pull/16918/head
Dax Raad 2026-03-19 21:48:23 -04:00
parent cbc40a5981
commit 65e786258a
2 changed files with 126 additions and 136 deletions

View File

@ -1,5 +1,4 @@
import { createConnection } from "net"
import { createServer } from "http"
import { Log } from "../util/log"
import { OAUTH_CALLBACK_PORT, OAUTH_CALLBACK_PATH } from "./oauth-provider"
@ -53,74 +52,11 @@ interface PendingAuth {
}
export namespace McpOAuthCallback {
let server: ReturnType<typeof createServer> | undefined
let server: ReturnType<typeof Bun.serve> | undefined
const pendingAuths = new Map<string, PendingAuth>()
const CALLBACK_TIMEOUT_MS = 5 * 60 * 1000 // 5 minutes
function handleRequest(req: import("http").IncomingMessage, res: import("http").ServerResponse) {
const url = new URL(req.url || "/", `http://localhost:${OAUTH_CALLBACK_PORT}`)
if (url.pathname !== OAUTH_CALLBACK_PATH) {
res.writeHead(404)
res.end("Not found")
return
}
const code = url.searchParams.get("code")
const state = url.searchParams.get("state")
const error = url.searchParams.get("error")
const errorDescription = url.searchParams.get("error_description")
log.info("received oauth callback", { hasCode: !!code, state, error })
// Enforce state parameter presence
if (!state) {
const errorMsg = "Missing required state parameter - potential CSRF attack"
log.error("oauth callback missing state parameter", { url: url.toString() })
res.writeHead(400, { "Content-Type": "text/html" })
res.end(HTML_ERROR(errorMsg))
return
}
if (error) {
const errorMsg = errorDescription || error
if (pendingAuths.has(state)) {
const pending = pendingAuths.get(state)!
clearTimeout(pending.timeout)
pendingAuths.delete(state)
pending.reject(new Error(errorMsg))
}
res.writeHead(200, { "Content-Type": "text/html" })
res.end(HTML_ERROR(errorMsg))
return
}
if (!code) {
res.writeHead(400, { "Content-Type": "text/html" })
res.end(HTML_ERROR("No authorization code provided"))
return
}
// Validate state parameter
if (!pendingAuths.has(state)) {
const errorMsg = "Invalid or expired state parameter - potential CSRF attack"
log.error("oauth callback with invalid state", { state, pendingStates: Array.from(pendingAuths.keys()) })
res.writeHead(400, { "Content-Type": "text/html" })
res.end(HTML_ERROR(errorMsg))
return
}
const pending = pendingAuths.get(state)!
clearTimeout(pending.timeout)
pendingAuths.delete(state)
pending.resolve(code)
res.writeHead(200, { "Content-Type": "text/html" })
res.end(HTML_SUCCESS)
}
export async function ensureRunning(): Promise<void> {
if (server) return
@ -130,14 +66,75 @@ export namespace McpOAuthCallback {
return
}
server = createServer(handleRequest)
await new Promise<void>((resolve, reject) => {
server!.listen(OAUTH_CALLBACK_PORT, () => {
log.info("oauth callback server started", { port: OAUTH_CALLBACK_PORT })
resolve()
})
server!.on("error", reject)
server = Bun.serve({
port: OAUTH_CALLBACK_PORT,
fetch(req) {
const url = new URL(req.url)
if (url.pathname !== OAUTH_CALLBACK_PATH) {
return new Response("Not found", { status: 404 })
}
const code = url.searchParams.get("code")
const state = url.searchParams.get("state")
const error = url.searchParams.get("error")
const errorDescription = url.searchParams.get("error_description")
log.info("received oauth callback", { hasCode: !!code, state, error })
// Enforce state parameter presence
if (!state) {
const errorMsg = "Missing required state parameter - potential CSRF attack"
log.error("oauth callback missing state parameter", { url: url.toString() })
return new Response(HTML_ERROR(errorMsg), {
status: 400,
headers: { "Content-Type": "text/html" },
})
}
if (error) {
const errorMsg = errorDescription || error
if (pendingAuths.has(state)) {
const pending = pendingAuths.get(state)!
clearTimeout(pending.timeout)
pendingAuths.delete(state)
pending.reject(new Error(errorMsg))
}
return new Response(HTML_ERROR(errorMsg), {
headers: { "Content-Type": "text/html" },
})
}
if (!code) {
return new Response(HTML_ERROR("No authorization code provided"), {
status: 400,
headers: { "Content-Type": "text/html" },
})
}
// Validate state parameter
if (!pendingAuths.has(state)) {
const errorMsg = "Invalid or expired state parameter - potential CSRF attack"
log.error("oauth callback with invalid state", { state, pendingStates: Array.from(pendingAuths.keys()) })
return new Response(HTML_ERROR(errorMsg), {
status: 400,
headers: { "Content-Type": "text/html" },
})
}
const pending = pendingAuths.get(state)!
clearTimeout(pending.timeout)
pendingAuths.delete(state)
pending.resolve(code)
return new Response(HTML_SUCCESS, {
headers: { "Content-Type": "text/html" },
})
},
})
log.info("oauth callback server started", { port: OAUTH_CALLBACK_PORT })
}
export function waitForCallback(oauthState: string): Promise<string> {
@ -177,7 +174,7 @@ export namespace McpOAuthCallback {
export async function stop(): Promise<void> {
if (server) {
await new Promise<void>((resolve) => server!.close(() => resolve()))
server.stop()
server = undefined
log.info("oauth callback server stopped")
}

View File

@ -6,7 +6,6 @@ import os from "os"
import { ProviderTransform } from "@/provider/transform"
import { ModelID, ProviderID } from "@/provider/schema"
import { setTimeout as sleep } from "node:timers/promises"
import { createServer } from "http"
const log = Log.create({ service: "plugin.codex" })
@ -242,7 +241,7 @@ interface PendingOAuth {
reject: (error: Error) => void
}
let oauthServer: ReturnType<typeof createServer> | undefined
let oauthServer: ReturnType<typeof Bun.serve> | undefined
let pendingOAuth: PendingOAuth | undefined
async function startOAuthServer(): Promise<{ port: number; redirectUri: string }> {
@ -250,83 +249,77 @@ async function startOAuthServer(): Promise<{ port: number; redirectUri: string }
return { port: OAUTH_PORT, redirectUri: `http://localhost:${OAUTH_PORT}/auth/callback` }
}
oauthServer = createServer((req, res) => {
const url = new URL(req.url || "/", `http://localhost:${OAUTH_PORT}`)
oauthServer = Bun.serve({
port: OAUTH_PORT,
fetch(req) {
const url = new URL(req.url)
if (url.pathname === "/auth/callback") {
const code = url.searchParams.get("code")
const state = url.searchParams.get("state")
const error = url.searchParams.get("error")
const errorDescription = url.searchParams.get("error_description")
if (url.pathname === "/auth/callback") {
const code = url.searchParams.get("code")
const state = url.searchParams.get("state")
const error = url.searchParams.get("error")
const errorDescription = url.searchParams.get("error_description")
if (error) {
const errorMsg = errorDescription || error
pendingOAuth?.reject(new Error(errorMsg))
if (error) {
const errorMsg = errorDescription || error
pendingOAuth?.reject(new Error(errorMsg))
pendingOAuth = undefined
return new Response(HTML_ERROR(errorMsg), {
headers: { "Content-Type": "text/html" },
})
}
if (!code) {
const errorMsg = "Missing authorization code"
pendingOAuth?.reject(new Error(errorMsg))
pendingOAuth = undefined
return new Response(HTML_ERROR(errorMsg), {
status: 400,
headers: { "Content-Type": "text/html" },
})
}
if (!pendingOAuth || state !== pendingOAuth.state) {
const errorMsg = "Invalid state - potential CSRF attack"
pendingOAuth?.reject(new Error(errorMsg))
pendingOAuth = undefined
return new Response(HTML_ERROR(errorMsg), {
status: 400,
headers: { "Content-Type": "text/html" },
})
}
const current = pendingOAuth
pendingOAuth = undefined
res.writeHead(200, { "Content-Type": "text/html" })
res.end(HTML_ERROR(errorMsg))
return
exchangeCodeForTokens(code, `http://localhost:${OAUTH_PORT}/auth/callback`, current.pkce)
.then((tokens) => current.resolve(tokens))
.catch((err) => current.reject(err))
return new Response(HTML_SUCCESS, {
headers: { "Content-Type": "text/html" },
})
}
if (!code) {
const errorMsg = "Missing authorization code"
pendingOAuth?.reject(new Error(errorMsg))
if (url.pathname === "/cancel") {
pendingOAuth?.reject(new Error("Login cancelled"))
pendingOAuth = undefined
res.writeHead(400, { "Content-Type": "text/html" })
res.end(HTML_ERROR(errorMsg))
return
return new Response("Login cancelled", { status: 200 })
}
if (!pendingOAuth || state !== pendingOAuth.state) {
const errorMsg = "Invalid state - potential CSRF attack"
pendingOAuth?.reject(new Error(errorMsg))
pendingOAuth = undefined
res.writeHead(400, { "Content-Type": "text/html" })
res.end(HTML_ERROR(errorMsg))
return
}
const current = pendingOAuth
pendingOAuth = undefined
exchangeCodeForTokens(code, `http://localhost:${OAUTH_PORT}/auth/callback`, current.pkce)
.then((tokens) => current.resolve(tokens))
.catch((err) => current.reject(err))
res.writeHead(200, { "Content-Type": "text/html" })
res.end(HTML_SUCCESS)
return
}
if (url.pathname === "/cancel") {
pendingOAuth?.reject(new Error("Login cancelled"))
pendingOAuth = undefined
res.writeHead(200)
res.end("Login cancelled")
return
}
res.writeHead(404)
res.end("Not found")
})
await new Promise<void>((resolve, reject) => {
oauthServer!.listen(OAUTH_PORT, () => {
log.info("codex oauth server started", { port: OAUTH_PORT })
resolve()
})
oauthServer!.on("error", reject)
return new Response("Not found", { status: 404 })
},
})
log.info("codex oauth server started", { port: OAUTH_PORT })
return { port: OAUTH_PORT, redirectUri: `http://localhost:${OAUTH_PORT}/auth/callback` }
}
function stopOAuthServer() {
if (oauthServer) {
oauthServer.close(() => {
log.info("codex oauth server stopped")
})
oauthServer.stop()
oauthServer = undefined
log.info("codex oauth server stopped")
}
}