diff --git a/packages/opencode/src/provider/provider.ts b/packages/opencode/src/provider/provider.ts index 32750221eb..c7443b2cd0 100644 --- a/packages/opencode/src/provider/provider.ts +++ b/packages/opencode/src/provider/provider.ts @@ -1,7 +1,7 @@ import z from "zod" import fuzzysort from "fuzzysort" import { Config } from "../config/config" -import { mergeDeep, sortBy } from "remeda" +import { entries, mapValues, mergeDeep, pipe, sortBy } from "remeda" import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai" import { Log } from "../util/log" import { BunProc } from "../bun" @@ -43,7 +43,7 @@ export namespace Provider { "@ai-sdk/github-copilot": createGitHubCopilotOpenAICompatible, } - type CustomLoader = (provider: ModelsDev.Provider) => Promise<{ + type CustomLoader = (provider: Info) => Promise<{ autoload: boolean getModel?: (sdk: any, modelID: string, options?: Record) => Promise options?: Record @@ -299,18 +299,155 @@ export namespace Provider { }, } - export type Model = { - providerID: string - modelID: string + export const Model = z + .object({ + id: z.string(), + providerID: z.string(), + api: z.object({ + id: z.string(), + url: z.string(), + npm: z.string(), + }), + name: z.string(), + capabilities: z.object({ + temperature: z.boolean(), + reasoning: z.boolean(), + attachment: z.boolean(), + toolcall: z.boolean(), + input: { + text: z.boolean(), + audio: z.boolean(), + image: z.boolean(), + video: z.boolean(), + pdf: z.boolean(), + }, + output: { + text: z.boolean(), + audio: z.boolean(), + image: z.boolean(), + video: z.boolean(), + pdf: z.boolean(), + }, + }), + cost: z.object({ + input: z.number(), + output: z.number(), + cache: z.object({ + read: z.number(), + write: z.number(), + }), + experimentalOver200K: z + .object({ + input: z.number(), + output: z.number(), + cache: z.object({ + read: z.number(), + write: z.number(), + }), + }) + .optional(), + }), + limit: z.object({ + context: z.number(), + output: z.number(), + }), + status: z.enum(["alpha", "beta", "deprecated", "active"]), + options: z.record(z.string(), z.any()), + headers: z.record(z.string(), z.string()), + }) + .meta({ + ref: "Model", + }) + export type Model = z.infer + + export const Info = z.object({ + id: z.string(), + name: z.string(), + source: z.enum(["env", "config", "custom", "api"]), + env: z.string().array(), + key: z.string().optional(), + options: z.record(z.string(), z.any()), + models: z.record(z.string(), Model), + }) + export type Info = z.infer + + function fromModelsDevModel(provider: ModelsDev.Provider, model: ModelsDev.Model): Model { + return { + id: model.id, + name: model.name, + api: { + id: model.id, + url: provider.api!, + npm: model.provider?.npm ?? provider.npm ?? provider.id, + }, + status: model.status ?? "active", + headers: model.headers ?? {}, + options: model.options ?? {}, + cost: { + input: model.cost.input, + output: model.cost.output, + cache: { + read: model.cost.cache_read ?? 0, + write: model.cost.cache_write ?? 0, + }, + experimentalOver200K: model.cost.context_over_200k + ? { + cache: { + read: model.cost.context_over_200k.cache_read ?? 0, + write: model.cost.context_over_200k.cache_write ?? 0, + }, + input: model.cost.context_over_200k.input, + output: model.cost.context_over_200k.output, + } + : undefined, + }, + limit: { + context: model.limit.context, + output: model.limit.output, + }, + capabilities: { + temperature: model.temperature, + reasoning: model.reasoning, + attachment: model.attachment, + toolcall: model.tool_call, + input: { + text: model.modalities?.input?.includes("text") ?? false, + audio: model.modalities?.input?.includes("audio") ?? false, + image: model.modalities?.input?.includes("image") ?? false, + video: model.modalities?.input?.includes("video") ?? false, + pdf: model.modalities?.input?.includes("pdf") ?? false, + }, + output: { + text: model.modalities?.output?.includes("text") ?? false, + audio: model.modalities?.output?.includes("audio") ?? false, + image: model.modalities?.output?.includes("image") ?? false, + video: model.modalities?.output?.includes("video") ?? false, + pdf: model.modalities?.output?.includes("pdf") ?? false, + }, + }, + } + } + + function fromModelsDevProvider(provider: ModelsDev.Provider): Info { + return { + id: provider.id, + source: "custom", + name: provider.name, + env: provider.env ?? [], + options: {}, + models: mapValues(provider.models, (model) => fromModelsDevModel(provider, model)), + } + } + + export type ModelWithStuff = { language: LanguageModel - info: ModelsDev.Model - npm: string + info: Model } const state = Instance.state(async () => { using _ = log.time("state") const config = await Config.get() - const database = await ModelsDev.get() + const database = mapValues(await ModelsDev.get(), fromModelsDevProvider) const disabled = new Set(config.disabled_providers ?? []) const enabled = config.enabled_providers ? new Set(config.enabled_providers) : null @@ -321,43 +458,12 @@ export namespace Provider { return true } - const providers: { - [providerID: string]: { - source: Source - info: ModelsDev.Provider - getModel?: (sdk: any, modelID: string, options?: Record) => Promise - options: Record - } - } = {} - const models = new Map() + const providers: { [providerID: string]: Info } = {} + const models = new Map() const sdk = new Map() log.info("init") - function mergeProvider( - id: string, - options: Record, - source: Source, - getModel?: (sdk: any, modelID: string, options?: Record) => Promise, - ) { - const provider = providers[id] - if (!provider) { - const info = database[id] - if (!info) return - if (info.api && !options["baseURL"]) options["baseURL"] = info.api - providers[id] = { - source, - info, - options, - getModel, - } - return - } - provider.options = mergeDeep(provider.options, options) - provider.source = source - provider.getModel = getModel ?? provider.getModel - } - const configProviders = Object.entries(config.provider ?? {}) // Add GitHub Copilot Enterprise provider that inherits from GitHub Copilot @@ -367,11 +473,17 @@ export namespace Provider { ...githubCopilot, id: "github-copilot-enterprise", name: "GitHub Copilot Enterprise", - // Enterprise uses a different API endpoint - will be set dynamically based on auth - api: undefined, } } + function mergeProvider(providerID: string, provider: Partial) { + const match = database[providerID] + if (!match) return + // @ts-expect-error + providers[providerID] = mergeDeep(match, provider) + } + + // TODO: load config for (const [providerID, provider] of configProviders) { const existing = database[providerID] const parsed: ModelsDev.Provider = { @@ -390,29 +502,27 @@ export namespace Provider { if (model.id && model.id !== modelID) return modelID return existing?.name ?? modelID }) - const parsedModel: ModelsDev.Model = { + const parsedModel: Model = { id: modelID, - target: model.target ?? existing?.target ?? modelID, + apiID: model.target ?? existing?.target ?? modelID, + status: model.status ?? existing?.status ?? "alpha", name, - release_date: model.release_date ?? existing?.release_date, - attachment: model.attachment ?? existing?.attachment ?? false, - reasoning: model.reasoning ?? existing?.reasoning ?? false, - temperature: model.temperature ?? existing?.temperature ?? false, - tool_call: model.tool_call ?? existing?.tool_call ?? true, - cost: - !model.cost && !existing?.cost - ? { - input: 0, - output: 0, - cache_read: 0, - cache_write: 0, - } - : { - cache_read: 0, - cache_write: 0, - ...existing?.cost, - ...model.cost, - }, + providerID, + npm: model.provider?.npm ?? existing?.provider?.npm ?? provider.npm ?? providerID, + support: { + temperature: model.temperature ?? existing?.temperature ?? false, + reasoning: model.reasoning ?? existing?.reasoning ?? false, + attachment: model.attachment ?? existing?.attachment ?? false, + toolcall: model.tool_call ?? existing?.tool_call ?? true, + }, + cost: { + input: model?.cost?.input ?? existing?.cost?.input ?? 0, + output: model?.cost?.output ?? existing?.cost?.output ?? 0, + cache: { + read: model?.cost?.cache_read ?? existing?.cost?.cache_read ?? 0, + write: model?.cost?.cache_write ?? existing?.cost?.cache_write ?? 0, + }, + }, options: { ...existing?.options, ...model.options, @@ -427,8 +537,7 @@ export namespace Provider { input: ["text"], output: ["text"], }, - headers: model.headers, - provider: model.provider ?? existing?.provider, + headers: model.headers ?? {}, } parsed.models[modelID] = parsedModel } @@ -442,19 +551,20 @@ export namespace Provider { if (disabled.has(providerID)) continue const apiKey = provider.env.map((item) => env[item]).find(Boolean) if (!apiKey) continue - mergeProvider( - providerID, - // only include apiKey if there's only one potential option - provider.env.length === 1 ? { apiKey } : {}, - "env", - ) + mergeProvider(providerID, { + source: "env", + key: provider.env.length === 1 ? apiKey : undefined, + }) } // load apikeys for (const [providerID, provider] of Object.entries(await Auth.all())) { if (disabled.has(providerID)) continue if (provider.type === "api") { - mergeProvider(providerID, { apiKey: provider.key }, "api") + mergeProvider(providerID, { + source: "api", + key: provider.key, + }) } } @@ -480,7 +590,10 @@ export namespace Provider { // Load for the main provider if auth exists if (auth) { const options = await plugin.auth.loader(() => Auth.get(providerID) as any, database[plugin.auth.provider]) - mergeProvider(plugin.auth.provider, options ?? {}, "custom") + mergeProvider(plugin.auth.provider, { + source: "custom", + options: options, + }) } // If this is github-copilot plugin, also register for github-copilot-enterprise if auth exists @@ -493,7 +606,10 @@ export namespace Provider { () => Auth.get(enterpriseProviderID) as any, database[enterpriseProviderID], ) - mergeProvider(enterpriseProviderID, enterpriseOptions ?? {}, "custom") + mergeProvider(enterpriseProviderID, { + source: "custom", + options: enterpriseOptions, + }) } } } @@ -503,13 +619,22 @@ export namespace Provider { if (disabled.has(providerID)) continue const result = await fn(database[providerID]) if (result && (result.autoload || providers[providerID])) { - mergeProvider(providerID, result.options ?? {}, "custom", result.getModel) + mergeProvider(providerID, { + source: "custom", + options: result.options, + }) } } // load config for (const [providerID, provider] of configProviders) { - mergeProvider(providerID, provider.options ?? {}, "config") + mergeProvider(providerID, { + source: "config", + env: provider.env, + name: provider.name, + options: provider.options, + // TODO: merge models + }) } for (const [providerID, provider] of Object.entries(providers)) { @@ -519,33 +644,36 @@ export namespace Provider { } if (providerID === "github-copilot" || providerID === "github-copilot-enterprise") { - provider.info.npm = "@ai-sdk/github-copilot" + provider.models = mapValues(provider.models, (model) => ({ + ...model, + api: { + ...model.api, + npm: "@ai-sdk/github-copilot", + }, + })) } const configProvider = config.provider?.[providerID] - for (const [modelID, model] of Object.entries(provider.info.models)) { - model.target = model.target ?? model.id ?? modelID + for (const [modelID, model] of Object.entries(provider.models)) { + model.api.id = model.api.id ?? model.id ?? modelID if (modelID === "gpt-5-chat-latest" || (providerID === "openrouter" && modelID === "openai/gpt-5-chat")) - delete provider.info.models[modelID] - if ( - ((model.status === "alpha" || model.experimental) && !Flag.OPENCODE_ENABLE_EXPERIMENTAL_MODELS) || - model.status === "deprecated" - ) - delete provider.info.models[modelID] + delete provider.models[modelID] + if ((model.status === "alpha" && !Flag.OPENCODE_ENABLE_EXPERIMENTAL_MODELS) || model.status === "deprecated") + delete provider.models[modelID] if ( (configProvider?.blacklist && configProvider.blacklist.includes(modelID)) || (configProvider?.whitelist && !configProvider.whitelist.includes(modelID)) ) - delete provider.info.models[modelID] + delete provider.models[modelID] } - if (Object.keys(provider.info.models).length === 0) { + if (Object.keys(provider.models).length === 0) { delete providers[providerID] continue } - log.info("found", { providerID, npm: provider.info.npm }) + log.info("found", { providerID }) } return { @@ -559,18 +687,28 @@ export namespace Provider { return state().then((state) => state.providers) } - async function getSDK(npm: string, providerID: string) { + async function getSDK(model: Model) { try { using _ = log.time("getSDK", { - providerID, + providerID: model.providerID, }) const s = await state() - const options = { ...s.providers[providerID]?.options } - if (npm.includes("@ai-sdk/openai-compatible") && options["includeUsage"] !== false) { + const provider = s.providers[model.providerID] + const options = { ...provider.options } + + if (model.api.npm.includes("@ai-sdk/openai-compatible") && options["includeUsage"] !== false) { options["includeUsage"] = true } - const key = Bun.hash.xxHash32(JSON.stringify({ pkg: npm, options })) + if (!options["baseURL"]) options["baseURL"] = model.api.url + if (!options["apiKey"]) options["apiKey"] = provider.key + if (model.headers) + options["headers"] = { + ...options["headers"], + ...model.headers, + } + + const key = Bun.hash.xxHash32(JSON.stringify({ npm: model.api.npm, options })) const existing = s.sdk.get(key) if (existing) return existing @@ -599,12 +737,13 @@ export namespace Provider { } // Special case: google-vertex-anthropic uses a subpath import - const bundledKey = providerID === "google-vertex-anthropic" ? "@ai-sdk/google-vertex/anthropic" : npm + const bundledKey = + model.providerID === "google-vertex-anthropic" ? "@ai-sdk/google-vertex/anthropic" : model.api.npm const bundledFn = BUNDLED_PROVIDERS[bundledKey] if (bundledFn) { - log.info("using bundled provider", { providerID, pkg: bundledKey }) + log.info("using bundled provider", { providerID: model.providerID, pkg: bundledKey }) const loaded = bundledFn({ - name: providerID, + name: model.providerID, ...options, }) s.sdk.set(key, loaded) @@ -612,24 +751,24 @@ export namespace Provider { } let installedPath: string - if (!npm.startsWith("file://")) { - installedPath = await BunProc.install(npm, "latest") + if (!model.api.npm.startsWith("file://")) { + installedPath = await BunProc.install(model.api.npm, "latest") } else { - log.info("loading local provider", { pkg: npm }) - installedPath = npm + log.info("loading local provider", { pkg: model.api.npm }) + installedPath = model.api.npm } const mod = await import(installedPath) const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!] const loaded = fn({ - name: providerID, + name: model.providerID, ...options, }) s.sdk.set(key, loaded) return loaded as SDK } catch (e) { - throw new InitError({ providerID }, { cause: e }) + throw new InitError({ providerID: model.providerID }, { cause: e }) } } @@ -655,22 +794,25 @@ export namespace Provider { throw new ModelNotFoundError({ providerID, modelID, suggestions }) } - const info = provider.info.models[modelID] + const info = provider.models[modelID] if (!info) { - const availableModels = Object.keys(provider.info.models) + const availableModels = Object.keys(provider.models) const matches = fuzzysort.go(modelID, availableModels, { limit: 3, threshold: -10000 }) const suggestions = matches.map((m) => m.target) throw new ModelNotFoundError({ providerID, modelID, suggestions }) } - const npm = info.provider?.npm ?? provider.info.npm ?? info.id - const sdk = await getSDK(npm, providerID) + const sdk = await getSDK(info) try { const language = provider.getModel - ? await provider.getModel(sdk, info.target, provider.options) - : sdk.languageModel(info.target) + ? await provider.getModel(sdk, info.api.id, provider.options) + : sdk.languageModel(info.api.id) log.info("found", { providerID, modelID }) + const cached: ModelWithStuff = { + info, + language, + } s.models.set(key, { providerID, modelID, @@ -755,7 +897,7 @@ export namespace Provider { } const priority = ["gpt-5", "claude-sonnet-4", "big-pickle", "gemini-3-pro"] - export function sort(models: ModelsDev.Model[]) { + export function sort(models: Model[]) { return sortBy( models, [(model) => priority.findIndex((filter) => model.id.includes(filter)), "desc"],