refactor: replace Bun.serve with Node http.createServer in OAuth handlers (#18327)
Co-authored-by: LukeParkerDev <10430890+Hona@users.noreply.github.com>pull/18335/head
parent
965c751522
commit
2e4c43c1cf
|
|
@ -1,4 +1,5 @@
|
||||||
import { createConnection } from "net"
|
import { createConnection } from "net"
|
||||||
|
import { createServer } from "http"
|
||||||
import { Log } from "../util/log"
|
import { Log } from "../util/log"
|
||||||
import { OAUTH_CALLBACK_PORT, OAUTH_CALLBACK_PATH } from "./oauth-provider"
|
import { OAUTH_CALLBACK_PORT, OAUTH_CALLBACK_PATH } from "./oauth-provider"
|
||||||
|
|
||||||
|
|
@ -52,7 +53,7 @@ interface PendingAuth {
|
||||||
}
|
}
|
||||||
|
|
||||||
export namespace McpOAuthCallback {
|
export namespace McpOAuthCallback {
|
||||||
let server: ReturnType<typeof Bun.serve> | undefined
|
let server: ReturnType<typeof createServer> | undefined
|
||||||
const pendingAuths = new Map<string, PendingAuth>()
|
const pendingAuths = new Map<string, PendingAuth>()
|
||||||
// Reverse index: mcpName → oauthState, so cancelPending(mcpName) can
|
// Reverse index: mcpName → oauthState, so cancelPending(mcpName) can
|
||||||
// find the right entry in pendingAuths (which is keyed by oauthState).
|
// find the right entry in pendingAuths (which is keyed by oauthState).
|
||||||
|
|
@ -60,22 +61,22 @@ export namespace McpOAuthCallback {
|
||||||
|
|
||||||
const CALLBACK_TIMEOUT_MS = 5 * 60 * 1000 // 5 minutes
|
const CALLBACK_TIMEOUT_MS = 5 * 60 * 1000 // 5 minutes
|
||||||
|
|
||||||
export async function ensureRunning(): Promise<void> {
|
function cleanupStateIndex(oauthState: string) {
|
||||||
if (server) return
|
for (const [name, state] of mcpNameToState) {
|
||||||
|
if (state === oauthState) {
|
||||||
const running = await isPortInUse()
|
mcpNameToState.delete(name)
|
||||||
if (running) {
|
break
|
||||||
log.info("oauth callback server already running on another instance", { port: OAUTH_CALLBACK_PORT })
|
}
|
||||||
return
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
server = Bun.serve({
|
function handleRequest(req: import("http").IncomingMessage, res: import("http").ServerResponse) {
|
||||||
port: OAUTH_CALLBACK_PORT,
|
const url = new URL(req.url || "/", `http://localhost:${OAUTH_CALLBACK_PORT}`)
|
||||||
fetch(req) {
|
|
||||||
const url = new URL(req.url)
|
|
||||||
|
|
||||||
if (url.pathname !== OAUTH_CALLBACK_PATH) {
|
if (url.pathname !== OAUTH_CALLBACK_PATH) {
|
||||||
return new Response("Not found", { status: 404 })
|
res.writeHead(404)
|
||||||
|
res.end("Not found")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
const code = url.searchParams.get("code")
|
const code = url.searchParams.get("code")
|
||||||
|
|
@ -89,10 +90,9 @@ export namespace McpOAuthCallback {
|
||||||
if (!state) {
|
if (!state) {
|
||||||
const errorMsg = "Missing required state parameter - potential CSRF attack"
|
const errorMsg = "Missing required state parameter - potential CSRF attack"
|
||||||
log.error("oauth callback missing state parameter", { url: url.toString() })
|
log.error("oauth callback missing state parameter", { url: url.toString() })
|
||||||
return new Response(HTML_ERROR(errorMsg), {
|
res.writeHead(400, { "Content-Type": "text/html" })
|
||||||
status: 400,
|
res.end(HTML_ERROR(errorMsg))
|
||||||
headers: { "Content-Type": "text/html" },
|
return
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (error) {
|
if (error) {
|
||||||
|
|
@ -101,56 +101,57 @@ export namespace McpOAuthCallback {
|
||||||
const pending = pendingAuths.get(state)!
|
const pending = pendingAuths.get(state)!
|
||||||
clearTimeout(pending.timeout)
|
clearTimeout(pending.timeout)
|
||||||
pendingAuths.delete(state)
|
pendingAuths.delete(state)
|
||||||
for (const [name, s] of mcpNameToState) {
|
cleanupStateIndex(state)
|
||||||
if (s === state) {
|
|
||||||
mcpNameToState.delete(name)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pending.reject(new Error(errorMsg))
|
pending.reject(new Error(errorMsg))
|
||||||
}
|
}
|
||||||
return new Response(HTML_ERROR(errorMsg), {
|
res.writeHead(200, { "Content-Type": "text/html" })
|
||||||
headers: { "Content-Type": "text/html" },
|
res.end(HTML_ERROR(errorMsg))
|
||||||
})
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!code) {
|
if (!code) {
|
||||||
return new Response(HTML_ERROR("No authorization code provided"), {
|
res.writeHead(400, { "Content-Type": "text/html" })
|
||||||
status: 400,
|
res.end(HTML_ERROR("No authorization code provided"))
|
||||||
headers: { "Content-Type": "text/html" },
|
return
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate state parameter
|
// Validate state parameter
|
||||||
if (!pendingAuths.has(state)) {
|
if (!pendingAuths.has(state)) {
|
||||||
const errorMsg = "Invalid or expired state parameter - potential CSRF attack"
|
const errorMsg = "Invalid or expired state parameter - potential CSRF attack"
|
||||||
log.error("oauth callback with invalid state", { state, pendingStates: Array.from(pendingAuths.keys()) })
|
log.error("oauth callback with invalid state", { state, pendingStates: Array.from(pendingAuths.keys()) })
|
||||||
return new Response(HTML_ERROR(errorMsg), {
|
res.writeHead(400, { "Content-Type": "text/html" })
|
||||||
status: 400,
|
res.end(HTML_ERROR(errorMsg))
|
||||||
headers: { "Content-Type": "text/html" },
|
return
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const pending = pendingAuths.get(state)!
|
const pending = pendingAuths.get(state)!
|
||||||
|
|
||||||
clearTimeout(pending.timeout)
|
clearTimeout(pending.timeout)
|
||||||
pendingAuths.delete(state)
|
pendingAuths.delete(state)
|
||||||
// Clean up reverse index
|
cleanupStateIndex(state)
|
||||||
for (const [name, s] of mcpNameToState) {
|
|
||||||
if (s === state) {
|
|
||||||
mcpNameToState.delete(name)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pending.resolve(code)
|
pending.resolve(code)
|
||||||
|
|
||||||
return new Response(HTML_SUCCESS, {
|
res.writeHead(200, { "Content-Type": "text/html" })
|
||||||
headers: { "Content-Type": "text/html" },
|
res.end(HTML_SUCCESS)
|
||||||
})
|
}
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
|
export async function ensureRunning(): Promise<void> {
|
||||||
|
if (server) return
|
||||||
|
|
||||||
|
const running = await isPortInUse()
|
||||||
|
if (running) {
|
||||||
|
log.info("oauth callback server already running on another instance", { port: OAUTH_CALLBACK_PORT })
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
server = createServer(handleRequest)
|
||||||
|
await new Promise<void>((resolve, reject) => {
|
||||||
|
server!.listen(OAUTH_CALLBACK_PORT, () => {
|
||||||
log.info("oauth callback server started", { port: OAUTH_CALLBACK_PORT })
|
log.info("oauth callback server started", { port: OAUTH_CALLBACK_PORT })
|
||||||
|
resolve()
|
||||||
|
})
|
||||||
|
server!.on("error", reject)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
export function waitForCallback(oauthState: string, mcpName?: string): Promise<string> {
|
export function waitForCallback(oauthState: string, mcpName?: string): Promise<string> {
|
||||||
|
|
@ -196,7 +197,7 @@ export namespace McpOAuthCallback {
|
||||||
|
|
||||||
export async function stop(): Promise<void> {
|
export async function stop(): Promise<void> {
|
||||||
if (server) {
|
if (server) {
|
||||||
server.stop()
|
await new Promise<void>((resolve) => server!.close(() => resolve()))
|
||||||
server = undefined
|
server = undefined
|
||||||
log.info("oauth callback server stopped")
|
log.info("oauth callback server stopped")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import os from "os"
|
||||||
import { ProviderTransform } from "@/provider/transform"
|
import { ProviderTransform } from "@/provider/transform"
|
||||||
import { ModelID, ProviderID } from "@/provider/schema"
|
import { ModelID, ProviderID } from "@/provider/schema"
|
||||||
import { setTimeout as sleep } from "node:timers/promises"
|
import { setTimeout as sleep } from "node:timers/promises"
|
||||||
|
import { createServer } from "http"
|
||||||
|
|
||||||
const log = Log.create({ service: "plugin.codex" })
|
const log = Log.create({ service: "plugin.codex" })
|
||||||
|
|
||||||
|
|
@ -241,7 +242,7 @@ interface PendingOAuth {
|
||||||
reject: (error: Error) => void
|
reject: (error: Error) => void
|
||||||
}
|
}
|
||||||
|
|
||||||
let oauthServer: ReturnType<typeof Bun.serve> | undefined
|
let oauthServer: ReturnType<typeof createServer> | undefined
|
||||||
let pendingOAuth: PendingOAuth | undefined
|
let pendingOAuth: PendingOAuth | undefined
|
||||||
|
|
||||||
async function startOAuthServer(): Promise<{ port: number; redirectUri: string }> {
|
async function startOAuthServer(): Promise<{ port: number; redirectUri: string }> {
|
||||||
|
|
@ -249,10 +250,8 @@ async function startOAuthServer(): Promise<{ port: number; redirectUri: string }
|
||||||
return { port: OAUTH_PORT, redirectUri: `http://localhost:${OAUTH_PORT}/auth/callback` }
|
return { port: OAUTH_PORT, redirectUri: `http://localhost:${OAUTH_PORT}/auth/callback` }
|
||||||
}
|
}
|
||||||
|
|
||||||
oauthServer = Bun.serve({
|
oauthServer = createServer((req, res) => {
|
||||||
port: OAUTH_PORT,
|
const url = new URL(req.url || "/", `http://localhost:${OAUTH_PORT}`)
|
||||||
fetch(req) {
|
|
||||||
const url = new URL(req.url)
|
|
||||||
|
|
||||||
if (url.pathname === "/auth/callback") {
|
if (url.pathname === "/auth/callback") {
|
||||||
const code = url.searchParams.get("code")
|
const code = url.searchParams.get("code")
|
||||||
|
|
@ -264,29 +263,27 @@ async function startOAuthServer(): Promise<{ port: number; redirectUri: string }
|
||||||
const errorMsg = errorDescription || error
|
const errorMsg = errorDescription || error
|
||||||
pendingOAuth?.reject(new Error(errorMsg))
|
pendingOAuth?.reject(new Error(errorMsg))
|
||||||
pendingOAuth = undefined
|
pendingOAuth = undefined
|
||||||
return new Response(HTML_ERROR(errorMsg), {
|
res.writeHead(200, { "Content-Type": "text/html" })
|
||||||
headers: { "Content-Type": "text/html" },
|
res.end(HTML_ERROR(errorMsg))
|
||||||
})
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!code) {
|
if (!code) {
|
||||||
const errorMsg = "Missing authorization code"
|
const errorMsg = "Missing authorization code"
|
||||||
pendingOAuth?.reject(new Error(errorMsg))
|
pendingOAuth?.reject(new Error(errorMsg))
|
||||||
pendingOAuth = undefined
|
pendingOAuth = undefined
|
||||||
return new Response(HTML_ERROR(errorMsg), {
|
res.writeHead(400, { "Content-Type": "text/html" })
|
||||||
status: 400,
|
res.end(HTML_ERROR(errorMsg))
|
||||||
headers: { "Content-Type": "text/html" },
|
return
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!pendingOAuth || state !== pendingOAuth.state) {
|
if (!pendingOAuth || state !== pendingOAuth.state) {
|
||||||
const errorMsg = "Invalid state - potential CSRF attack"
|
const errorMsg = "Invalid state - potential CSRF attack"
|
||||||
pendingOAuth?.reject(new Error(errorMsg))
|
pendingOAuth?.reject(new Error(errorMsg))
|
||||||
pendingOAuth = undefined
|
pendingOAuth = undefined
|
||||||
return new Response(HTML_ERROR(errorMsg), {
|
res.writeHead(400, { "Content-Type": "text/html" })
|
||||||
status: 400,
|
res.end(HTML_ERROR(errorMsg))
|
||||||
headers: { "Content-Type": "text/html" },
|
return
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const current = pendingOAuth
|
const current = pendingOAuth
|
||||||
|
|
@ -296,30 +293,40 @@ async function startOAuthServer(): Promise<{ port: number; redirectUri: string }
|
||||||
.then((tokens) => current.resolve(tokens))
|
.then((tokens) => current.resolve(tokens))
|
||||||
.catch((err) => current.reject(err))
|
.catch((err) => current.reject(err))
|
||||||
|
|
||||||
return new Response(HTML_SUCCESS, {
|
res.writeHead(200, { "Content-Type": "text/html" })
|
||||||
headers: { "Content-Type": "text/html" },
|
res.end(HTML_SUCCESS)
|
||||||
})
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if (url.pathname === "/cancel") {
|
if (url.pathname === "/cancel") {
|
||||||
pendingOAuth?.reject(new Error("Login cancelled"))
|
pendingOAuth?.reject(new Error("Login cancelled"))
|
||||||
pendingOAuth = undefined
|
pendingOAuth = undefined
|
||||||
return new Response("Login cancelled", { status: 200 })
|
res.writeHead(200)
|
||||||
|
res.end("Login cancelled")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
return new Response("Not found", { status: 404 })
|
res.writeHead(404)
|
||||||
},
|
res.end("Not found")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
await new Promise<void>((resolve, reject) => {
|
||||||
|
oauthServer!.listen(OAUTH_PORT, () => {
|
||||||
log.info("codex oauth server started", { port: OAUTH_PORT })
|
log.info("codex oauth server started", { port: OAUTH_PORT })
|
||||||
|
resolve()
|
||||||
|
})
|
||||||
|
oauthServer!.on("error", reject)
|
||||||
|
})
|
||||||
|
|
||||||
return { port: OAUTH_PORT, redirectUri: `http://localhost:${OAUTH_PORT}/auth/callback` }
|
return { port: OAUTH_PORT, redirectUri: `http://localhost:${OAUTH_PORT}/auth/callback` }
|
||||||
}
|
}
|
||||||
|
|
||||||
function stopOAuthServer() {
|
function stopOAuthServer() {
|
||||||
if (oauthServer) {
|
if (oauthServer) {
|
||||||
oauthServer.stop()
|
oauthServer.close(() => {
|
||||||
oauthServer = undefined
|
|
||||||
log.info("codex oauth server stopped")
|
log.info("codex oauth server stopped")
|
||||||
|
})
|
||||||
|
oauthServer = undefined
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue