diff --git a/packages/opencode/src/config/config.ts b/packages/opencode/src/config/config.ts index 2bdbbca5b0..6ca8fe55be 100644 --- a/packages/opencode/src/config/config.ts +++ b/packages/opencode/src/config/config.ts @@ -542,7 +542,15 @@ export namespace Config { .extend({ whitelist: z.array(z.string()).optional(), blacklist: z.array(z.string()).optional(), - models: z.record(z.string(), ModelsDev.Model.partial()).optional(), + models: z + .record( + z.string(), + ModelsDev.Model.partial().refine( + (input) => input.id === undefined, + "The model.id field can no longer be specified. Use model.target to specify an alternate model id to use when calling the provider.", + ), + ) + .optional(), options: z .object({ apiKey: z.string().optional(), diff --git a/packages/opencode/src/provider/models.ts b/packages/opencode/src/provider/models.ts index 676837e152..f8ff2e86a4 100644 --- a/packages/opencode/src/provider/models.ts +++ b/packages/opencode/src/provider/models.ts @@ -13,6 +13,7 @@ export namespace ModelsDev { .object({ id: z.string(), name: z.string(), + target: z.string(), release_date: z.string(), attachment: z.boolean(), reasoning: z.boolean(), diff --git a/packages/opencode/src/provider/provider.ts b/packages/opencode/src/provider/provider.ts index 1123e6bbed..32750221eb 100644 --- a/packages/opencode/src/provider/provider.ts +++ b/packages/opencode/src/provider/provider.ts @@ -280,7 +280,7 @@ export namespace Provider { project, location, }, - async getModel(sdk: any, modelID: string) { + async getModel(sdk, modelID) { const id = String(modelID).trim() return sdk.languageModel(id) }, @@ -299,6 +299,14 @@ export namespace Provider { }, } + export type Model = { + providerID: string + modelID: string + language: LanguageModel + info: ModelsDev.Model + npm: string + } + const state = Instance.state(async () => { using _ = log.time("state") const config = await Config.get() @@ -321,19 +329,8 @@ export namespace Provider { options: Record } } = {} - const models = new Map< - string, - { - providerID: string - modelID: string - info: ModelsDev.Model - language: LanguageModel - npm?: string - } - >() + const models = new Map() const sdk = new Map() - // Maps `${provider}/${key}` to the provider’s actual model ID for custom aliases. - const realIdByKey = new Map() log.info("init") @@ -395,6 +392,7 @@ export namespace Provider { }) const parsedModel: ModelsDev.Model = { id: modelID, + target: model.target ?? existing?.target ?? modelID, name, release_date: model.release_date ?? existing?.release_date, attachment: model.attachment ?? existing?.attachment ?? false, @@ -432,9 +430,6 @@ export namespace Provider { headers: model.headers, provider: model.provider ?? existing?.provider, } - if (model.id && model.id !== modelID) { - realIdByKey.set(`${providerID}/${modelID}`, model.id) - } parsed.models[modelID] = parsedModel } @@ -528,31 +523,22 @@ export namespace Provider { } const configProvider = config.provider?.[providerID] - const filteredModels = Object.fromEntries( - Object.entries(provider.info.models) - // Filter out blacklisted models - .filter( - ([modelID]) => - modelID !== "gpt-5-chat-latest" && !(providerID === "openrouter" && modelID === "openai/gpt-5-chat"), - ) - // Filter out experimental models - .filter( - ([, model]) => - ((!model.experimental && model.status !== "alpha") || Flag.OPENCODE_ENABLE_EXPERIMENTAL_MODELS) && - model.status !== "deprecated", - ) - // Filter by provider's whitelist/blacklist from config - .filter(([modelID]) => { - if (!configProvider) return true - return ( - (!configProvider.blacklist || !configProvider.blacklist.includes(modelID)) && - (!configProvider.whitelist || configProvider.whitelist.includes(modelID)) - ) - }), - ) - - provider.info.models = filteredModels + for (const [modelID, model] of Object.entries(provider.info.models)) { + model.target = model.target ?? model.id ?? modelID + if (modelID === "gpt-5-chat-latest" || (providerID === "openrouter" && modelID === "openai/gpt-5-chat")) + delete provider.info.models[modelID] + if ( + ((model.status === "alpha" || model.experimental) && !Flag.OPENCODE_ENABLE_EXPERIMENTAL_MODELS) || + model.status === "deprecated" + ) + delete provider.info.models[modelID] + if ( + (configProvider?.blacklist && configProvider.blacklist.includes(modelID)) || + (configProvider?.whitelist && !configProvider.whitelist.includes(modelID)) + ) + delete provider.info.models[modelID] + } if (Object.keys(provider.info.models).length === 0) { delete providers[providerID] @@ -566,7 +552,6 @@ export namespace Provider { models, providers, sdk, - realIdByKey, } }) @@ -574,19 +559,18 @@ export namespace Provider { return state().then((state) => state.providers) } - async function getSDK(provider: ModelsDev.Provider, model: ModelsDev.Model) { - return (async () => { + async function getSDK(npm: string, providerID: string) { + try { using _ = log.time("getSDK", { - providerID: provider.id, + providerID, }) const s = await state() - const pkg = model.provider?.npm ?? provider.npm ?? provider.id - const options = { ...s.providers[provider.id]?.options } - if (pkg.includes("@ai-sdk/openai-compatible") && options["includeUsage"] === undefined) { + const options = { ...s.providers[providerID]?.options } + if (npm.includes("@ai-sdk/openai-compatible") && options["includeUsage"] !== false) { options["includeUsage"] = true } - const key = Bun.hash.xxHash32(JSON.stringify({ pkg, options })) + const key = Bun.hash.xxHash32(JSON.stringify({ pkg: npm, options })) const existing = s.sdk.get(key) if (existing) return existing @@ -615,12 +599,12 @@ export namespace Provider { } // Special case: google-vertex-anthropic uses a subpath import - const bundledKey = provider.id === "google-vertex-anthropic" ? "@ai-sdk/google-vertex/anthropic" : pkg + const bundledKey = providerID === "google-vertex-anthropic" ? "@ai-sdk/google-vertex/anthropic" : npm const bundledFn = BUNDLED_PROVIDERS[bundledKey] if (bundledFn) { - log.info("using bundled provider", { providerID: provider.id, pkg: bundledKey }) + log.info("using bundled provider", { providerID, pkg: bundledKey }) const loaded = bundledFn({ - name: provider.id, + name: providerID, ...options, }) s.sdk.set(key, loaded) @@ -628,25 +612,25 @@ export namespace Provider { } let installedPath: string - if (!pkg.startsWith("file://")) { - installedPath = await BunProc.install(pkg, "latest") + if (!npm.startsWith("file://")) { + installedPath = await BunProc.install(npm, "latest") } else { - log.info("loading local provider", { pkg }) - installedPath = pkg + log.info("loading local provider", { pkg: npm }) + installedPath = npm } const mod = await import(installedPath) const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!] const loaded = fn({ - name: provider.id, + name: providerID, ...options, }) s.sdk.set(key, loaded) return loaded as SDK - })().catch((e) => { - throw new InitError({ providerID: provider.id }, { cause: e }) - }) + } catch (e) { + throw new InitError({ providerID }, { cause: e }) + } } export async function getProvider(providerID: string) { @@ -679,28 +663,27 @@ export namespace Provider { throw new ModelNotFoundError({ providerID, modelID, suggestions }) } - const sdk = await getSDK(provider.info, info) + const npm = info.provider?.npm ?? provider.info.npm ?? info.id + const sdk = await getSDK(npm, providerID) try { - const keyReal = `${providerID}/${modelID}` - const realID = s.realIdByKey.get(keyReal) ?? info.id const language = provider.getModel - ? await provider.getModel(sdk, realID, provider.options) - : sdk.languageModel(realID) + ? await provider.getModel(sdk, info.target, provider.options) + : sdk.languageModel(info.target) log.info("found", { providerID, modelID }) s.models.set(key, { providerID, modelID, info, language, - npm: info.provider?.npm ?? provider.info.npm, + npm, }) return { modelID, providerID, info, language, - npm: info.provider?.npm ?? provider.info.npm, + npm, } } catch (e) { if (e instanceof NoSuchModelError) diff --git a/packages/opencode/src/provider/transform.ts b/packages/opencode/src/provider/transform.ts index abe269d5d0..d3b30575ad 100644 --- a/packages/opencode/src/provider/transform.ts +++ b/packages/opencode/src/provider/transform.ts @@ -1,10 +1,11 @@ import type { APICallError, ModelMessage } from "ai" import { unique } from "remeda" import type { JSONSchema } from "zod/v4/core" +import type { ModelsDev } from "./models" export namespace ProviderTransform { - function normalizeMessages(msgs: ModelMessage[], providerID: string, modelID: string): ModelMessage[] { - if (modelID.includes("claude")) { + function normalizeMessages(msgs: ModelMessage[], providerID: string, model: ModelsDev.Model): ModelMessage[] { + if (model.target.includes("claude")) { return msgs.map((msg) => { if ((msg.role === "assistant" || msg.role === "tool") && Array.isArray(msg.content)) { msg.content = msg.content.map((part) => { @@ -20,7 +21,7 @@ export namespace ProviderTransform { return msg }) } - if (providerID === "mistral" || modelID.toLowerCase().includes("mistral")) { + if (providerID === "mistral" || model.target.toLowerCase().includes("mistral")) { const result: ModelMessage[] = [] for (let i = 0; i < msgs.length; i++) { const msg = msgs[i] @@ -107,30 +108,30 @@ export namespace ProviderTransform { return msgs } - export function message(msgs: ModelMessage[], providerID: string, modelID: string) { - msgs = normalizeMessages(msgs, providerID, modelID) - if (providerID === "anthropic" || modelID.includes("anthropic") || modelID.includes("claude")) { + export function message(msgs: ModelMessage[], providerID: string, model: ModelsDev.Model) { + msgs = normalizeMessages(msgs, providerID, model) + if (providerID === "anthropic" || model.target.includes("anthropic") || model.target.includes("claude")) { msgs = applyCaching(msgs, providerID) } return msgs } - export function temperature(_providerID: string, modelID: string) { - if (modelID.toLowerCase().includes("qwen")) return 0.55 - if (modelID.toLowerCase().includes("claude")) return undefined - if (modelID.toLowerCase().includes("gemini-3-pro")) return 1.0 + export function temperature(model: ModelsDev.Model) { + if (model.target.toLowerCase().includes("qwen")) return 0.55 + if (model.target.toLowerCase().includes("claude")) return undefined + if (model.target.toLowerCase().includes("gemini-3-pro")) return 1.0 return 0 } - export function topP(_providerID: string, modelID: string) { - if (modelID.toLowerCase().includes("qwen")) return 1 + export function topP(model: ModelsDev.Model) { + if (model.target.toLowerCase().includes("qwen")) return 1 return undefined } export function options( providerID: string, - modelID: string, + model: ModelsDev.Model, npm: string, sessionID: string, providerOptions?: Record, @@ -148,22 +149,22 @@ export namespace ProviderTransform { result["promptCacheKey"] = sessionID } - if (providerID === "google" || (providerID.startsWith("opencode") && modelID.includes("gemini-3"))) { + if (providerID === "google" || (providerID.startsWith("opencode") && model.target.includes("gemini-3"))) { result["thinkingConfig"] = { includeThoughts: true, } } - if (modelID.includes("gpt-5") && !modelID.includes("gpt-5-chat")) { - if (modelID.includes("codex")) { + if (model.target.includes("gpt-5") && !model.target.includes("gpt-5-chat")) { + if (model.target.includes("codex")) { result["store"] = false } - if (!modelID.includes("codex") && !modelID.includes("gpt-5-pro")) { + if (!model.target.includes("codex") && !model.target.includes("gpt-5-pro")) { result["reasoningEffort"] = "medium" } - if (modelID.endsWith("gpt-5.1") && providerID !== "azure") { + if (model.target.endsWith("gpt-5.1") && providerID !== "azure") { result["textVerbosity"] = "low" } @@ -176,11 +177,11 @@ export namespace ProviderTransform { return result } - export function smallOptions(input: { providerID: string; modelID: string }) { + export function smallOptions(input: { providerID: string; model: ModelsDev.Model }) { const options: Record = {} - if (input.providerID === "openai" || input.modelID.includes("gpt-5")) { - if (input.modelID.includes("5.1")) { + if (input.providerID === "openai" || input.model.target.includes("gpt-5")) { + if (input.model.target.includes("5.1")) { options["reasoningEffort"] = "low" } else { options["reasoningEffort"] = "minimal" @@ -254,7 +255,7 @@ export namespace ProviderTransform { return standardLimit } - export function schema(providerID: string, modelID: string, schema: JSONSchema.BaseSchema) { + export function schema(providerID: string, model: ModelsDev.Model, schema: JSONSchema.BaseSchema) { /* if (["openai", "azure"].includes(providerID)) { if (schema.type === "object" && schema.properties) { @@ -274,7 +275,7 @@ export namespace ProviderTransform { */ // Convert integer enums to string enums for Google/Gemini - if (providerID === "google" || modelID.includes("gemini")) { + if (providerID === "google" || model.target.includes("gemini")) { const sanitizeGemini = (obj: any): any => { if (obj === null || typeof obj !== "object") { return obj diff --git a/packages/opencode/src/server/server.ts b/packages/opencode/src/server/server.ts index fe4ad195aa..4dfd3ac743 100644 --- a/packages/opencode/src/server/server.ts +++ b/packages/opencode/src/server/server.ts @@ -296,8 +296,8 @@ export namespace Server { }), ), async (c) => { - const { provider, model } = c.req.valid("query") - const tools = await ToolRegistry.tools(provider, model) + const { provider } = c.req.valid("query") + const tools = await ToolRegistry.tools(provider) return c.json( tools.map((t) => ({ id: t.id, diff --git a/packages/opencode/src/session/compaction.ts b/packages/opencode/src/session/compaction.ts index a6b71edcef..d4c9eb99a3 100644 --- a/packages/opencode/src/session/compaction.ts +++ b/packages/opencode/src/session/compaction.ts @@ -1,4 +1,4 @@ -import { streamText, wrapLanguageModel, type ModelMessage } from "ai" +import { wrapLanguageModel, type ModelMessage } from "ai" import { Session } from "." import { Identifier } from "../id/id" import { Instance } from "../project/instance" @@ -130,75 +130,73 @@ export namespace SessionCompaction { model: model.info, abort: input.abort, }) - const result = await processor.process(() => - streamText({ - onError(error) { - log.error("stream error", { - error, - }) - }, - // set to 0, we handle loop - maxRetries: 0, - providerOptions: ProviderTransform.providerOptions( - model.npm, - model.providerID, - pipe( - {}, - mergeDeep(ProviderTransform.options(model.providerID, model.modelID, model.npm ?? "", input.sessionID)), - mergeDeep(model.info.options), - ), + const result = await processor.process({ + onError(error) { + log.error("stream error", { + error, + }) + }, + // set to 0, we handle loop + maxRetries: 0, + providerOptions: ProviderTransform.providerOptions( + model.npm, + model.providerID, + pipe( + {}, + mergeDeep(ProviderTransform.options(model.providerID, model.info, model.npm ?? "", input.sessionID)), + mergeDeep(model.info.options), ), - headers: model.info.headers, - abortSignal: input.abort, - tools: model.info.tool_call ? {} : undefined, - messages: [ - ...system.map( - (x): ModelMessage => ({ - role: "system", - content: x, - }), - ), - ...MessageV2.toModelMessage( - input.messages.filter((m) => { - if (m.info.role !== "assistant" || m.info.error === undefined) { - return true - } - if ( - MessageV2.AbortedError.isInstance(m.info.error) && - m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning") - ) { - return true - } + ), + headers: model.info.headers, + abortSignal: input.abort, + tools: model.info.tool_call ? {} : undefined, + messages: [ + ...system.map( + (x): ModelMessage => ({ + role: "system", + content: x, + }), + ), + ...MessageV2.toModelMessage( + input.messages.filter((m) => { + if (m.info.role !== "assistant" || m.info.error === undefined) { + return true + } + if ( + MessageV2.AbortedError.isInstance(m.info.error) && + m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning") + ) { + return true + } - return false - }), - ), - { - role: "user", - content: [ - { - type: "text", - text: "Summarize our conversation above. This summary will be the only context available when the conversation continues, so preserve critical information including: what was accomplished, current work in progress, files involved, next steps, and any key user requests or constraints. Be concise but detailed enough that work can continue seamlessly.", - }, - ], - }, - ], - model: wrapLanguageModel({ - model: model.language, - middleware: [ + return false + }), + ), + { + role: "user", + content: [ { - async transformParams(args) { - if (args.type === "stream") { - // @ts-expect-error - args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID) - } - return args.params - }, + type: "text", + text: "Summarize our conversation above. This summary will be the only context available when the conversation continues, so preserve critical information including: what was accomplished, current work in progress, files involved, next steps, and any key user requests or constraints. Be concise but detailed enough that work can continue seamlessly.", }, ], - }), + }, + ], + model: wrapLanguageModel({ + model: model.language, + middleware: [ + { + async transformParams(args) { + if (args.type === "stream") { + // @ts-expect-error + args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID) + } + return args.params + }, + }, + ], }), - ) + }) if (result === "continue" && input.auto) { const continueMsg = await Session.updateMessage({ id: Identifier.ascending("message"), diff --git a/packages/opencode/src/session/processor.ts b/packages/opencode/src/session/processor.ts index 8655781d5e..5823d6191c 100644 --- a/packages/opencode/src/session/processor.ts +++ b/packages/opencode/src/session/processor.ts @@ -1,6 +1,6 @@ import type { ModelsDev } from "@/provider/models" import { MessageV2 } from "./message-v2" -import { type StreamTextResult, type Tool as AITool, APICallError } from "ai" +import { streamText } from "ai" import { Log } from "@/util/log" import { Identifier } from "@/id/id" import { Session } from "." @@ -19,6 +19,15 @@ export namespace SessionProcessor { export type Info = Awaited> export type Result = Awaited> + export type StreamInput = Parameters[0] + + export type TBD = { + model: { + modelID: string + providerID: string + } + } + export function create(input: { assistantMessage: MessageV2.Assistant sessionID: string @@ -38,13 +47,13 @@ export namespace SessionProcessor { partFromToolCall(toolCallID: string) { return toolcalls[toolCallID] }, - async process(fn: () => StreamTextResult, never>) { + async process(streamInput: StreamInput) { log.info("process") while (true) { try { let currentText: MessageV2.TextPart | undefined let reasoningMap: Record = {} - const stream = fn() + const stream = streamText(streamInput) for await (const value of stream.fullStream) { input.abort.throwIfAborted() diff --git a/packages/opencode/src/session/prompt.ts b/packages/opencode/src/session/prompt.ts index ee58bb3380..2433c582b4 100644 --- a/packages/opencode/src/session/prompt.ts +++ b/packages/opencode/src/session/prompt.ts @@ -11,7 +11,6 @@ import { Agent } from "../agent/agent" import { Provider } from "../provider/provider" import { generateText, - streamText, type ModelMessage, type Tool as AITool, tool, @@ -48,6 +47,7 @@ import { fn } from "@/util/fn" import { SessionProcessor } from "./processor" import { TaskTool } from "@/tool/task" import { SessionStatus } from "./status" +import type { ModelsDev } from "@/provider/models" // @ts-ignore globalThis.AI_SDK_LOG_WARNINGS = false @@ -469,14 +469,15 @@ export namespace SessionPrompt { }) const system = await resolveSystemPrompt({ providerID: model.providerID, - modelID: model.info.id, + model: model.info, agent, system: lastUser.system, }) const tools = await resolveTools({ agent, sessionID, - model: lastUser.model, + providerID: model.providerID, + model: model.info, tools: lastUser.tools, processor, }) @@ -492,14 +493,12 @@ export namespace SessionPrompt { }, { temperature: model.info.temperature - ? (agent.temperature ?? ProviderTransform.temperature(model.providerID, model.modelID)) + ? (agent.temperature ?? ProviderTransform.temperature(model.info)) : undefined, - topP: agent.topP ?? ProviderTransform.topP(model.providerID, model.modelID), + topP: agent.topP ?? ProviderTransform.topP(model.info), options: pipe( {}, - mergeDeep( - ProviderTransform.options(model.providerID, model.modelID, model.npm ?? "", sessionID, provider?.options), - ), + mergeDeep(ProviderTransform.options(model.providerID, model.info, model.npm, sessionID, provider?.options)), mergeDeep(model.info.options), mergeDeep(agent.options), ), @@ -513,113 +512,111 @@ export namespace SessionPrompt { }) } - const result = await processor.process(() => - streamText({ - onError(error) { - log.error("stream error", { - error, + const result = await processor.process({ + onError(error) { + log.error("stream error", { + error, + }) + }, + async experimental_repairToolCall(input) { + const lower = input.toolCall.toolName.toLowerCase() + if (lower !== input.toolCall.toolName && tools[lower]) { + log.info("repairing tool call", { + tool: input.toolCall.toolName, + repaired: lower, }) - }, - async experimental_repairToolCall(input) { - const lower = input.toolCall.toolName.toLowerCase() - if (lower !== input.toolCall.toolName && tools[lower]) { - log.info("repairing tool call", { - tool: input.toolCall.toolName, - repaired: lower, - }) - return { - ...input.toolCall, - toolName: lower, - } - } return { ...input.toolCall, - input: JSON.stringify({ - tool: input.toolCall.toolName, - error: input.error.message, - }), - toolName: "invalid", + toolName: lower, } - }, - headers: { - ...(model.providerID.startsWith("opencode") - ? { - "x-opencode-project": Instance.project.id, - "x-opencode-session": sessionID, - "x-opencode-request": lastUser.id, - } - : undefined), - ...model.info.headers, - }, - // set to 0, we handle loop - maxRetries: 0, - activeTools: Object.keys(tools).filter((x) => x !== "invalid"), - maxOutputTokens: ProviderTransform.maxOutputTokens( - model.providerID, - params.options, - model.info.limit.output, - OUTPUT_TOKEN_MAX, + } + return { + ...input.toolCall, + input: JSON.stringify({ + tool: input.toolCall.toolName, + error: input.error.message, + }), + toolName: "invalid", + } + }, + headers: { + ...(model.providerID.startsWith("opencode") + ? { + "x-opencode-project": Instance.project.id, + "x-opencode-session": sessionID, + "x-opencode-request": lastUser.id, + } + : undefined), + ...model.info.headers, + }, + // set to 0, we handle loop + maxRetries: 0, + activeTools: Object.keys(tools).filter((x) => x !== "invalid"), + maxOutputTokens: ProviderTransform.maxOutputTokens( + model.providerID, + params.options, + model.info.limit.output, + OUTPUT_TOKEN_MAX, + ), + abortSignal: abort, + providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, params.options), + stopWhen: stepCountIs(1), + temperature: params.temperature, + topP: params.topP, + messages: [ + ...system.map( + (x): ModelMessage => ({ + role: "system", + content: x, + }), ), - abortSignal: abort, - providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, params.options), - stopWhen: stepCountIs(1), - temperature: params.temperature, - topP: params.topP, - messages: [ - ...system.map( - (x): ModelMessage => ({ - role: "system", - content: x, - }), - ), - ...MessageV2.toModelMessage( - msgs.filter((m) => { - if (m.info.role !== "assistant" || m.info.error === undefined) { - return true - } - if ( - MessageV2.AbortedError.isInstance(m.info.error) && - m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning") - ) { - return true - } + ...MessageV2.toModelMessage( + msgs.filter((m) => { + if (m.info.role !== "assistant" || m.info.error === undefined) { + return true + } + if ( + MessageV2.AbortedError.isInstance(m.info.error) && + m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning") + ) { + return true + } - return false - }), - ), - ], - tools: model.info.tool_call === false ? undefined : tools, - model: wrapLanguageModel({ - model: model.language, - middleware: [ - { - async transformParams(args) { - if (args.type === "stream") { - // @ts-expect-error - args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID) - } - // Transform tool schemas for provider compatibility - if (args.params.tools && Array.isArray(args.params.tools)) { - args.params.tools = args.params.tools.map((tool: any) => { - // Tools at middleware level have inputSchema, not parameters - if (tool.inputSchema && typeof tool.inputSchema === "object") { - // Transform the inputSchema for provider compatibility - return { - ...tool, - inputSchema: ProviderTransform.schema(model.providerID, model.modelID, tool.inputSchema), - } + return false + }), + ), + ], + tools: model.info.tool_call === false ? undefined : tools, + model: wrapLanguageModel({ + model: model.language, + middleware: [ + { + async transformParams(args) { + if (args.type === "stream") { + // @ts-expect-error + args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.info) + } + // Transform tool schemas for provider compatibility + if (args.params.tools && Array.isArray(args.params.tools)) { + args.params.tools = args.params.tools.map((tool: any) => { + // Tools at middleware level have inputSchema, not parameters + if (tool.inputSchema && typeof tool.inputSchema === "object") { + // Transform the inputSchema for provider compatibility + return { + ...tool, + inputSchema: ProviderTransform.schema(model.providerID, model.info, tool.inputSchema), } - // If no inputSchema, return tool unchanged - return tool - }) - } - return args.params - }, + } + // If no inputSchema, return tool unchanged + return tool + }) + } + return args.params }, - ], - }), + }, + ], }), - ) + }) if (result === "stop") break continue } @@ -646,14 +643,14 @@ export namespace SessionPrompt { system?: string agent: Agent.Info providerID: string - modelID: string + model: ModelsDev.Model }) { let system = SystemPrompt.header(input.providerID) system.push( ...(() => { if (input.system) return [input.system] if (input.agent.prompt) return [input.agent.prompt] - return SystemPrompt.provider(input.modelID) + return SystemPrompt.provider(input.model) })(), ) system.push(...(await SystemPrompt.environment())) @@ -666,10 +663,8 @@ export namespace SessionPrompt { async function resolveTools(input: { agent: Agent.Info - model: { - providerID: string - modelID: string - } + providerID: string + model: ModelsDev.Model sessionID: string tools?: Record processor: SessionProcessor.Info @@ -677,16 +672,12 @@ export namespace SessionPrompt { const tools: Record = {} const enabledTools = pipe( input.agent.tools, - mergeDeep(await ToolRegistry.enabled(input.model.providerID, input.model.modelID, input.agent)), + mergeDeep(await ToolRegistry.enabled(input.agent)), mergeDeep(input.tools ?? {}), ) - for (const item of await ToolRegistry.tools(input.model.providerID, input.model.modelID)) { + for (const item of await ToolRegistry.tools(input.providerID)) { if (Wildcard.all(item.id, enabledTools) === false) continue - const schema = ProviderTransform.schema( - input.model.providerID, - input.model.modelID, - z.toJSONSchema(item.parameters), - ) + const schema = ProviderTransform.schema(input.providerID, input.model, z.toJSONSchema(item.parameters)) tools[item.id] = tool({ id: item.id as any, description: item.description, @@ -1441,15 +1432,9 @@ export namespace SessionPrompt { const options = pipe( {}, mergeDeep( - ProviderTransform.options( - small.providerID, - small.modelID, - small.npm ?? "", - input.session.id, - provider?.options, - ), + ProviderTransform.options(small.providerID, small.info, small.npm ?? "", input.session.id, provider?.options), ), - mergeDeep(ProviderTransform.smallOptions({ providerID: small.providerID, modelID: small.modelID })), + mergeDeep(ProviderTransform.smallOptions({ providerID: small.providerID, model: small.info })), mergeDeep(small.info.options), ) await generateText({ diff --git a/packages/opencode/src/session/summary.ts b/packages/opencode/src/session/summary.ts index d9247f182d..9f56b084e0 100644 --- a/packages/opencode/src/session/summary.ts +++ b/packages/opencode/src/session/summary.ts @@ -79,8 +79,8 @@ export namespace SessionSummary { const options = pipe( {}, - mergeDeep(ProviderTransform.options(small.providerID, small.modelID, small.npm ?? "", assistantMsg.sessionID)), - mergeDeep(ProviderTransform.smallOptions({ providerID: small.providerID, modelID: small.modelID })), + mergeDeep(ProviderTransform.options(small.providerID, small.info, small.npm ?? "", assistantMsg.sessionID)), + mergeDeep(ProviderTransform.smallOptions({ providerID: small.providerID, model: small.info })), mergeDeep(small.info.options), ) diff --git a/packages/opencode/src/session/system.ts b/packages/opencode/src/session/system.ts index 399cad8cde..42b398948d 100644 --- a/packages/opencode/src/session/system.ts +++ b/packages/opencode/src/session/system.ts @@ -17,6 +17,7 @@ import PROMPT_COMPACTION from "./prompt/compaction.txt" import PROMPT_SUMMARIZE from "./prompt/summarize.txt" import PROMPT_TITLE from "./prompt/title.txt" import PROMPT_CODEX from "./prompt/codex.txt" +import type { ModelsDev } from "@/provider/models" export namespace SystemPrompt { export function header(providerID: string) { @@ -24,12 +25,13 @@ export namespace SystemPrompt { return [] } - export function provider(modelID: string) { - if (modelID.includes("gpt-5")) return [PROMPT_CODEX] - if (modelID.includes("gpt-") || modelID.includes("o1") || modelID.includes("o3")) return [PROMPT_BEAST] - if (modelID.includes("gemini-")) return [PROMPT_GEMINI] - if (modelID.includes("claude")) return [PROMPT_ANTHROPIC] - if (modelID.includes("polaris-alpha")) return [PROMPT_POLARIS] + export function provider(model: ModelsDev.Model) { + if (model.target.includes("gpt-5")) return [PROMPT_CODEX] + if (model.target.includes("gpt-") || model.target.includes("o1") || model.target.includes("o3")) + return [PROMPT_BEAST] + if (model.target.includes("gemini-")) return [PROMPT_GEMINI] + if (model.target.includes("claude")) return [PROMPT_ANTHROPIC] + if (model.target.includes("polaris-alpha")) return [PROMPT_POLARIS] return [PROMPT_ANTHROPIC_WITHOUT_TODO] } diff --git a/packages/opencode/src/tool/batch.ts b/packages/opencode/src/tool/batch.ts index 7d6449e7dc..cc61b090aa 100644 --- a/packages/opencode/src/tool/batch.ts +++ b/packages/opencode/src/tool/batch.ts @@ -37,7 +37,7 @@ export const BatchTool = Tool.define("batch", async () => { const discardedCalls = params.tool_calls.slice(10) const { ToolRegistry } = await import("./registry") - const availableTools = await ToolRegistry.tools("", "") + const availableTools = await ToolRegistry.tools("") const toolMap = new Map(availableTools.map((t) => [t.id, t])) const executeCall = async (call: (typeof toolCalls)[0]) => { diff --git a/packages/opencode/src/tool/registry.ts b/packages/opencode/src/tool/registry.ts index 26b6ea9fcf..33a54675ff 100644 --- a/packages/opencode/src/tool/registry.ts +++ b/packages/opencode/src/tool/registry.ts @@ -108,7 +108,7 @@ export namespace ToolRegistry { return all().then((x) => x.map((t) => t.id)) } - export async function tools(providerID: string, _modelID: string) { + export async function tools(providerID: string) { const tools = await all() const result = await Promise.all( tools @@ -124,11 +124,7 @@ export namespace ToolRegistry { return result } - export async function enabled( - _providerID: string, - _modelID: string, - agent: Agent.Info, - ): Promise> { + export async function enabled(agent: Agent.Info): Promise> { const result: Record = {} if (agent.permission.edit === "deny") { diff --git a/packages/sdk/js/src/gen/types.gen.ts b/packages/sdk/js/src/gen/types.gen.ts index 80348fb9ad..fcf04444ed 100644 --- a/packages/sdk/js/src/gen/types.gen.ts +++ b/packages/sdk/js/src/gen/types.gen.ts @@ -1110,6 +1110,7 @@ export type Config = { [key: string]: { id?: string name?: string + target?: string release_date?: string attachment?: boolean reasoning?: boolean @@ -1355,6 +1356,7 @@ export type Command = { export type Model = { id: string name: string + target: string release_date: string attachment: boolean reasoning: boolean