From 48cf07d32aac4b237c684b8c8765dc7d7ce651e1 Mon Sep 17 00:00:00 2001 From: Dax Raad Date: Tue, 2 Dec 2025 00:18:26 -0500 Subject: [PATCH] core: refactor model ID system to use target field for provider calls Changed model identification from using model.id to model.target when calling providers, allowing users to specify alternate model IDs while maintaining internal references. This enables more flexible provider configurations and better model mapping. --- packages/opencode/src/config/config.ts | 10 +- packages/opencode/src/provider/models.ts | 1 + packages/opencode/src/provider/provider.ts | 115 ++++------ packages/opencode/src/provider/transform.ts | 47 ++-- packages/opencode/src/server/server.ts | 4 +- packages/opencode/src/session/compaction.ts | 126 +++++------ packages/opencode/src/session/processor.ts | 15 +- packages/opencode/src/session/prompt.ts | 239 +++++++++----------- packages/opencode/src/session/summary.ts | 4 +- packages/opencode/src/session/system.ts | 14 +- packages/opencode/src/tool/batch.ts | 2 +- packages/opencode/src/tool/registry.ts | 8 +- packages/sdk/js/src/gen/types.gen.ts | 2 + 13 files changed, 286 insertions(+), 301 deletions(-) diff --git a/packages/opencode/src/config/config.ts b/packages/opencode/src/config/config.ts index 2bdbbca5b0..6ca8fe55be 100644 --- a/packages/opencode/src/config/config.ts +++ b/packages/opencode/src/config/config.ts @@ -542,7 +542,15 @@ export namespace Config { .extend({ whitelist: z.array(z.string()).optional(), blacklist: z.array(z.string()).optional(), - models: z.record(z.string(), ModelsDev.Model.partial()).optional(), + models: z + .record( + z.string(), + ModelsDev.Model.partial().refine( + (input) => input.id === undefined, + "The model.id field can no longer be specified. Use model.target to specify an alternate model id to use when calling the provider.", + ), + ) + .optional(), options: z .object({ apiKey: z.string().optional(), diff --git a/packages/opencode/src/provider/models.ts b/packages/opencode/src/provider/models.ts index 676837e152..f8ff2e86a4 100644 --- a/packages/opencode/src/provider/models.ts +++ b/packages/opencode/src/provider/models.ts @@ -13,6 +13,7 @@ export namespace ModelsDev { .object({ id: z.string(), name: z.string(), + target: z.string(), release_date: z.string(), attachment: z.boolean(), reasoning: z.boolean(), diff --git a/packages/opencode/src/provider/provider.ts b/packages/opencode/src/provider/provider.ts index 1123e6bbed..32750221eb 100644 --- a/packages/opencode/src/provider/provider.ts +++ b/packages/opencode/src/provider/provider.ts @@ -280,7 +280,7 @@ export namespace Provider { project, location, }, - async getModel(sdk: any, modelID: string) { + async getModel(sdk, modelID) { const id = String(modelID).trim() return sdk.languageModel(id) }, @@ -299,6 +299,14 @@ export namespace Provider { }, } + export type Model = { + providerID: string + modelID: string + language: LanguageModel + info: ModelsDev.Model + npm: string + } + const state = Instance.state(async () => { using _ = log.time("state") const config = await Config.get() @@ -321,19 +329,8 @@ export namespace Provider { options: Record } } = {} - const models = new Map< - string, - { - providerID: string - modelID: string - info: ModelsDev.Model - language: LanguageModel - npm?: string - } - >() + const models = new Map() const sdk = new Map() - // Maps `${provider}/${key}` to the provider’s actual model ID for custom aliases. - const realIdByKey = new Map() log.info("init") @@ -395,6 +392,7 @@ export namespace Provider { }) const parsedModel: ModelsDev.Model = { id: modelID, + target: model.target ?? existing?.target ?? modelID, name, release_date: model.release_date ?? existing?.release_date, attachment: model.attachment ?? existing?.attachment ?? false, @@ -432,9 +430,6 @@ export namespace Provider { headers: model.headers, provider: model.provider ?? existing?.provider, } - if (model.id && model.id !== modelID) { - realIdByKey.set(`${providerID}/${modelID}`, model.id) - } parsed.models[modelID] = parsedModel } @@ -528,31 +523,22 @@ export namespace Provider { } const configProvider = config.provider?.[providerID] - const filteredModels = Object.fromEntries( - Object.entries(provider.info.models) - // Filter out blacklisted models - .filter( - ([modelID]) => - modelID !== "gpt-5-chat-latest" && !(providerID === "openrouter" && modelID === "openai/gpt-5-chat"), - ) - // Filter out experimental models - .filter( - ([, model]) => - ((!model.experimental && model.status !== "alpha") || Flag.OPENCODE_ENABLE_EXPERIMENTAL_MODELS) && - model.status !== "deprecated", - ) - // Filter by provider's whitelist/blacklist from config - .filter(([modelID]) => { - if (!configProvider) return true - return ( - (!configProvider.blacklist || !configProvider.blacklist.includes(modelID)) && - (!configProvider.whitelist || configProvider.whitelist.includes(modelID)) - ) - }), - ) - - provider.info.models = filteredModels + for (const [modelID, model] of Object.entries(provider.info.models)) { + model.target = model.target ?? model.id ?? modelID + if (modelID === "gpt-5-chat-latest" || (providerID === "openrouter" && modelID === "openai/gpt-5-chat")) + delete provider.info.models[modelID] + if ( + ((model.status === "alpha" || model.experimental) && !Flag.OPENCODE_ENABLE_EXPERIMENTAL_MODELS) || + model.status === "deprecated" + ) + delete provider.info.models[modelID] + if ( + (configProvider?.blacklist && configProvider.blacklist.includes(modelID)) || + (configProvider?.whitelist && !configProvider.whitelist.includes(modelID)) + ) + delete provider.info.models[modelID] + } if (Object.keys(provider.info.models).length === 0) { delete providers[providerID] @@ -566,7 +552,6 @@ export namespace Provider { models, providers, sdk, - realIdByKey, } }) @@ -574,19 +559,18 @@ export namespace Provider { return state().then((state) => state.providers) } - async function getSDK(provider: ModelsDev.Provider, model: ModelsDev.Model) { - return (async () => { + async function getSDK(npm: string, providerID: string) { + try { using _ = log.time("getSDK", { - providerID: provider.id, + providerID, }) const s = await state() - const pkg = model.provider?.npm ?? provider.npm ?? provider.id - const options = { ...s.providers[provider.id]?.options } - if (pkg.includes("@ai-sdk/openai-compatible") && options["includeUsage"] === undefined) { + const options = { ...s.providers[providerID]?.options } + if (npm.includes("@ai-sdk/openai-compatible") && options["includeUsage"] !== false) { options["includeUsage"] = true } - const key = Bun.hash.xxHash32(JSON.stringify({ pkg, options })) + const key = Bun.hash.xxHash32(JSON.stringify({ pkg: npm, options })) const existing = s.sdk.get(key) if (existing) return existing @@ -615,12 +599,12 @@ export namespace Provider { } // Special case: google-vertex-anthropic uses a subpath import - const bundledKey = provider.id === "google-vertex-anthropic" ? "@ai-sdk/google-vertex/anthropic" : pkg + const bundledKey = providerID === "google-vertex-anthropic" ? "@ai-sdk/google-vertex/anthropic" : npm const bundledFn = BUNDLED_PROVIDERS[bundledKey] if (bundledFn) { - log.info("using bundled provider", { providerID: provider.id, pkg: bundledKey }) + log.info("using bundled provider", { providerID, pkg: bundledKey }) const loaded = bundledFn({ - name: provider.id, + name: providerID, ...options, }) s.sdk.set(key, loaded) @@ -628,25 +612,25 @@ export namespace Provider { } let installedPath: string - if (!pkg.startsWith("file://")) { - installedPath = await BunProc.install(pkg, "latest") + if (!npm.startsWith("file://")) { + installedPath = await BunProc.install(npm, "latest") } else { - log.info("loading local provider", { pkg }) - installedPath = pkg + log.info("loading local provider", { pkg: npm }) + installedPath = npm } const mod = await import(installedPath) const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!] const loaded = fn({ - name: provider.id, + name: providerID, ...options, }) s.sdk.set(key, loaded) return loaded as SDK - })().catch((e) => { - throw new InitError({ providerID: provider.id }, { cause: e }) - }) + } catch (e) { + throw new InitError({ providerID }, { cause: e }) + } } export async function getProvider(providerID: string) { @@ -679,28 +663,27 @@ export namespace Provider { throw new ModelNotFoundError({ providerID, modelID, suggestions }) } - const sdk = await getSDK(provider.info, info) + const npm = info.provider?.npm ?? provider.info.npm ?? info.id + const sdk = await getSDK(npm, providerID) try { - const keyReal = `${providerID}/${modelID}` - const realID = s.realIdByKey.get(keyReal) ?? info.id const language = provider.getModel - ? await provider.getModel(sdk, realID, provider.options) - : sdk.languageModel(realID) + ? await provider.getModel(sdk, info.target, provider.options) + : sdk.languageModel(info.target) log.info("found", { providerID, modelID }) s.models.set(key, { providerID, modelID, info, language, - npm: info.provider?.npm ?? provider.info.npm, + npm, }) return { modelID, providerID, info, language, - npm: info.provider?.npm ?? provider.info.npm, + npm, } } catch (e) { if (e instanceof NoSuchModelError) diff --git a/packages/opencode/src/provider/transform.ts b/packages/opencode/src/provider/transform.ts index abe269d5d0..d3b30575ad 100644 --- a/packages/opencode/src/provider/transform.ts +++ b/packages/opencode/src/provider/transform.ts @@ -1,10 +1,11 @@ import type { APICallError, ModelMessage } from "ai" import { unique } from "remeda" import type { JSONSchema } from "zod/v4/core" +import type { ModelsDev } from "./models" export namespace ProviderTransform { - function normalizeMessages(msgs: ModelMessage[], providerID: string, modelID: string): ModelMessage[] { - if (modelID.includes("claude")) { + function normalizeMessages(msgs: ModelMessage[], providerID: string, model: ModelsDev.Model): ModelMessage[] { + if (model.target.includes("claude")) { return msgs.map((msg) => { if ((msg.role === "assistant" || msg.role === "tool") && Array.isArray(msg.content)) { msg.content = msg.content.map((part) => { @@ -20,7 +21,7 @@ export namespace ProviderTransform { return msg }) } - if (providerID === "mistral" || modelID.toLowerCase().includes("mistral")) { + if (providerID === "mistral" || model.target.toLowerCase().includes("mistral")) { const result: ModelMessage[] = [] for (let i = 0; i < msgs.length; i++) { const msg = msgs[i] @@ -107,30 +108,30 @@ export namespace ProviderTransform { return msgs } - export function message(msgs: ModelMessage[], providerID: string, modelID: string) { - msgs = normalizeMessages(msgs, providerID, modelID) - if (providerID === "anthropic" || modelID.includes("anthropic") || modelID.includes("claude")) { + export function message(msgs: ModelMessage[], providerID: string, model: ModelsDev.Model) { + msgs = normalizeMessages(msgs, providerID, model) + if (providerID === "anthropic" || model.target.includes("anthropic") || model.target.includes("claude")) { msgs = applyCaching(msgs, providerID) } return msgs } - export function temperature(_providerID: string, modelID: string) { - if (modelID.toLowerCase().includes("qwen")) return 0.55 - if (modelID.toLowerCase().includes("claude")) return undefined - if (modelID.toLowerCase().includes("gemini-3-pro")) return 1.0 + export function temperature(model: ModelsDev.Model) { + if (model.target.toLowerCase().includes("qwen")) return 0.55 + if (model.target.toLowerCase().includes("claude")) return undefined + if (model.target.toLowerCase().includes("gemini-3-pro")) return 1.0 return 0 } - export function topP(_providerID: string, modelID: string) { - if (modelID.toLowerCase().includes("qwen")) return 1 + export function topP(model: ModelsDev.Model) { + if (model.target.toLowerCase().includes("qwen")) return 1 return undefined } export function options( providerID: string, - modelID: string, + model: ModelsDev.Model, npm: string, sessionID: string, providerOptions?: Record, @@ -148,22 +149,22 @@ export namespace ProviderTransform { result["promptCacheKey"] = sessionID } - if (providerID === "google" || (providerID.startsWith("opencode") && modelID.includes("gemini-3"))) { + if (providerID === "google" || (providerID.startsWith("opencode") && model.target.includes("gemini-3"))) { result["thinkingConfig"] = { includeThoughts: true, } } - if (modelID.includes("gpt-5") && !modelID.includes("gpt-5-chat")) { - if (modelID.includes("codex")) { + if (model.target.includes("gpt-5") && !model.target.includes("gpt-5-chat")) { + if (model.target.includes("codex")) { result["store"] = false } - if (!modelID.includes("codex") && !modelID.includes("gpt-5-pro")) { + if (!model.target.includes("codex") && !model.target.includes("gpt-5-pro")) { result["reasoningEffort"] = "medium" } - if (modelID.endsWith("gpt-5.1") && providerID !== "azure") { + if (model.target.endsWith("gpt-5.1") && providerID !== "azure") { result["textVerbosity"] = "low" } @@ -176,11 +177,11 @@ export namespace ProviderTransform { return result } - export function smallOptions(input: { providerID: string; modelID: string }) { + export function smallOptions(input: { providerID: string; model: ModelsDev.Model }) { const options: Record = {} - if (input.providerID === "openai" || input.modelID.includes("gpt-5")) { - if (input.modelID.includes("5.1")) { + if (input.providerID === "openai" || input.model.target.includes("gpt-5")) { + if (input.model.target.includes("5.1")) { options["reasoningEffort"] = "low" } else { options["reasoningEffort"] = "minimal" @@ -254,7 +255,7 @@ export namespace ProviderTransform { return standardLimit } - export function schema(providerID: string, modelID: string, schema: JSONSchema.BaseSchema) { + export function schema(providerID: string, model: ModelsDev.Model, schema: JSONSchema.BaseSchema) { /* if (["openai", "azure"].includes(providerID)) { if (schema.type === "object" && schema.properties) { @@ -274,7 +275,7 @@ export namespace ProviderTransform { */ // Convert integer enums to string enums for Google/Gemini - if (providerID === "google" || modelID.includes("gemini")) { + if (providerID === "google" || model.target.includes("gemini")) { const sanitizeGemini = (obj: any): any => { if (obj === null || typeof obj !== "object") { return obj diff --git a/packages/opencode/src/server/server.ts b/packages/opencode/src/server/server.ts index fe4ad195aa..4dfd3ac743 100644 --- a/packages/opencode/src/server/server.ts +++ b/packages/opencode/src/server/server.ts @@ -296,8 +296,8 @@ export namespace Server { }), ), async (c) => { - const { provider, model } = c.req.valid("query") - const tools = await ToolRegistry.tools(provider, model) + const { provider } = c.req.valid("query") + const tools = await ToolRegistry.tools(provider) return c.json( tools.map((t) => ({ id: t.id, diff --git a/packages/opencode/src/session/compaction.ts b/packages/opencode/src/session/compaction.ts index a6b71edcef..d4c9eb99a3 100644 --- a/packages/opencode/src/session/compaction.ts +++ b/packages/opencode/src/session/compaction.ts @@ -1,4 +1,4 @@ -import { streamText, wrapLanguageModel, type ModelMessage } from "ai" +import { wrapLanguageModel, type ModelMessage } from "ai" import { Session } from "." import { Identifier } from "../id/id" import { Instance } from "../project/instance" @@ -130,75 +130,73 @@ export namespace SessionCompaction { model: model.info, abort: input.abort, }) - const result = await processor.process(() => - streamText({ - onError(error) { - log.error("stream error", { - error, - }) - }, - // set to 0, we handle loop - maxRetries: 0, - providerOptions: ProviderTransform.providerOptions( - model.npm, - model.providerID, - pipe( - {}, - mergeDeep(ProviderTransform.options(model.providerID, model.modelID, model.npm ?? "", input.sessionID)), - mergeDeep(model.info.options), - ), + const result = await processor.process({ + onError(error) { + log.error("stream error", { + error, + }) + }, + // set to 0, we handle loop + maxRetries: 0, + providerOptions: ProviderTransform.providerOptions( + model.npm, + model.providerID, + pipe( + {}, + mergeDeep(ProviderTransform.options(model.providerID, model.info, model.npm ?? "", input.sessionID)), + mergeDeep(model.info.options), ), - headers: model.info.headers, - abortSignal: input.abort, - tools: model.info.tool_call ? {} : undefined, - messages: [ - ...system.map( - (x): ModelMessage => ({ - role: "system", - content: x, - }), - ), - ...MessageV2.toModelMessage( - input.messages.filter((m) => { - if (m.info.role !== "assistant" || m.info.error === undefined) { - return true - } - if ( - MessageV2.AbortedError.isInstance(m.info.error) && - m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning") - ) { - return true - } + ), + headers: model.info.headers, + abortSignal: input.abort, + tools: model.info.tool_call ? {} : undefined, + messages: [ + ...system.map( + (x): ModelMessage => ({ + role: "system", + content: x, + }), + ), + ...MessageV2.toModelMessage( + input.messages.filter((m) => { + if (m.info.role !== "assistant" || m.info.error === undefined) { + return true + } + if ( + MessageV2.AbortedError.isInstance(m.info.error) && + m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning") + ) { + return true + } - return false - }), - ), - { - role: "user", - content: [ - { - type: "text", - text: "Summarize our conversation above. This summary will be the only context available when the conversation continues, so preserve critical information including: what was accomplished, current work in progress, files involved, next steps, and any key user requests or constraints. Be concise but detailed enough that work can continue seamlessly.", - }, - ], - }, - ], - model: wrapLanguageModel({ - model: model.language, - middleware: [ + return false + }), + ), + { + role: "user", + content: [ { - async transformParams(args) { - if (args.type === "stream") { - // @ts-expect-error - args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID) - } - return args.params - }, + type: "text", + text: "Summarize our conversation above. This summary will be the only context available when the conversation continues, so preserve critical information including: what was accomplished, current work in progress, files involved, next steps, and any key user requests or constraints. Be concise but detailed enough that work can continue seamlessly.", }, ], - }), + }, + ], + model: wrapLanguageModel({ + model: model.language, + middleware: [ + { + async transformParams(args) { + if (args.type === "stream") { + // @ts-expect-error + args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID) + } + return args.params + }, + }, + ], }), - ) + }) if (result === "continue" && input.auto) { const continueMsg = await Session.updateMessage({ id: Identifier.ascending("message"), diff --git a/packages/opencode/src/session/processor.ts b/packages/opencode/src/session/processor.ts index 8655781d5e..5823d6191c 100644 --- a/packages/opencode/src/session/processor.ts +++ b/packages/opencode/src/session/processor.ts @@ -1,6 +1,6 @@ import type { ModelsDev } from "@/provider/models" import { MessageV2 } from "./message-v2" -import { type StreamTextResult, type Tool as AITool, APICallError } from "ai" +import { streamText } from "ai" import { Log } from "@/util/log" import { Identifier } from "@/id/id" import { Session } from "." @@ -19,6 +19,15 @@ export namespace SessionProcessor { export type Info = Awaited> export type Result = Awaited> + export type StreamInput = Parameters[0] + + export type TBD = { + model: { + modelID: string + providerID: string + } + } + export function create(input: { assistantMessage: MessageV2.Assistant sessionID: string @@ -38,13 +47,13 @@ export namespace SessionProcessor { partFromToolCall(toolCallID: string) { return toolcalls[toolCallID] }, - async process(fn: () => StreamTextResult, never>) { + async process(streamInput: StreamInput) { log.info("process") while (true) { try { let currentText: MessageV2.TextPart | undefined let reasoningMap: Record = {} - const stream = fn() + const stream = streamText(streamInput) for await (const value of stream.fullStream) { input.abort.throwIfAborted() diff --git a/packages/opencode/src/session/prompt.ts b/packages/opencode/src/session/prompt.ts index ee58bb3380..2433c582b4 100644 --- a/packages/opencode/src/session/prompt.ts +++ b/packages/opencode/src/session/prompt.ts @@ -11,7 +11,6 @@ import { Agent } from "../agent/agent" import { Provider } from "../provider/provider" import { generateText, - streamText, type ModelMessage, type Tool as AITool, tool, @@ -48,6 +47,7 @@ import { fn } from "@/util/fn" import { SessionProcessor } from "./processor" import { TaskTool } from "@/tool/task" import { SessionStatus } from "./status" +import type { ModelsDev } from "@/provider/models" // @ts-ignore globalThis.AI_SDK_LOG_WARNINGS = false @@ -469,14 +469,15 @@ export namespace SessionPrompt { }) const system = await resolveSystemPrompt({ providerID: model.providerID, - modelID: model.info.id, + model: model.info, agent, system: lastUser.system, }) const tools = await resolveTools({ agent, sessionID, - model: lastUser.model, + providerID: model.providerID, + model: model.info, tools: lastUser.tools, processor, }) @@ -492,14 +493,12 @@ export namespace SessionPrompt { }, { temperature: model.info.temperature - ? (agent.temperature ?? ProviderTransform.temperature(model.providerID, model.modelID)) + ? (agent.temperature ?? ProviderTransform.temperature(model.info)) : undefined, - topP: agent.topP ?? ProviderTransform.topP(model.providerID, model.modelID), + topP: agent.topP ?? ProviderTransform.topP(model.info), options: pipe( {}, - mergeDeep( - ProviderTransform.options(model.providerID, model.modelID, model.npm ?? "", sessionID, provider?.options), - ), + mergeDeep(ProviderTransform.options(model.providerID, model.info, model.npm, sessionID, provider?.options)), mergeDeep(model.info.options), mergeDeep(agent.options), ), @@ -513,113 +512,111 @@ export namespace SessionPrompt { }) } - const result = await processor.process(() => - streamText({ - onError(error) { - log.error("stream error", { - error, + const result = await processor.process({ + onError(error) { + log.error("stream error", { + error, + }) + }, + async experimental_repairToolCall(input) { + const lower = input.toolCall.toolName.toLowerCase() + if (lower !== input.toolCall.toolName && tools[lower]) { + log.info("repairing tool call", { + tool: input.toolCall.toolName, + repaired: lower, }) - }, - async experimental_repairToolCall(input) { - const lower = input.toolCall.toolName.toLowerCase() - if (lower !== input.toolCall.toolName && tools[lower]) { - log.info("repairing tool call", { - tool: input.toolCall.toolName, - repaired: lower, - }) - return { - ...input.toolCall, - toolName: lower, - } - } return { ...input.toolCall, - input: JSON.stringify({ - tool: input.toolCall.toolName, - error: input.error.message, - }), - toolName: "invalid", + toolName: lower, } - }, - headers: { - ...(model.providerID.startsWith("opencode") - ? { - "x-opencode-project": Instance.project.id, - "x-opencode-session": sessionID, - "x-opencode-request": lastUser.id, - } - : undefined), - ...model.info.headers, - }, - // set to 0, we handle loop - maxRetries: 0, - activeTools: Object.keys(tools).filter((x) => x !== "invalid"), - maxOutputTokens: ProviderTransform.maxOutputTokens( - model.providerID, - params.options, - model.info.limit.output, - OUTPUT_TOKEN_MAX, + } + return { + ...input.toolCall, + input: JSON.stringify({ + tool: input.toolCall.toolName, + error: input.error.message, + }), + toolName: "invalid", + } + }, + headers: { + ...(model.providerID.startsWith("opencode") + ? { + "x-opencode-project": Instance.project.id, + "x-opencode-session": sessionID, + "x-opencode-request": lastUser.id, + } + : undefined), + ...model.info.headers, + }, + // set to 0, we handle loop + maxRetries: 0, + activeTools: Object.keys(tools).filter((x) => x !== "invalid"), + maxOutputTokens: ProviderTransform.maxOutputTokens( + model.providerID, + params.options, + model.info.limit.output, + OUTPUT_TOKEN_MAX, + ), + abortSignal: abort, + providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, params.options), + stopWhen: stepCountIs(1), + temperature: params.temperature, + topP: params.topP, + messages: [ + ...system.map( + (x): ModelMessage => ({ + role: "system", + content: x, + }), ), - abortSignal: abort, - providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, params.options), - stopWhen: stepCountIs(1), - temperature: params.temperature, - topP: params.topP, - messages: [ - ...system.map( - (x): ModelMessage => ({ - role: "system", - content: x, - }), - ), - ...MessageV2.toModelMessage( - msgs.filter((m) => { - if (m.info.role !== "assistant" || m.info.error === undefined) { - return true - } - if ( - MessageV2.AbortedError.isInstance(m.info.error) && - m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning") - ) { - return true - } + ...MessageV2.toModelMessage( + msgs.filter((m) => { + if (m.info.role !== "assistant" || m.info.error === undefined) { + return true + } + if ( + MessageV2.AbortedError.isInstance(m.info.error) && + m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning") + ) { + return true + } - return false - }), - ), - ], - tools: model.info.tool_call === false ? undefined : tools, - model: wrapLanguageModel({ - model: model.language, - middleware: [ - { - async transformParams(args) { - if (args.type === "stream") { - // @ts-expect-error - args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID) - } - // Transform tool schemas for provider compatibility - if (args.params.tools && Array.isArray(args.params.tools)) { - args.params.tools = args.params.tools.map((tool: any) => { - // Tools at middleware level have inputSchema, not parameters - if (tool.inputSchema && typeof tool.inputSchema === "object") { - // Transform the inputSchema for provider compatibility - return { - ...tool, - inputSchema: ProviderTransform.schema(model.providerID, model.modelID, tool.inputSchema), - } + return false + }), + ), + ], + tools: model.info.tool_call === false ? undefined : tools, + model: wrapLanguageModel({ + model: model.language, + middleware: [ + { + async transformParams(args) { + if (args.type === "stream") { + // @ts-expect-error + args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.info) + } + // Transform tool schemas for provider compatibility + if (args.params.tools && Array.isArray(args.params.tools)) { + args.params.tools = args.params.tools.map((tool: any) => { + // Tools at middleware level have inputSchema, not parameters + if (tool.inputSchema && typeof tool.inputSchema === "object") { + // Transform the inputSchema for provider compatibility + return { + ...tool, + inputSchema: ProviderTransform.schema(model.providerID, model.info, tool.inputSchema), } - // If no inputSchema, return tool unchanged - return tool - }) - } - return args.params - }, + } + // If no inputSchema, return tool unchanged + return tool + }) + } + return args.params }, - ], - }), + }, + ], }), - ) + }) if (result === "stop") break continue } @@ -646,14 +643,14 @@ export namespace SessionPrompt { system?: string agent: Agent.Info providerID: string - modelID: string + model: ModelsDev.Model }) { let system = SystemPrompt.header(input.providerID) system.push( ...(() => { if (input.system) return [input.system] if (input.agent.prompt) return [input.agent.prompt] - return SystemPrompt.provider(input.modelID) + return SystemPrompt.provider(input.model) })(), ) system.push(...(await SystemPrompt.environment())) @@ -666,10 +663,8 @@ export namespace SessionPrompt { async function resolveTools(input: { agent: Agent.Info - model: { - providerID: string - modelID: string - } + providerID: string + model: ModelsDev.Model sessionID: string tools?: Record processor: SessionProcessor.Info @@ -677,16 +672,12 @@ export namespace SessionPrompt { const tools: Record = {} const enabledTools = pipe( input.agent.tools, - mergeDeep(await ToolRegistry.enabled(input.model.providerID, input.model.modelID, input.agent)), + mergeDeep(await ToolRegistry.enabled(input.agent)), mergeDeep(input.tools ?? {}), ) - for (const item of await ToolRegistry.tools(input.model.providerID, input.model.modelID)) { + for (const item of await ToolRegistry.tools(input.providerID)) { if (Wildcard.all(item.id, enabledTools) === false) continue - const schema = ProviderTransform.schema( - input.model.providerID, - input.model.modelID, - z.toJSONSchema(item.parameters), - ) + const schema = ProviderTransform.schema(input.providerID, input.model, z.toJSONSchema(item.parameters)) tools[item.id] = tool({ id: item.id as any, description: item.description, @@ -1441,15 +1432,9 @@ export namespace SessionPrompt { const options = pipe( {}, mergeDeep( - ProviderTransform.options( - small.providerID, - small.modelID, - small.npm ?? "", - input.session.id, - provider?.options, - ), + ProviderTransform.options(small.providerID, small.info, small.npm ?? "", input.session.id, provider?.options), ), - mergeDeep(ProviderTransform.smallOptions({ providerID: small.providerID, modelID: small.modelID })), + mergeDeep(ProviderTransform.smallOptions({ providerID: small.providerID, model: small.info })), mergeDeep(small.info.options), ) await generateText({ diff --git a/packages/opencode/src/session/summary.ts b/packages/opencode/src/session/summary.ts index d9247f182d..9f56b084e0 100644 --- a/packages/opencode/src/session/summary.ts +++ b/packages/opencode/src/session/summary.ts @@ -79,8 +79,8 @@ export namespace SessionSummary { const options = pipe( {}, - mergeDeep(ProviderTransform.options(small.providerID, small.modelID, small.npm ?? "", assistantMsg.sessionID)), - mergeDeep(ProviderTransform.smallOptions({ providerID: small.providerID, modelID: small.modelID })), + mergeDeep(ProviderTransform.options(small.providerID, small.info, small.npm ?? "", assistantMsg.sessionID)), + mergeDeep(ProviderTransform.smallOptions({ providerID: small.providerID, model: small.info })), mergeDeep(small.info.options), ) diff --git a/packages/opencode/src/session/system.ts b/packages/opencode/src/session/system.ts index 399cad8cde..42b398948d 100644 --- a/packages/opencode/src/session/system.ts +++ b/packages/opencode/src/session/system.ts @@ -17,6 +17,7 @@ import PROMPT_COMPACTION from "./prompt/compaction.txt" import PROMPT_SUMMARIZE from "./prompt/summarize.txt" import PROMPT_TITLE from "./prompt/title.txt" import PROMPT_CODEX from "./prompt/codex.txt" +import type { ModelsDev } from "@/provider/models" export namespace SystemPrompt { export function header(providerID: string) { @@ -24,12 +25,13 @@ export namespace SystemPrompt { return [] } - export function provider(modelID: string) { - if (modelID.includes("gpt-5")) return [PROMPT_CODEX] - if (modelID.includes("gpt-") || modelID.includes("o1") || modelID.includes("o3")) return [PROMPT_BEAST] - if (modelID.includes("gemini-")) return [PROMPT_GEMINI] - if (modelID.includes("claude")) return [PROMPT_ANTHROPIC] - if (modelID.includes("polaris-alpha")) return [PROMPT_POLARIS] + export function provider(model: ModelsDev.Model) { + if (model.target.includes("gpt-5")) return [PROMPT_CODEX] + if (model.target.includes("gpt-") || model.target.includes("o1") || model.target.includes("o3")) + return [PROMPT_BEAST] + if (model.target.includes("gemini-")) return [PROMPT_GEMINI] + if (model.target.includes("claude")) return [PROMPT_ANTHROPIC] + if (model.target.includes("polaris-alpha")) return [PROMPT_POLARIS] return [PROMPT_ANTHROPIC_WITHOUT_TODO] } diff --git a/packages/opencode/src/tool/batch.ts b/packages/opencode/src/tool/batch.ts index 7d6449e7dc..cc61b090aa 100644 --- a/packages/opencode/src/tool/batch.ts +++ b/packages/opencode/src/tool/batch.ts @@ -37,7 +37,7 @@ export const BatchTool = Tool.define("batch", async () => { const discardedCalls = params.tool_calls.slice(10) const { ToolRegistry } = await import("./registry") - const availableTools = await ToolRegistry.tools("", "") + const availableTools = await ToolRegistry.tools("") const toolMap = new Map(availableTools.map((t) => [t.id, t])) const executeCall = async (call: (typeof toolCalls)[0]) => { diff --git a/packages/opencode/src/tool/registry.ts b/packages/opencode/src/tool/registry.ts index 26b6ea9fcf..33a54675ff 100644 --- a/packages/opencode/src/tool/registry.ts +++ b/packages/opencode/src/tool/registry.ts @@ -108,7 +108,7 @@ export namespace ToolRegistry { return all().then((x) => x.map((t) => t.id)) } - export async function tools(providerID: string, _modelID: string) { + export async function tools(providerID: string) { const tools = await all() const result = await Promise.all( tools @@ -124,11 +124,7 @@ export namespace ToolRegistry { return result } - export async function enabled( - _providerID: string, - _modelID: string, - agent: Agent.Info, - ): Promise> { + export async function enabled(agent: Agent.Info): Promise> { const result: Record = {} if (agent.permission.edit === "deny") { diff --git a/packages/sdk/js/src/gen/types.gen.ts b/packages/sdk/js/src/gen/types.gen.ts index 80348fb9ad..fcf04444ed 100644 --- a/packages/sdk/js/src/gen/types.gen.ts +++ b/packages/sdk/js/src/gen/types.gen.ts @@ -1110,6 +1110,7 @@ export type Config = { [key: string]: { id?: string name?: string + target?: string release_date?: string attachment?: boolean reasoning?: boolean @@ -1355,6 +1356,7 @@ export type Command = { export type Model = { id: string name: string + target: string release_date: string attachment: boolean reasoning: boolean