Compare commits
4 Commits
dev
...
fix-read-t
| Author | SHA1 | Date |
|---|---|---|
|
|
c92351ec52 | |
|
|
4bd25e0e33 | |
|
|
798f866d4c | |
|
|
838c9968ec |
|
|
@ -1,839 +0,0 @@
|
||||||
import type { APIEvent } from "@solidjs/start/server"
|
|
||||||
import { and, Database, eq, isNull, lt, or, sql } from "@opencode-ai/console-core/drizzle/index.js"
|
|
||||||
import { KeyTable } from "@opencode-ai/console-core/schema/key.sql.js"
|
|
||||||
import { BillingTable, SubscriptionTable, UsageTable } from "@opencode-ai/console-core/schema/billing.sql.js"
|
|
||||||
import { centsToMicroCents } from "@opencode-ai/console-core/util/price.js"
|
|
||||||
import { getWeekBounds } from "@opencode-ai/console-core/util/date.js"
|
|
||||||
import { Identifier } from "@opencode-ai/console-core/identifier.js"
|
|
||||||
import { Billing } from "@opencode-ai/console-core/billing.js"
|
|
||||||
import { Actor } from "@opencode-ai/console-core/actor.js"
|
|
||||||
import { WorkspaceTable } from "@opencode-ai/console-core/schema/workspace.sql.js"
|
|
||||||
import { ZenData } from "@opencode-ai/console-core/model.js"
|
|
||||||
import { Black, BlackData } from "@opencode-ai/console-core/black.js"
|
|
||||||
import { UserTable } from "@opencode-ai/console-core/schema/user.sql.js"
|
|
||||||
import { ModelTable } from "@opencode-ai/console-core/schema/model.sql.js"
|
|
||||||
import { ProviderTable } from "@opencode-ai/console-core/schema/provider.sql.js"
|
|
||||||
import { logger } from "./logger"
|
|
||||||
import {
|
|
||||||
AuthError,
|
|
||||||
CreditsError,
|
|
||||||
MonthlyLimitError,
|
|
||||||
UserLimitError,
|
|
||||||
ModelError,
|
|
||||||
FreeUsageLimitError,
|
|
||||||
SubscriptionUsageLimitError,
|
|
||||||
} from "./error"
|
|
||||||
import { createBodyConverter, createStreamPartConverter, createResponseConverter, UsageInfo } from "./provider/provider"
|
|
||||||
import { anthropicHelper } from "./provider/anthropic"
|
|
||||||
import { googleHelper } from "./provider/google"
|
|
||||||
import { openaiHelper } from "./provider/openai"
|
|
||||||
import { oaCompatHelper } from "./provider/openai-compatible"
|
|
||||||
import { createRateLimiter } from "./rateLimiter"
|
|
||||||
import { createDataDumper } from "./dataDumper"
|
|
||||||
import { createTrialLimiter } from "./trialLimiter"
|
|
||||||
import { createStickyTracker } from "./stickyProviderTracker"
|
|
||||||
|
|
||||||
type ZenData = Awaited<ReturnType<typeof ZenData.list>>
|
|
||||||
type RetryOptions = {
|
|
||||||
excludeProviders: string[]
|
|
||||||
retryCount: number
|
|
||||||
}
|
|
||||||
type BillingSource = "anonymous" | "free" | "byok" | "subscription" | "balance"
|
|
||||||
|
|
||||||
export async function handler(
|
|
||||||
input: APIEvent,
|
|
||||||
opts: {
|
|
||||||
format: ZenData.Format
|
|
||||||
parseApiKey: (headers: Headers) => string | undefined
|
|
||||||
parseModel: (url: string, body: any) => string
|
|
||||||
parseIsStream: (url: string, body: any) => boolean
|
|
||||||
},
|
|
||||||
) {
|
|
||||||
type AuthInfo = Awaited<ReturnType<typeof authenticate>>
|
|
||||||
type ModelInfo = Awaited<ReturnType<typeof validateModel>>
|
|
||||||
type ProviderInfo = Awaited<ReturnType<typeof selectProvider>>
|
|
||||||
type CostInfo = ReturnType<typeof calculateCost>
|
|
||||||
|
|
||||||
const MAX_FAILOVER_RETRIES = 3
|
|
||||||
const MAX_429_RETRIES = 3
|
|
||||||
const FREE_WORKSPACES = [
|
|
||||||
"wrk_01K46JDFR0E75SG2Q8K172KF3Y", // frank
|
|
||||||
"wrk_01K6W1A3VE0KMNVSCQT43BG2SX", // opencode bench
|
|
||||||
]
|
|
||||||
|
|
||||||
try {
|
|
||||||
const url = input.request.url
|
|
||||||
const body = await input.request.json()
|
|
||||||
const model = opts.parseModel(url, body)
|
|
||||||
const isStream = opts.parseIsStream(url, body)
|
|
||||||
const ip = input.request.headers.get("x-real-ip") ?? ""
|
|
||||||
const sessionId = input.request.headers.get("x-opencode-session") ?? ""
|
|
||||||
const requestId = input.request.headers.get("x-opencode-request") ?? ""
|
|
||||||
const projectId = input.request.headers.get("x-opencode-project") ?? ""
|
|
||||||
const ocClient = input.request.headers.get("x-opencode-client") ?? ""
|
|
||||||
logger.metric({
|
|
||||||
is_tream: isStream,
|
|
||||||
session: sessionId,
|
|
||||||
request: requestId,
|
|
||||||
client: ocClient,
|
|
||||||
})
|
|
||||||
const zenData = ZenData.list()
|
|
||||||
const modelInfo = validateModel(zenData, model)
|
|
||||||
const dataDumper = createDataDumper(sessionId, requestId, projectId)
|
|
||||||
const trialLimiter = createTrialLimiter(modelInfo.trial, ip, ocClient)
|
|
||||||
const isTrial = await trialLimiter?.isTrial()
|
|
||||||
const rateLimiter = createRateLimiter(modelInfo.rateLimit, ip, input.request.headers)
|
|
||||||
await rateLimiter?.check()
|
|
||||||
const stickyTracker = createStickyTracker(modelInfo.stickyProvider, sessionId)
|
|
||||||
const stickyProvider = await stickyTracker?.get()
|
|
||||||
const authInfo = await authenticate(modelInfo)
|
|
||||||
const billingSource = validateBilling(authInfo, modelInfo)
|
|
||||||
|
|
||||||
const retriableRequest = async (retry: RetryOptions = { excludeProviders: [], retryCount: 0 }) => {
|
|
||||||
const providerInfo = selectProvider(
|
|
||||||
model,
|
|
||||||
zenData,
|
|
||||||
authInfo,
|
|
||||||
modelInfo,
|
|
||||||
sessionId,
|
|
||||||
isTrial ?? false,
|
|
||||||
retry,
|
|
||||||
stickyProvider,
|
|
||||||
)
|
|
||||||
validateModelSettings(authInfo)
|
|
||||||
updateProviderKey(authInfo, providerInfo)
|
|
||||||
logger.metric({ provider: providerInfo.id })
|
|
||||||
|
|
||||||
const startTimestamp = Date.now()
|
|
||||||
const reqUrl = providerInfo.modifyUrl(providerInfo.api, isStream)
|
|
||||||
const reqBody = JSON.stringify(
|
|
||||||
providerInfo.modifyBody({
|
|
||||||
...createBodyConverter(opts.format, providerInfo.format)(body),
|
|
||||||
model: providerInfo.model,
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
logger.debug("REQUEST URL: " + reqUrl)
|
|
||||||
logger.debug("REQUEST: " + reqBody.substring(0, 300) + "...")
|
|
||||||
const res = await fetchWith429Retry(reqUrl, {
|
|
||||||
method: "POST",
|
|
||||||
headers: (() => {
|
|
||||||
const headers = new Headers(input.request.headers)
|
|
||||||
providerInfo.modifyHeaders(headers, body, providerInfo.apiKey)
|
|
||||||
Object.entries(providerInfo.headerMappings ?? {}).forEach(([k, v]) => {
|
|
||||||
headers.set(k, headers.get(v)!)
|
|
||||||
})
|
|
||||||
Object.entries(providerInfo.headers ?? {}).forEach(([k, v]) => {
|
|
||||||
headers.set(k, v)
|
|
||||||
})
|
|
||||||
headers.delete("host")
|
|
||||||
headers.delete("content-length")
|
|
||||||
headers.delete("x-opencode-request")
|
|
||||||
headers.delete("x-opencode-session")
|
|
||||||
headers.delete("x-opencode-project")
|
|
||||||
headers.delete("x-opencode-client")
|
|
||||||
return headers
|
|
||||||
})(),
|
|
||||||
body: reqBody,
|
|
||||||
})
|
|
||||||
|
|
||||||
if (res.status !== 200) {
|
|
||||||
logger.metric({
|
|
||||||
"llm.error.code": res.status,
|
|
||||||
"llm.error.message": res.statusText,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try another provider => stop retrying if using fallback provider
|
|
||||||
if (
|
|
||||||
res.status !== 200 &&
|
|
||||||
// ie. openai 404 error: Item with id 'msg_0ead8b004a3b165d0069436a6b6834819896da85b63b196a3f' not found.
|
|
||||||
res.status !== 404 &&
|
|
||||||
// ie. cannot change codex model providers mid-session
|
|
||||||
modelInfo.stickyProvider !== "strict" &&
|
|
||||||
modelInfo.fallbackProvider &&
|
|
||||||
providerInfo.id !== modelInfo.fallbackProvider
|
|
||||||
) {
|
|
||||||
return retriableRequest({
|
|
||||||
excludeProviders: [...retry.excludeProviders, providerInfo.id],
|
|
||||||
retryCount: retry.retryCount + 1,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return { providerInfo, reqBody, res, startTimestamp }
|
|
||||||
}
|
|
||||||
|
|
||||||
const { providerInfo, reqBody, res, startTimestamp } = await retriableRequest()
|
|
||||||
|
|
||||||
// Store model request
|
|
||||||
dataDumper?.provideModel(providerInfo.storeModel)
|
|
||||||
dataDumper?.provideRequest(reqBody)
|
|
||||||
|
|
||||||
// Store sticky provider
|
|
||||||
await stickyTracker?.set(providerInfo.id)
|
|
||||||
|
|
||||||
// Temporarily change 404 to 400 status code b/c solid start automatically override 404 response
|
|
||||||
const resStatus = res.status === 404 ? 400 : res.status
|
|
||||||
|
|
||||||
// Scrub response headers
|
|
||||||
const resHeaders = new Headers()
|
|
||||||
const keepHeaders = ["content-type", "cache-control"]
|
|
||||||
for (const [k, v] of res.headers.entries()) {
|
|
||||||
if (keepHeaders.includes(k.toLowerCase())) {
|
|
||||||
resHeaders.set(k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
logger.debug("STATUS: " + res.status + " " + res.statusText)
|
|
||||||
|
|
||||||
// Handle non-streaming response
|
|
||||||
if (!isStream) {
|
|
||||||
const json = await res.json()
|
|
||||||
const usageInfo = providerInfo.normalizeUsage(json.usage)
|
|
||||||
const costInfo = calculateCost(modelInfo, usageInfo)
|
|
||||||
await trialLimiter?.track(usageInfo)
|
|
||||||
await rateLimiter?.track()
|
|
||||||
await trackUsage(billingSource, authInfo, modelInfo, providerInfo, usageInfo, costInfo)
|
|
||||||
await reload(billingSource, authInfo, costInfo)
|
|
||||||
|
|
||||||
const responseConverter = createResponseConverter(providerInfo.format, opts.format)
|
|
||||||
const body = JSON.stringify(
|
|
||||||
responseConverter({
|
|
||||||
...json,
|
|
||||||
cost: calculateOccuredCost(billingSource, costInfo),
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
logger.metric({ response_length: body.length })
|
|
||||||
logger.debug("RESPONSE: " + body)
|
|
||||||
dataDumper?.provideResponse(body)
|
|
||||||
dataDumper?.flush()
|
|
||||||
return new Response(body, {
|
|
||||||
status: resStatus,
|
|
||||||
statusText: res.statusText,
|
|
||||||
headers: resHeaders,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle streaming response
|
|
||||||
const streamConverter = createStreamPartConverter(providerInfo.format, opts.format)
|
|
||||||
const usageParser = providerInfo.createUsageParser()
|
|
||||||
const binaryDecoder = providerInfo.createBinaryStreamDecoder()
|
|
||||||
const stream = new ReadableStream({
|
|
||||||
start(c) {
|
|
||||||
const reader = res.body?.getReader()
|
|
||||||
const decoder = new TextDecoder()
|
|
||||||
const encoder = new TextEncoder()
|
|
||||||
|
|
||||||
let buffer = ""
|
|
||||||
let responseLength = 0
|
|
||||||
|
|
||||||
function pump(): Promise<void> {
|
|
||||||
return (
|
|
||||||
reader?.read().then(async ({ done, value: rawValue }) => {
|
|
||||||
if (done) {
|
|
||||||
logger.metric({
|
|
||||||
response_length: responseLength,
|
|
||||||
"timestamp.last_byte": Date.now(),
|
|
||||||
})
|
|
||||||
dataDumper?.flush()
|
|
||||||
await rateLimiter?.track()
|
|
||||||
const usage = usageParser.retrieve()
|
|
||||||
let cost = "0"
|
|
||||||
if (usage) {
|
|
||||||
const usageInfo = providerInfo.normalizeUsage(usage)
|
|
||||||
const costInfo = calculateCost(modelInfo, usageInfo)
|
|
||||||
await trialLimiter?.track(usageInfo)
|
|
||||||
await trackUsage(billingSource, authInfo, modelInfo, providerInfo, usageInfo, costInfo)
|
|
||||||
await reload(billingSource, authInfo, costInfo)
|
|
||||||
cost = calculateOccuredCost(billingSource, costInfo)
|
|
||||||
}
|
|
||||||
c.enqueue(encoder.encode(usageParser.buidlCostChunk(cost)))
|
|
||||||
c.close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if (responseLength === 0) {
|
|
||||||
const now = Date.now()
|
|
||||||
logger.metric({
|
|
||||||
time_to_first_byte: now - startTimestamp,
|
|
||||||
"timestamp.first_byte": now,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
const value = binaryDecoder ? binaryDecoder(rawValue) : rawValue
|
|
||||||
if (!value) return
|
|
||||||
|
|
||||||
responseLength += value.length
|
|
||||||
buffer += decoder.decode(value, { stream: true })
|
|
||||||
dataDumper?.provideStream(buffer)
|
|
||||||
|
|
||||||
const parts = buffer.split(providerInfo.streamSeparator)
|
|
||||||
buffer = parts.pop() ?? ""
|
|
||||||
|
|
||||||
for (let part of parts) {
|
|
||||||
logger.debug("PART: " + part)
|
|
||||||
|
|
||||||
part = part.trim()
|
|
||||||
usageParser.parse(part)
|
|
||||||
|
|
||||||
if (providerInfo.bodyModifier) {
|
|
||||||
for (const [k, v] of Object.entries(providerInfo.bodyModifier)) {
|
|
||||||
part = part.replace(k, v)
|
|
||||||
}
|
|
||||||
c.enqueue(encoder.encode(part + "\n\n"))
|
|
||||||
} else if (providerInfo.format !== opts.format) {
|
|
||||||
part = streamConverter(part)
|
|
||||||
c.enqueue(encoder.encode(part + "\n\n"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!providerInfo.bodyModifier && providerInfo.format === opts.format) {
|
|
||||||
c.enqueue(value)
|
|
||||||
}
|
|
||||||
|
|
||||||
return pump()
|
|
||||||
}) || Promise.resolve()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
return pump()
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return new Response(stream, {
|
|
||||||
status: resStatus,
|
|
||||||
statusText: res.statusText,
|
|
||||||
headers: resHeaders,
|
|
||||||
})
|
|
||||||
} catch (error: any) {
|
|
||||||
logger.metric({
|
|
||||||
"error.type": error.constructor.name,
|
|
||||||
"error.message": error.message,
|
|
||||||
})
|
|
||||||
|
|
||||||
// Note: both top level "type" and "error.type" fields are used by the @ai-sdk/anthropic client to render the error message.
|
|
||||||
if (
|
|
||||||
error instanceof AuthError ||
|
|
||||||
error instanceof CreditsError ||
|
|
||||||
error instanceof MonthlyLimitError ||
|
|
||||||
error instanceof UserLimitError ||
|
|
||||||
error instanceof ModelError
|
|
||||||
)
|
|
||||||
return new Response(
|
|
||||||
JSON.stringify({
|
|
||||||
type: "error",
|
|
||||||
error: { type: error.constructor.name, message: error.message },
|
|
||||||
}),
|
|
||||||
{ status: 401 },
|
|
||||||
)
|
|
||||||
|
|
||||||
if (error instanceof FreeUsageLimitError || error instanceof SubscriptionUsageLimitError) {
|
|
||||||
const headers = new Headers()
|
|
||||||
if (error.retryAfter) {
|
|
||||||
headers.set("retry-after", String(error.retryAfter))
|
|
||||||
}
|
|
||||||
return new Response(
|
|
||||||
JSON.stringify({
|
|
||||||
type: "error",
|
|
||||||
error: { type: error.constructor.name, message: error.message },
|
|
||||||
}),
|
|
||||||
{ status: 429, headers },
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
return new Response(
|
|
||||||
JSON.stringify({
|
|
||||||
type: "error",
|
|
||||||
error: {
|
|
||||||
type: "error",
|
|
||||||
message: error.message,
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
{ status: 500 },
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
function validateModel(zenData: ZenData, reqModel: string) {
|
|
||||||
if (!(reqModel in zenData.models)) throw new ModelError(`Model ${reqModel} not supported`)
|
|
||||||
|
|
||||||
const modelId = reqModel as keyof typeof zenData.models
|
|
||||||
const modelData = Array.isArray(zenData.models[modelId])
|
|
||||||
? zenData.models[modelId].find((model) => opts.format === model.formatFilter)
|
|
||||||
: zenData.models[modelId]
|
|
||||||
|
|
||||||
if (!modelData) throw new ModelError(`Model ${reqModel} not supported for format ${opts.format}`)
|
|
||||||
|
|
||||||
logger.metric({ model: modelId })
|
|
||||||
|
|
||||||
return { id: modelId, ...modelData }
|
|
||||||
}
|
|
||||||
|
|
||||||
function selectProvider(
|
|
||||||
reqModel: string,
|
|
||||||
zenData: ZenData,
|
|
||||||
authInfo: AuthInfo,
|
|
||||||
modelInfo: ModelInfo,
|
|
||||||
sessionId: string,
|
|
||||||
isTrial: boolean,
|
|
||||||
retry: RetryOptions,
|
|
||||||
stickyProvider: string | undefined,
|
|
||||||
) {
|
|
||||||
const modelProvider = (() => {
|
|
||||||
if (authInfo?.provider?.credentials) {
|
|
||||||
return modelInfo.providers.find((provider) => provider.id === modelInfo.byokProvider)
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isTrial) {
|
|
||||||
return modelInfo.providers.find((provider) => provider.id === modelInfo.trial!.provider)
|
|
||||||
}
|
|
||||||
|
|
||||||
if (stickyProvider) {
|
|
||||||
const provider = modelInfo.providers.find((provider) => provider.id === stickyProvider)
|
|
||||||
if (provider) return provider
|
|
||||||
}
|
|
||||||
|
|
||||||
if (retry.retryCount === MAX_FAILOVER_RETRIES) {
|
|
||||||
const provider = modelInfo.providers.find((provider) => provider.id === modelInfo.fallbackProvider)
|
|
||||||
if (provider) return provider
|
|
||||||
}
|
|
||||||
|
|
||||||
const providers = modelInfo.providers
|
|
||||||
.filter((provider) => !provider.disabled)
|
|
||||||
.filter((provider) => !retry.excludeProviders.includes(provider.id))
|
|
||||||
.flatMap((provider) => Array<typeof provider>(provider.weight ?? 1).fill(provider))
|
|
||||||
|
|
||||||
// Use the last 4 characters of session ID to select a provider
|
|
||||||
let h = 0
|
|
||||||
const l = sessionId.length
|
|
||||||
for (let i = l - 4; i < l; i++) {
|
|
||||||
h = (h * 31 + sessionId.charCodeAt(i)) | 0 // 32-bit int
|
|
||||||
}
|
|
||||||
const index = (h >>> 0) % providers.length // make unsigned + range 0..length-1
|
|
||||||
return providers[index || 0]
|
|
||||||
})()
|
|
||||||
|
|
||||||
if (!modelProvider) throw new ModelError("No provider available")
|
|
||||||
if (!(modelProvider.id in zenData.providers)) throw new ModelError(`Provider ${modelProvider.id} not supported`)
|
|
||||||
|
|
||||||
return {
|
|
||||||
...modelProvider,
|
|
||||||
...zenData.providers[modelProvider.id],
|
|
||||||
...(() => {
|
|
||||||
const format = zenData.providers[modelProvider.id].format
|
|
||||||
const providerModel = modelProvider.model
|
|
||||||
if (format === "anthropic") return anthropicHelper({ reqModel, providerModel })
|
|
||||||
if (format === "google") return googleHelper({ reqModel, providerModel })
|
|
||||||
if (format === "openai") return openaiHelper({ reqModel, providerModel })
|
|
||||||
return oaCompatHelper({ reqModel, providerModel })
|
|
||||||
})(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async function authenticate(modelInfo: ModelInfo) {
|
|
||||||
const apiKey = opts.parseApiKey(input.request.headers)
|
|
||||||
if (!apiKey || apiKey === "public") {
|
|
||||||
if (modelInfo.allowAnonymous) return
|
|
||||||
throw new AuthError("Missing API key.")
|
|
||||||
}
|
|
||||||
|
|
||||||
const data = await Database.use((tx) =>
|
|
||||||
tx
|
|
||||||
.select({
|
|
||||||
apiKey: KeyTable.id,
|
|
||||||
workspaceID: KeyTable.workspaceID,
|
|
||||||
billing: {
|
|
||||||
balance: BillingTable.balance,
|
|
||||||
paymentMethodID: BillingTable.paymentMethodID,
|
|
||||||
monthlyLimit: BillingTable.monthlyLimit,
|
|
||||||
monthlyUsage: BillingTable.monthlyUsage,
|
|
||||||
timeMonthlyUsageUpdated: BillingTable.timeMonthlyUsageUpdated,
|
|
||||||
reloadTrigger: BillingTable.reloadTrigger,
|
|
||||||
timeReloadLockedTill: BillingTable.timeReloadLockedTill,
|
|
||||||
subscription: BillingTable.subscription,
|
|
||||||
},
|
|
||||||
user: {
|
|
||||||
id: UserTable.id,
|
|
||||||
monthlyLimit: UserTable.monthlyLimit,
|
|
||||||
monthlyUsage: UserTable.monthlyUsage,
|
|
||||||
timeMonthlyUsageUpdated: UserTable.timeMonthlyUsageUpdated,
|
|
||||||
},
|
|
||||||
subscription: {
|
|
||||||
id: SubscriptionTable.id,
|
|
||||||
rollingUsage: SubscriptionTable.rollingUsage,
|
|
||||||
fixedUsage: SubscriptionTable.fixedUsage,
|
|
||||||
timeRollingUpdated: SubscriptionTable.timeRollingUpdated,
|
|
||||||
timeFixedUpdated: SubscriptionTable.timeFixedUpdated,
|
|
||||||
},
|
|
||||||
provider: {
|
|
||||||
credentials: ProviderTable.credentials,
|
|
||||||
},
|
|
||||||
timeDisabled: ModelTable.timeCreated,
|
|
||||||
})
|
|
||||||
.from(KeyTable)
|
|
||||||
.innerJoin(WorkspaceTable, eq(WorkspaceTable.id, KeyTable.workspaceID))
|
|
||||||
.innerJoin(BillingTable, eq(BillingTable.workspaceID, KeyTable.workspaceID))
|
|
||||||
.innerJoin(UserTable, and(eq(UserTable.workspaceID, KeyTable.workspaceID), eq(UserTable.id, KeyTable.userID)))
|
|
||||||
.leftJoin(ModelTable, and(eq(ModelTable.workspaceID, KeyTable.workspaceID), eq(ModelTable.model, modelInfo.id)))
|
|
||||||
.leftJoin(
|
|
||||||
ProviderTable,
|
|
||||||
modelInfo.byokProvider
|
|
||||||
? and(
|
|
||||||
eq(ProviderTable.workspaceID, KeyTable.workspaceID),
|
|
||||||
eq(ProviderTable.provider, modelInfo.byokProvider),
|
|
||||||
)
|
|
||||||
: sql`false`,
|
|
||||||
)
|
|
||||||
.leftJoin(
|
|
||||||
SubscriptionTable,
|
|
||||||
and(
|
|
||||||
eq(SubscriptionTable.workspaceID, KeyTable.workspaceID),
|
|
||||||
eq(SubscriptionTable.userID, KeyTable.userID),
|
|
||||||
isNull(SubscriptionTable.timeDeleted),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.where(and(eq(KeyTable.key, apiKey), isNull(KeyTable.timeDeleted)))
|
|
||||||
.then((rows) => rows[0]),
|
|
||||||
)
|
|
||||||
|
|
||||||
if (!data) throw new AuthError("Invalid API key.")
|
|
||||||
logger.metric({
|
|
||||||
api_key: data.apiKey,
|
|
||||||
workspace: data.workspaceID,
|
|
||||||
isSubscription: data.subscription ? true : false,
|
|
||||||
subscription: data.billing.subscription?.plan,
|
|
||||||
})
|
|
||||||
|
|
||||||
return {
|
|
||||||
apiKeyId: data.apiKey,
|
|
||||||
workspaceID: data.workspaceID,
|
|
||||||
billing: data.billing,
|
|
||||||
user: data.user,
|
|
||||||
subscription: data.subscription,
|
|
||||||
provider: data.provider,
|
|
||||||
isFree: FREE_WORKSPACES.includes(data.workspaceID),
|
|
||||||
isDisabled: !!data.timeDisabled,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function validateBilling(authInfo: AuthInfo, modelInfo: ModelInfo): BillingSource {
|
|
||||||
if (!authInfo) return "anonymous"
|
|
||||||
if (authInfo.provider?.credentials) return "byok"
|
|
||||||
if (authInfo.isFree) return "free"
|
|
||||||
if (modelInfo.allowAnonymous) return "free"
|
|
||||||
|
|
||||||
// Validate subscription billing
|
|
||||||
if (authInfo.billing.subscription && authInfo.subscription) {
|
|
||||||
try {
|
|
||||||
const sub = authInfo.subscription
|
|
||||||
const plan = authInfo.billing.subscription.plan
|
|
||||||
|
|
||||||
const formatRetryTime = (seconds: number) => {
|
|
||||||
const days = Math.floor(seconds / 86400)
|
|
||||||
if (days >= 1) return `${days} day${days > 1 ? "s" : ""}`
|
|
||||||
const hours = Math.floor(seconds / 3600)
|
|
||||||
const minutes = Math.ceil((seconds % 3600) / 60)
|
|
||||||
if (hours >= 1) return `${hours}hr ${minutes}min`
|
|
||||||
return `${minutes}min`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check weekly limit
|
|
||||||
if (sub.fixedUsage && sub.timeFixedUpdated) {
|
|
||||||
const result = Black.analyzeWeeklyUsage({
|
|
||||||
plan,
|
|
||||||
usage: sub.fixedUsage,
|
|
||||||
timeUpdated: sub.timeFixedUpdated,
|
|
||||||
})
|
|
||||||
if (result.status === "rate-limited")
|
|
||||||
throw new SubscriptionUsageLimitError(
|
|
||||||
`Subscription quota exceeded. Retry in ${formatRetryTime(result.resetInSec)}.`,
|
|
||||||
result.resetInSec,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check rolling limit
|
|
||||||
if (sub.rollingUsage && sub.timeRollingUpdated) {
|
|
||||||
const result = Black.analyzeRollingUsage({
|
|
||||||
plan,
|
|
||||||
usage: sub.rollingUsage,
|
|
||||||
timeUpdated: sub.timeRollingUpdated,
|
|
||||||
})
|
|
||||||
if (result.status === "rate-limited")
|
|
||||||
throw new SubscriptionUsageLimitError(
|
|
||||||
`Subscription quota exceeded. Retry in ${formatRetryTime(result.resetInSec)}.`,
|
|
||||||
result.resetInSec,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
return "subscription"
|
|
||||||
} catch (e) {
|
|
||||||
if (!authInfo.billing.subscription.useBalance) throw e
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate pay as you go billing
|
|
||||||
const billing = authInfo.billing
|
|
||||||
if (!billing.paymentMethodID)
|
|
||||||
throw new CreditsError(
|
|
||||||
`No payment method. Add a payment method here: https://opencode.ai/workspace/${authInfo.workspaceID}/billing`,
|
|
||||||
)
|
|
||||||
if (billing.balance <= 0)
|
|
||||||
throw new CreditsError(
|
|
||||||
`Insufficient balance. Manage your billing here: https://opencode.ai/workspace/${authInfo.workspaceID}/billing`,
|
|
||||||
)
|
|
||||||
|
|
||||||
const now = new Date()
|
|
||||||
const currentYear = now.getUTCFullYear()
|
|
||||||
const currentMonth = now.getUTCMonth()
|
|
||||||
if (
|
|
||||||
billing.monthlyLimit &&
|
|
||||||
billing.monthlyUsage &&
|
|
||||||
billing.timeMonthlyUsageUpdated &&
|
|
||||||
billing.monthlyUsage >= centsToMicroCents(billing.monthlyLimit * 100) &&
|
|
||||||
currentYear === billing.timeMonthlyUsageUpdated.getUTCFullYear() &&
|
|
||||||
currentMonth === billing.timeMonthlyUsageUpdated.getUTCMonth()
|
|
||||||
)
|
|
||||||
throw new MonthlyLimitError(
|
|
||||||
`Your workspace has reached its monthly spending limit of $${billing.monthlyLimit}. Manage your limits here: https://opencode.ai/workspace/${authInfo.workspaceID}/billing`,
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
authInfo.user.monthlyLimit &&
|
|
||||||
authInfo.user.monthlyUsage &&
|
|
||||||
authInfo.user.timeMonthlyUsageUpdated &&
|
|
||||||
authInfo.user.monthlyUsage >= centsToMicroCents(authInfo.user.monthlyLimit * 100) &&
|
|
||||||
currentYear === authInfo.user.timeMonthlyUsageUpdated.getUTCFullYear() &&
|
|
||||||
currentMonth === authInfo.user.timeMonthlyUsageUpdated.getUTCMonth()
|
|
||||||
)
|
|
||||||
throw new UserLimitError(
|
|
||||||
`You have reached your monthly spending limit of $${authInfo.user.monthlyLimit}. Manage your limits here: https://opencode.ai/workspace/${authInfo.workspaceID}/members`,
|
|
||||||
)
|
|
||||||
|
|
||||||
return "balance"
|
|
||||||
}
|
|
||||||
|
|
||||||
function validateModelSettings(authInfo: AuthInfo) {
|
|
||||||
if (!authInfo) return
|
|
||||||
if (authInfo.isDisabled) throw new ModelError("Model is disabled")
|
|
||||||
}
|
|
||||||
|
|
||||||
function updateProviderKey(authInfo: AuthInfo, providerInfo: ProviderInfo) {
|
|
||||||
if (!authInfo?.provider?.credentials) return
|
|
||||||
providerInfo.apiKey = authInfo.provider.credentials
|
|
||||||
}
|
|
||||||
|
|
||||||
async function fetchWith429Retry(url: string, options: RequestInit, retry = { count: 0 }) {
|
|
||||||
const res = await fetch(url, options)
|
|
||||||
if (res.status === 429 && retry.count < MAX_429_RETRIES) {
|
|
||||||
await new Promise((resolve) => setTimeout(resolve, Math.pow(2, retry.count) * 500))
|
|
||||||
return fetchWith429Retry(url, options, { count: retry.count + 1 })
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
function calculateCost(modelInfo: ModelInfo, usageInfo: UsageInfo) {
|
|
||||||
const { inputTokens, outputTokens, reasoningTokens, cacheReadTokens, cacheWrite5mTokens, cacheWrite1hTokens } =
|
|
||||||
usageInfo
|
|
||||||
|
|
||||||
const modelCost =
|
|
||||||
modelInfo.cost200K &&
|
|
||||||
inputTokens + (cacheReadTokens ?? 0) + (cacheWrite5mTokens ?? 0) + (cacheWrite1hTokens ?? 0) > 200_000
|
|
||||||
? modelInfo.cost200K
|
|
||||||
: modelInfo.cost
|
|
||||||
|
|
||||||
const inputCost = modelCost.input * inputTokens * 100
|
|
||||||
const outputCost = modelCost.output * outputTokens * 100
|
|
||||||
const reasoningCost = (() => {
|
|
||||||
if (!reasoningTokens) return undefined
|
|
||||||
return modelCost.output * reasoningTokens * 100
|
|
||||||
})()
|
|
||||||
const cacheReadCost = (() => {
|
|
||||||
if (!cacheReadTokens) return undefined
|
|
||||||
if (!modelCost.cacheRead) return undefined
|
|
||||||
return modelCost.cacheRead * cacheReadTokens * 100
|
|
||||||
})()
|
|
||||||
const cacheWrite5mCost = (() => {
|
|
||||||
if (!cacheWrite5mTokens) return undefined
|
|
||||||
if (!modelCost.cacheWrite5m) return undefined
|
|
||||||
return modelCost.cacheWrite5m * cacheWrite5mTokens * 100
|
|
||||||
})()
|
|
||||||
const cacheWrite1hCost = (() => {
|
|
||||||
if (!cacheWrite1hTokens) return undefined
|
|
||||||
if (!modelCost.cacheWrite1h) return undefined
|
|
||||||
return modelCost.cacheWrite1h * cacheWrite1hTokens * 100
|
|
||||||
})()
|
|
||||||
const totalCostInCent =
|
|
||||||
inputCost +
|
|
||||||
outputCost +
|
|
||||||
(reasoningCost ?? 0) +
|
|
||||||
(cacheReadCost ?? 0) +
|
|
||||||
(cacheWrite5mCost ?? 0) +
|
|
||||||
(cacheWrite1hCost ?? 0)
|
|
||||||
return {
|
|
||||||
totalCostInCent,
|
|
||||||
inputCost,
|
|
||||||
outputCost,
|
|
||||||
reasoningCost,
|
|
||||||
cacheReadCost,
|
|
||||||
cacheWrite5mCost,
|
|
||||||
cacheWrite1hCost,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function calculateOccuredCost(billingSource: BillingSource, costInfo: CostInfo) {
|
|
||||||
return billingSource === "balance" ? (costInfo.totalCostInCent / 100).toFixed(8) : "0"
|
|
||||||
}
|
|
||||||
|
|
||||||
async function trackUsage(
|
|
||||||
billingSource: BillingSource,
|
|
||||||
authInfo: AuthInfo,
|
|
||||||
modelInfo: ModelInfo,
|
|
||||||
providerInfo: ProviderInfo,
|
|
||||||
usageInfo: UsageInfo,
|
|
||||||
costInfo: CostInfo,
|
|
||||||
) {
|
|
||||||
const { inputTokens, outputTokens, reasoningTokens, cacheReadTokens, cacheWrite5mTokens, cacheWrite1hTokens } =
|
|
||||||
usageInfo
|
|
||||||
const { totalCostInCent, inputCost, outputCost, reasoningCost, cacheReadCost, cacheWrite5mCost, cacheWrite1hCost } =
|
|
||||||
costInfo
|
|
||||||
|
|
||||||
logger.metric({
|
|
||||||
"tokens.input": inputTokens,
|
|
||||||
"tokens.output": outputTokens,
|
|
||||||
"tokens.reasoning": reasoningTokens,
|
|
||||||
"tokens.cache_read": cacheReadTokens,
|
|
||||||
"tokens.cache_write_5m": cacheWrite5mTokens,
|
|
||||||
"tokens.cache_write_1h": cacheWrite1hTokens,
|
|
||||||
"cost.input": Math.round(inputCost),
|
|
||||||
"cost.output": Math.round(outputCost),
|
|
||||||
"cost.reasoning": reasoningCost ? Math.round(reasoningCost) : undefined,
|
|
||||||
"cost.cache_read": cacheReadCost ? Math.round(cacheReadCost) : undefined,
|
|
||||||
"cost.cache_write_5m": cacheWrite5mCost ? Math.round(cacheWrite5mCost) : undefined,
|
|
||||||
"cost.cache_write_1h": cacheWrite1hCost ? Math.round(cacheWrite1hCost) : undefined,
|
|
||||||
"cost.total": Math.round(totalCostInCent),
|
|
||||||
})
|
|
||||||
|
|
||||||
if (billingSource === "anonymous") return
|
|
||||||
authInfo = authInfo!
|
|
||||||
|
|
||||||
const cost = centsToMicroCents(totalCostInCent)
|
|
||||||
await Database.use((db) =>
|
|
||||||
Promise.all([
|
|
||||||
db.insert(UsageTable).values({
|
|
||||||
workspaceID: authInfo.workspaceID,
|
|
||||||
id: Identifier.create("usage"),
|
|
||||||
model: modelInfo.id,
|
|
||||||
provider: providerInfo.id,
|
|
||||||
inputTokens,
|
|
||||||
outputTokens,
|
|
||||||
reasoningTokens,
|
|
||||||
cacheReadTokens,
|
|
||||||
cacheWrite5mTokens,
|
|
||||||
cacheWrite1hTokens,
|
|
||||||
cost,
|
|
||||||
keyID: authInfo.apiKeyId,
|
|
||||||
enrichment: billingSource === "subscription" ? { plan: "sub" } : undefined,
|
|
||||||
}),
|
|
||||||
db
|
|
||||||
.update(KeyTable)
|
|
||||||
.set({ timeUsed: sql`now()` })
|
|
||||||
.where(and(eq(KeyTable.workspaceID, authInfo.workspaceID), eq(KeyTable.id, authInfo.apiKeyId))),
|
|
||||||
...(billingSource === "subscription"
|
|
||||||
? (() => {
|
|
||||||
const plan = authInfo.billing.subscription!.plan
|
|
||||||
const black = BlackData.getLimits({ plan })
|
|
||||||
const week = getWeekBounds(new Date())
|
|
||||||
const rollingWindowSeconds = black.rollingWindow * 3600
|
|
||||||
return [
|
|
||||||
db
|
|
||||||
.update(SubscriptionTable)
|
|
||||||
.set({
|
|
||||||
fixedUsage: sql`
|
|
||||||
CASE
|
|
||||||
WHEN ${SubscriptionTable.timeFixedUpdated} >= ${week.start} THEN ${SubscriptionTable.fixedUsage} + ${cost}
|
|
||||||
ELSE ${cost}
|
|
||||||
END
|
|
||||||
`,
|
|
||||||
timeFixedUpdated: sql`now()`,
|
|
||||||
rollingUsage: sql`
|
|
||||||
CASE
|
|
||||||
WHEN UNIX_TIMESTAMP(${SubscriptionTable.timeRollingUpdated}) >= UNIX_TIMESTAMP(now()) - ${rollingWindowSeconds} THEN ${SubscriptionTable.rollingUsage} + ${cost}
|
|
||||||
ELSE ${cost}
|
|
||||||
END
|
|
||||||
`,
|
|
||||||
timeRollingUpdated: sql`
|
|
||||||
CASE
|
|
||||||
WHEN UNIX_TIMESTAMP(${SubscriptionTable.timeRollingUpdated}) >= UNIX_TIMESTAMP(now()) - ${rollingWindowSeconds} THEN ${SubscriptionTable.timeRollingUpdated}
|
|
||||||
ELSE now()
|
|
||||||
END
|
|
||||||
`,
|
|
||||||
})
|
|
||||||
.where(
|
|
||||||
and(
|
|
||||||
eq(SubscriptionTable.workspaceID, authInfo.workspaceID),
|
|
||||||
eq(SubscriptionTable.userID, authInfo.user.id),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
})()
|
|
||||||
: [
|
|
||||||
db
|
|
||||||
.update(BillingTable)
|
|
||||||
.set({
|
|
||||||
balance: authInfo.isFree
|
|
||||||
? sql`${BillingTable.balance} - ${0}`
|
|
||||||
: sql`${BillingTable.balance} - ${cost}`,
|
|
||||||
monthlyUsage: sql`
|
|
||||||
CASE
|
|
||||||
WHEN MONTH(${BillingTable.timeMonthlyUsageUpdated}) = MONTH(now()) AND YEAR(${BillingTable.timeMonthlyUsageUpdated}) = YEAR(now()) THEN ${BillingTable.monthlyUsage} + ${cost}
|
|
||||||
ELSE ${cost}
|
|
||||||
END
|
|
||||||
`,
|
|
||||||
timeMonthlyUsageUpdated: sql`now()`,
|
|
||||||
})
|
|
||||||
.where(eq(BillingTable.workspaceID, authInfo.workspaceID)),
|
|
||||||
db
|
|
||||||
.update(UserTable)
|
|
||||||
.set({
|
|
||||||
monthlyUsage: sql`
|
|
||||||
CASE
|
|
||||||
WHEN MONTH(${UserTable.timeMonthlyUsageUpdated}) = MONTH(now()) AND YEAR(${UserTable.timeMonthlyUsageUpdated}) = YEAR(now()) THEN ${UserTable.monthlyUsage} + ${cost}
|
|
||||||
ELSE ${cost}
|
|
||||||
END
|
|
||||||
`,
|
|
||||||
timeMonthlyUsageUpdated: sql`now()`,
|
|
||||||
})
|
|
||||||
.where(and(eq(UserTable.workspaceID, authInfo.workspaceID), eq(UserTable.id, authInfo.user.id))),
|
|
||||||
]),
|
|
||||||
]),
|
|
||||||
)
|
|
||||||
|
|
||||||
return { costInMicroCents: cost }
|
|
||||||
}
|
|
||||||
|
|
||||||
async function reload(billingSource: BillingSource, authInfo: AuthInfo, costInfo: CostInfo) {
|
|
||||||
if (billingSource !== "balance") return
|
|
||||||
authInfo = authInfo!
|
|
||||||
|
|
||||||
const reloadTrigger = centsToMicroCents((authInfo.billing.reloadTrigger ?? Billing.RELOAD_TRIGGER) * 100)
|
|
||||||
if (authInfo.billing.balance - costInfo.totalCostInCent >= reloadTrigger) return
|
|
||||||
if (authInfo.billing.timeReloadLockedTill && authInfo.billing.timeReloadLockedTill > new Date()) return
|
|
||||||
|
|
||||||
const lock = await Database.use((tx) =>
|
|
||||||
tx
|
|
||||||
.update(BillingTable)
|
|
||||||
.set({
|
|
||||||
timeReloadLockedTill: sql`now() + interval 1 minute`,
|
|
||||||
})
|
|
||||||
.where(
|
|
||||||
and(
|
|
||||||
eq(BillingTable.workspaceID, authInfo.workspaceID),
|
|
||||||
eq(BillingTable.reload, true),
|
|
||||||
lt(BillingTable.balance, reloadTrigger),
|
|
||||||
or(isNull(BillingTable.timeReloadLockedTill), lt(BillingTable.timeReloadLockedTill, sql`now()`)),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if (lock.rowsAffected === 0) return
|
|
||||||
|
|
||||||
await Actor.provide("system", { workspaceID: authInfo.workspaceID }, async () => {
|
|
||||||
await Billing.reload()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -3,6 +3,7 @@ import { Tool } from "./tool"
|
||||||
import TurndownService from "turndown"
|
import TurndownService from "turndown"
|
||||||
import DESCRIPTION from "./webfetch.txt"
|
import DESCRIPTION from "./webfetch.txt"
|
||||||
import { abortAfterAny } from "../util/abort"
|
import { abortAfterAny } from "../util/abort"
|
||||||
|
import { Identifier } from "../id/id"
|
||||||
|
|
||||||
const MAX_RESPONSE_SIZE = 5 * 1024 * 1024 // 5MB
|
const MAX_RESPONSE_SIZE = 5 * 1024 * 1024 // 5MB
|
||||||
const DEFAULT_TIMEOUT = 30 * 1000 // 30 seconds
|
const DEFAULT_TIMEOUT = 30 * 1000 // 30 seconds
|
||||||
|
|
@ -87,11 +88,34 @@ export const WebFetchTool = Tool.define("webfetch", {
|
||||||
throw new Error("Response too large (exceeds 5MB limit)")
|
throw new Error("Response too large (exceeds 5MB limit)")
|
||||||
}
|
}
|
||||||
|
|
||||||
const content = new TextDecoder().decode(arrayBuffer)
|
|
||||||
const contentType = response.headers.get("content-type") || ""
|
const contentType = response.headers.get("content-type") || ""
|
||||||
|
const mime = contentType.split(";")[0]?.trim().toLowerCase() || ""
|
||||||
const title = `${params.url} (${contentType})`
|
const title = `${params.url} (${contentType})`
|
||||||
|
|
||||||
|
// Check if response is an image
|
||||||
|
const isImage = mime.startsWith("image/") && mime !== "image/svg+xml" && mime !== "image/vnd.fastbidsheet"
|
||||||
|
|
||||||
|
if (isImage) {
|
||||||
|
const base64Content = Buffer.from(arrayBuffer).toString("base64")
|
||||||
|
return {
|
||||||
|
title,
|
||||||
|
output: "Image fetched successfully",
|
||||||
|
metadata: {},
|
||||||
|
attachments: [
|
||||||
|
{
|
||||||
|
id: Identifier.ascending("part"),
|
||||||
|
sessionID: ctx.sessionID,
|
||||||
|
messageID: ctx.messageID,
|
||||||
|
type: "file",
|
||||||
|
mime,
|
||||||
|
url: `data:${mime};base64,${base64Content}`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const content = new TextDecoder().decode(arrayBuffer)
|
||||||
|
|
||||||
// Handle content based on requested format and actual content type
|
// Handle content based on requested format and actual content type
|
||||||
switch (params.format) {
|
switch (params.format) {
|
||||||
case "markdown":
|
case "markdown":
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,97 @@
|
||||||
|
import { describe, expect, test } from "bun:test"
|
||||||
|
import path from "path"
|
||||||
|
import { Instance } from "../../src/project/instance"
|
||||||
|
import { WebFetchTool } from "../../src/tool/webfetch"
|
||||||
|
|
||||||
|
const projectRoot = path.join(import.meta.dir, "../..")
|
||||||
|
|
||||||
|
const ctx = {
|
||||||
|
sessionID: "test",
|
||||||
|
messageID: "message",
|
||||||
|
callID: "",
|
||||||
|
agent: "build",
|
||||||
|
abort: AbortSignal.any([]),
|
||||||
|
messages: [],
|
||||||
|
metadata: () => {},
|
||||||
|
ask: async () => {},
|
||||||
|
}
|
||||||
|
|
||||||
|
async function withFetch(
|
||||||
|
mockFetch: (input: string | URL | Request, init?: RequestInit) => Promise<Response>,
|
||||||
|
fn: () => Promise<void>,
|
||||||
|
) {
|
||||||
|
const originalFetch = globalThis.fetch
|
||||||
|
globalThis.fetch = mockFetch as unknown as typeof fetch
|
||||||
|
try {
|
||||||
|
await fn()
|
||||||
|
} finally {
|
||||||
|
globalThis.fetch = originalFetch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
describe("tool.webfetch", () => {
|
||||||
|
test("returns image responses as file attachments", async () => {
|
||||||
|
const bytes = new Uint8Array([137, 80, 78, 71, 13, 10, 26, 10])
|
||||||
|
await withFetch(
|
||||||
|
async () => new Response(bytes, { status: 200, headers: { "content-type": "IMAGE/PNG; charset=binary" } }),
|
||||||
|
async () => {
|
||||||
|
await Instance.provide({
|
||||||
|
directory: projectRoot,
|
||||||
|
fn: async () => {
|
||||||
|
const webfetch = await WebFetchTool.init()
|
||||||
|
const result = await webfetch.execute({ url: "https://example.com/image.png", format: "markdown" }, ctx)
|
||||||
|
expect(result.output).toBe("Image fetched successfully")
|
||||||
|
expect(result.attachments).toBeDefined()
|
||||||
|
expect(result.attachments?.length).toBe(1)
|
||||||
|
expect(result.attachments?.[0].type).toBe("file")
|
||||||
|
expect(result.attachments?.[0].mime).toBe("image/png")
|
||||||
|
expect(result.attachments?.[0].url.startsWith("data:image/png;base64,")).toBe(true)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
},
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
test("keeps svg as text output", async () => {
|
||||||
|
const svg = '<svg xmlns="http://www.w3.org/2000/svg"><text>hello</text></svg>'
|
||||||
|
await withFetch(
|
||||||
|
async () =>
|
||||||
|
new Response(svg, {
|
||||||
|
status: 200,
|
||||||
|
headers: { "content-type": "image/svg+xml; charset=UTF-8" },
|
||||||
|
}),
|
||||||
|
async () => {
|
||||||
|
await Instance.provide({
|
||||||
|
directory: projectRoot,
|
||||||
|
fn: async () => {
|
||||||
|
const webfetch = await WebFetchTool.init()
|
||||||
|
const result = await webfetch.execute({ url: "https://example.com/image.svg", format: "html" }, ctx)
|
||||||
|
expect(result.output).toContain("<svg")
|
||||||
|
expect(result.attachments).toBeUndefined()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
},
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
test("keeps text responses as text output", async () => {
|
||||||
|
await withFetch(
|
||||||
|
async () =>
|
||||||
|
new Response("hello from webfetch", {
|
||||||
|
status: 200,
|
||||||
|
headers: { "content-type": "text/plain; charset=utf-8" },
|
||||||
|
}),
|
||||||
|
async () => {
|
||||||
|
await Instance.provide({
|
||||||
|
directory: projectRoot,
|
||||||
|
fn: async () => {
|
||||||
|
const webfetch = await WebFetchTool.init()
|
||||||
|
const result = await webfetch.execute({ url: "https://example.com/file.txt", format: "text" }, ctx)
|
||||||
|
expect(result.output).toBe("hello from webfetch")
|
||||||
|
expect(result.attachments).toBeUndefined()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
},
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
@ -1548,8 +1548,8 @@ export type ProviderConfig = {
|
||||||
[key: string]: string
|
[key: string]: string
|
||||||
}
|
}
|
||||||
provider?: {
|
provider?: {
|
||||||
npm: string
|
npm?: string
|
||||||
api: string
|
api?: string
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* Variant-specific configuration
|
* Variant-specific configuration
|
||||||
|
|
@ -4068,8 +4068,8 @@ export type ProviderListResponses = {
|
||||||
[key: string]: string
|
[key: string]: string
|
||||||
}
|
}
|
||||||
provider?: {
|
provider?: {
|
||||||
npm: string
|
npm?: string
|
||||||
api: string
|
api?: string
|
||||||
}
|
}
|
||||||
variants?: {
|
variants?: {
|
||||||
[key: string]: {
|
[key: string]: {
|
||||||
|
|
|
||||||
|
|
@ -3800,8 +3800,7 @@
|
||||||
"api": {
|
"api": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
"required": ["npm", "api"]
|
|
||||||
},
|
},
|
||||||
"variants": {
|
"variants": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
|
@ -9405,8 +9404,7 @@
|
||||||
"api": {
|
"api": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
"required": ["npm", "api"]
|
|
||||||
},
|
},
|
||||||
"variants": {
|
"variants": {
|
||||||
"description": "Variant-specific configuration",
|
"description": "Variant-specific configuration",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue