core: refactor model ID system to use target field for provider calls

Changed model identification from using model.id to model.target when calling providers, allowing users to specify alternate model IDs while maintaining internal references. This enables more flexible provider configurations and better model mapping.
provider-cleanup
Dax Raad 2025-12-02 00:18:26 -05:00
parent ee4437ff32
commit 48cf07d32a
13 changed files with 286 additions and 301 deletions

View File

@ -542,7 +542,15 @@ export namespace Config {
.extend({
whitelist: z.array(z.string()).optional(),
blacklist: z.array(z.string()).optional(),
models: z.record(z.string(), ModelsDev.Model.partial()).optional(),
models: z
.record(
z.string(),
ModelsDev.Model.partial().refine(
(input) => input.id === undefined,
"The model.id field can no longer be specified. Use model.target to specify an alternate model id to use when calling the provider.",
),
)
.optional(),
options: z
.object({
apiKey: z.string().optional(),

View File

@ -13,6 +13,7 @@ export namespace ModelsDev {
.object({
id: z.string(),
name: z.string(),
target: z.string(),
release_date: z.string(),
attachment: z.boolean(),
reasoning: z.boolean(),

View File

@ -280,7 +280,7 @@ export namespace Provider {
project,
location,
},
async getModel(sdk: any, modelID: string) {
async getModel(sdk, modelID) {
const id = String(modelID).trim()
return sdk.languageModel(id)
},
@ -299,6 +299,14 @@ export namespace Provider {
},
}
export type Model = {
providerID: string
modelID: string
language: LanguageModel
info: ModelsDev.Model
npm: string
}
const state = Instance.state(async () => {
using _ = log.time("state")
const config = await Config.get()
@ -321,19 +329,8 @@ export namespace Provider {
options: Record<string, any>
}
} = {}
const models = new Map<
string,
{
providerID: string
modelID: string
info: ModelsDev.Model
language: LanguageModel
npm?: string
}
>()
const models = new Map<string, Model>()
const sdk = new Map<number, SDK>()
// Maps `${provider}/${key}` to the providers actual model ID for custom aliases.
const realIdByKey = new Map<string, string>()
log.info("init")
@ -395,6 +392,7 @@ export namespace Provider {
})
const parsedModel: ModelsDev.Model = {
id: modelID,
target: model.target ?? existing?.target ?? modelID,
name,
release_date: model.release_date ?? existing?.release_date,
attachment: model.attachment ?? existing?.attachment ?? false,
@ -432,9 +430,6 @@ export namespace Provider {
headers: model.headers,
provider: model.provider ?? existing?.provider,
}
if (model.id && model.id !== modelID) {
realIdByKey.set(`${providerID}/${modelID}`, model.id)
}
parsed.models[modelID] = parsedModel
}
@ -528,31 +523,22 @@ export namespace Provider {
}
const configProvider = config.provider?.[providerID]
const filteredModels = Object.fromEntries(
Object.entries(provider.info.models)
// Filter out blacklisted models
.filter(
([modelID]) =>
modelID !== "gpt-5-chat-latest" && !(providerID === "openrouter" && modelID === "openai/gpt-5-chat"),
)
// Filter out experimental models
.filter(
([, model]) =>
((!model.experimental && model.status !== "alpha") || Flag.OPENCODE_ENABLE_EXPERIMENTAL_MODELS) &&
model.status !== "deprecated",
)
// Filter by provider's whitelist/blacklist from config
.filter(([modelID]) => {
if (!configProvider) return true
return (
(!configProvider.blacklist || !configProvider.blacklist.includes(modelID)) &&
(!configProvider.whitelist || configProvider.whitelist.includes(modelID))
)
}),
)
provider.info.models = filteredModels
for (const [modelID, model] of Object.entries(provider.info.models)) {
model.target = model.target ?? model.id ?? modelID
if (modelID === "gpt-5-chat-latest" || (providerID === "openrouter" && modelID === "openai/gpt-5-chat"))
delete provider.info.models[modelID]
if (
((model.status === "alpha" || model.experimental) && !Flag.OPENCODE_ENABLE_EXPERIMENTAL_MODELS) ||
model.status === "deprecated"
)
delete provider.info.models[modelID]
if (
(configProvider?.blacklist && configProvider.blacklist.includes(modelID)) ||
(configProvider?.whitelist && !configProvider.whitelist.includes(modelID))
)
delete provider.info.models[modelID]
}
if (Object.keys(provider.info.models).length === 0) {
delete providers[providerID]
@ -566,7 +552,6 @@ export namespace Provider {
models,
providers,
sdk,
realIdByKey,
}
})
@ -574,19 +559,18 @@ export namespace Provider {
return state().then((state) => state.providers)
}
async function getSDK(provider: ModelsDev.Provider, model: ModelsDev.Model) {
return (async () => {
async function getSDK(npm: string, providerID: string) {
try {
using _ = log.time("getSDK", {
providerID: provider.id,
providerID,
})
const s = await state()
const pkg = model.provider?.npm ?? provider.npm ?? provider.id
const options = { ...s.providers[provider.id]?.options }
if (pkg.includes("@ai-sdk/openai-compatible") && options["includeUsage"] === undefined) {
const options = { ...s.providers[providerID]?.options }
if (npm.includes("@ai-sdk/openai-compatible") && options["includeUsage"] !== false) {
options["includeUsage"] = true
}
const key = Bun.hash.xxHash32(JSON.stringify({ pkg, options }))
const key = Bun.hash.xxHash32(JSON.stringify({ pkg: npm, options }))
const existing = s.sdk.get(key)
if (existing) return existing
@ -615,12 +599,12 @@ export namespace Provider {
}
// Special case: google-vertex-anthropic uses a subpath import
const bundledKey = provider.id === "google-vertex-anthropic" ? "@ai-sdk/google-vertex/anthropic" : pkg
const bundledKey = providerID === "google-vertex-anthropic" ? "@ai-sdk/google-vertex/anthropic" : npm
const bundledFn = BUNDLED_PROVIDERS[bundledKey]
if (bundledFn) {
log.info("using bundled provider", { providerID: provider.id, pkg: bundledKey })
log.info("using bundled provider", { providerID, pkg: bundledKey })
const loaded = bundledFn({
name: provider.id,
name: providerID,
...options,
})
s.sdk.set(key, loaded)
@ -628,25 +612,25 @@ export namespace Provider {
}
let installedPath: string
if (!pkg.startsWith("file://")) {
installedPath = await BunProc.install(pkg, "latest")
if (!npm.startsWith("file://")) {
installedPath = await BunProc.install(npm, "latest")
} else {
log.info("loading local provider", { pkg })
installedPath = pkg
log.info("loading local provider", { pkg: npm })
installedPath = npm
}
const mod = await import(installedPath)
const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]
const loaded = fn({
name: provider.id,
name: providerID,
...options,
})
s.sdk.set(key, loaded)
return loaded as SDK
})().catch((e) => {
throw new InitError({ providerID: provider.id }, { cause: e })
})
} catch (e) {
throw new InitError({ providerID }, { cause: e })
}
}
export async function getProvider(providerID: string) {
@ -679,28 +663,27 @@ export namespace Provider {
throw new ModelNotFoundError({ providerID, modelID, suggestions })
}
const sdk = await getSDK(provider.info, info)
const npm = info.provider?.npm ?? provider.info.npm ?? info.id
const sdk = await getSDK(npm, providerID)
try {
const keyReal = `${providerID}/${modelID}`
const realID = s.realIdByKey.get(keyReal) ?? info.id
const language = provider.getModel
? await provider.getModel(sdk, realID, provider.options)
: sdk.languageModel(realID)
? await provider.getModel(sdk, info.target, provider.options)
: sdk.languageModel(info.target)
log.info("found", { providerID, modelID })
s.models.set(key, {
providerID,
modelID,
info,
language,
npm: info.provider?.npm ?? provider.info.npm,
npm,
})
return {
modelID,
providerID,
info,
language,
npm: info.provider?.npm ?? provider.info.npm,
npm,
}
} catch (e) {
if (e instanceof NoSuchModelError)

View File

@ -1,10 +1,11 @@
import type { APICallError, ModelMessage } from "ai"
import { unique } from "remeda"
import type { JSONSchema } from "zod/v4/core"
import type { ModelsDev } from "./models"
export namespace ProviderTransform {
function normalizeMessages(msgs: ModelMessage[], providerID: string, modelID: string): ModelMessage[] {
if (modelID.includes("claude")) {
function normalizeMessages(msgs: ModelMessage[], providerID: string, model: ModelsDev.Model): ModelMessage[] {
if (model.target.includes("claude")) {
return msgs.map((msg) => {
if ((msg.role === "assistant" || msg.role === "tool") && Array.isArray(msg.content)) {
msg.content = msg.content.map((part) => {
@ -20,7 +21,7 @@ export namespace ProviderTransform {
return msg
})
}
if (providerID === "mistral" || modelID.toLowerCase().includes("mistral")) {
if (providerID === "mistral" || model.target.toLowerCase().includes("mistral")) {
const result: ModelMessage[] = []
for (let i = 0; i < msgs.length; i++) {
const msg = msgs[i]
@ -107,30 +108,30 @@ export namespace ProviderTransform {
return msgs
}
export function message(msgs: ModelMessage[], providerID: string, modelID: string) {
msgs = normalizeMessages(msgs, providerID, modelID)
if (providerID === "anthropic" || modelID.includes("anthropic") || modelID.includes("claude")) {
export function message(msgs: ModelMessage[], providerID: string, model: ModelsDev.Model) {
msgs = normalizeMessages(msgs, providerID, model)
if (providerID === "anthropic" || model.target.includes("anthropic") || model.target.includes("claude")) {
msgs = applyCaching(msgs, providerID)
}
return msgs
}
export function temperature(_providerID: string, modelID: string) {
if (modelID.toLowerCase().includes("qwen")) return 0.55
if (modelID.toLowerCase().includes("claude")) return undefined
if (modelID.toLowerCase().includes("gemini-3-pro")) return 1.0
export function temperature(model: ModelsDev.Model) {
if (model.target.toLowerCase().includes("qwen")) return 0.55
if (model.target.toLowerCase().includes("claude")) return undefined
if (model.target.toLowerCase().includes("gemini-3-pro")) return 1.0
return 0
}
export function topP(_providerID: string, modelID: string) {
if (modelID.toLowerCase().includes("qwen")) return 1
export function topP(model: ModelsDev.Model) {
if (model.target.toLowerCase().includes("qwen")) return 1
return undefined
}
export function options(
providerID: string,
modelID: string,
model: ModelsDev.Model,
npm: string,
sessionID: string,
providerOptions?: Record<string, any>,
@ -148,22 +149,22 @@ export namespace ProviderTransform {
result["promptCacheKey"] = sessionID
}
if (providerID === "google" || (providerID.startsWith("opencode") && modelID.includes("gemini-3"))) {
if (providerID === "google" || (providerID.startsWith("opencode") && model.target.includes("gemini-3"))) {
result["thinkingConfig"] = {
includeThoughts: true,
}
}
if (modelID.includes("gpt-5") && !modelID.includes("gpt-5-chat")) {
if (modelID.includes("codex")) {
if (model.target.includes("gpt-5") && !model.target.includes("gpt-5-chat")) {
if (model.target.includes("codex")) {
result["store"] = false
}
if (!modelID.includes("codex") && !modelID.includes("gpt-5-pro")) {
if (!model.target.includes("codex") && !model.target.includes("gpt-5-pro")) {
result["reasoningEffort"] = "medium"
}
if (modelID.endsWith("gpt-5.1") && providerID !== "azure") {
if (model.target.endsWith("gpt-5.1") && providerID !== "azure") {
result["textVerbosity"] = "low"
}
@ -176,11 +177,11 @@ export namespace ProviderTransform {
return result
}
export function smallOptions(input: { providerID: string; modelID: string }) {
export function smallOptions(input: { providerID: string; model: ModelsDev.Model }) {
const options: Record<string, any> = {}
if (input.providerID === "openai" || input.modelID.includes("gpt-5")) {
if (input.modelID.includes("5.1")) {
if (input.providerID === "openai" || input.model.target.includes("gpt-5")) {
if (input.model.target.includes("5.1")) {
options["reasoningEffort"] = "low"
} else {
options["reasoningEffort"] = "minimal"
@ -254,7 +255,7 @@ export namespace ProviderTransform {
return standardLimit
}
export function schema(providerID: string, modelID: string, schema: JSONSchema.BaseSchema) {
export function schema(providerID: string, model: ModelsDev.Model, schema: JSONSchema.BaseSchema) {
/*
if (["openai", "azure"].includes(providerID)) {
if (schema.type === "object" && schema.properties) {
@ -274,7 +275,7 @@ export namespace ProviderTransform {
*/
// Convert integer enums to string enums for Google/Gemini
if (providerID === "google" || modelID.includes("gemini")) {
if (providerID === "google" || model.target.includes("gemini")) {
const sanitizeGemini = (obj: any): any => {
if (obj === null || typeof obj !== "object") {
return obj

View File

@ -296,8 +296,8 @@ export namespace Server {
}),
),
async (c) => {
const { provider, model } = c.req.valid("query")
const tools = await ToolRegistry.tools(provider, model)
const { provider } = c.req.valid("query")
const tools = await ToolRegistry.tools(provider)
return c.json(
tools.map((t) => ({
id: t.id,

View File

@ -1,4 +1,4 @@
import { streamText, wrapLanguageModel, type ModelMessage } from "ai"
import { wrapLanguageModel, type ModelMessage } from "ai"
import { Session } from "."
import { Identifier } from "../id/id"
import { Instance } from "../project/instance"
@ -130,75 +130,73 @@ export namespace SessionCompaction {
model: model.info,
abort: input.abort,
})
const result = await processor.process(() =>
streamText({
onError(error) {
log.error("stream error", {
error,
})
},
// set to 0, we handle loop
maxRetries: 0,
providerOptions: ProviderTransform.providerOptions(
model.npm,
model.providerID,
pipe(
{},
mergeDeep(ProviderTransform.options(model.providerID, model.modelID, model.npm ?? "", input.sessionID)),
mergeDeep(model.info.options),
),
const result = await processor.process({
onError(error) {
log.error("stream error", {
error,
})
},
// set to 0, we handle loop
maxRetries: 0,
providerOptions: ProviderTransform.providerOptions(
model.npm,
model.providerID,
pipe(
{},
mergeDeep(ProviderTransform.options(model.providerID, model.info, model.npm ?? "", input.sessionID)),
mergeDeep(model.info.options),
),
headers: model.info.headers,
abortSignal: input.abort,
tools: model.info.tool_call ? {} : undefined,
messages: [
...system.map(
(x): ModelMessage => ({
role: "system",
content: x,
}),
),
...MessageV2.toModelMessage(
input.messages.filter((m) => {
if (m.info.role !== "assistant" || m.info.error === undefined) {
return true
}
if (
MessageV2.AbortedError.isInstance(m.info.error) &&
m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
) {
return true
}
),
headers: model.info.headers,
abortSignal: input.abort,
tools: model.info.tool_call ? {} : undefined,
messages: [
...system.map(
(x): ModelMessage => ({
role: "system",
content: x,
}),
),
...MessageV2.toModelMessage(
input.messages.filter((m) => {
if (m.info.role !== "assistant" || m.info.error === undefined) {
return true
}
if (
MessageV2.AbortedError.isInstance(m.info.error) &&
m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
) {
return true
}
return false
}),
),
{
role: "user",
content: [
{
type: "text",
text: "Summarize our conversation above. This summary will be the only context available when the conversation continues, so preserve critical information including: what was accomplished, current work in progress, files involved, next steps, and any key user requests or constraints. Be concise but detailed enough that work can continue seamlessly.",
},
],
},
],
model: wrapLanguageModel({
model: model.language,
middleware: [
return false
}),
),
{
role: "user",
content: [
{
async transformParams(args) {
if (args.type === "stream") {
// @ts-expect-error
args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID)
}
return args.params
},
type: "text",
text: "Summarize our conversation above. This summary will be the only context available when the conversation continues, so preserve critical information including: what was accomplished, current work in progress, files involved, next steps, and any key user requests or constraints. Be concise but detailed enough that work can continue seamlessly.",
},
],
}),
},
],
model: wrapLanguageModel({
model: model.language,
middleware: [
{
async transformParams(args) {
if (args.type === "stream") {
// @ts-expect-error
args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID)
}
return args.params
},
},
],
}),
)
})
if (result === "continue" && input.auto) {
const continueMsg = await Session.updateMessage({
id: Identifier.ascending("message"),

View File

@ -1,6 +1,6 @@
import type { ModelsDev } from "@/provider/models"
import { MessageV2 } from "./message-v2"
import { type StreamTextResult, type Tool as AITool, APICallError } from "ai"
import { streamText } from "ai"
import { Log } from "@/util/log"
import { Identifier } from "@/id/id"
import { Session } from "."
@ -19,6 +19,15 @@ export namespace SessionProcessor {
export type Info = Awaited<ReturnType<typeof create>>
export type Result = Awaited<ReturnType<Info["process"]>>
export type StreamInput = Parameters<typeof streamText>[0]
export type TBD = {
model: {
modelID: string
providerID: string
}
}
export function create(input: {
assistantMessage: MessageV2.Assistant
sessionID: string
@ -38,13 +47,13 @@ export namespace SessionProcessor {
partFromToolCall(toolCallID: string) {
return toolcalls[toolCallID]
},
async process(fn: () => StreamTextResult<Record<string, AITool>, never>) {
async process(streamInput: StreamInput) {
log.info("process")
while (true) {
try {
let currentText: MessageV2.TextPart | undefined
let reasoningMap: Record<string, MessageV2.ReasoningPart> = {}
const stream = fn()
const stream = streamText(streamInput)
for await (const value of stream.fullStream) {
input.abort.throwIfAborted()

View File

@ -11,7 +11,6 @@ import { Agent } from "../agent/agent"
import { Provider } from "../provider/provider"
import {
generateText,
streamText,
type ModelMessage,
type Tool as AITool,
tool,
@ -48,6 +47,7 @@ import { fn } from "@/util/fn"
import { SessionProcessor } from "./processor"
import { TaskTool } from "@/tool/task"
import { SessionStatus } from "./status"
import type { ModelsDev } from "@/provider/models"
// @ts-ignore
globalThis.AI_SDK_LOG_WARNINGS = false
@ -469,14 +469,15 @@ export namespace SessionPrompt {
})
const system = await resolveSystemPrompt({
providerID: model.providerID,
modelID: model.info.id,
model: model.info,
agent,
system: lastUser.system,
})
const tools = await resolveTools({
agent,
sessionID,
model: lastUser.model,
providerID: model.providerID,
model: model.info,
tools: lastUser.tools,
processor,
})
@ -492,14 +493,12 @@ export namespace SessionPrompt {
},
{
temperature: model.info.temperature
? (agent.temperature ?? ProviderTransform.temperature(model.providerID, model.modelID))
? (agent.temperature ?? ProviderTransform.temperature(model.info))
: undefined,
topP: agent.topP ?? ProviderTransform.topP(model.providerID, model.modelID),
topP: agent.topP ?? ProviderTransform.topP(model.info),
options: pipe(
{},
mergeDeep(
ProviderTransform.options(model.providerID, model.modelID, model.npm ?? "", sessionID, provider?.options),
),
mergeDeep(ProviderTransform.options(model.providerID, model.info, model.npm, sessionID, provider?.options)),
mergeDeep(model.info.options),
mergeDeep(agent.options),
),
@ -513,113 +512,111 @@ export namespace SessionPrompt {
})
}
const result = await processor.process(() =>
streamText({
onError(error) {
log.error("stream error", {
error,
const result = await processor.process({
onError(error) {
log.error("stream error", {
error,
})
},
async experimental_repairToolCall(input) {
const lower = input.toolCall.toolName.toLowerCase()
if (lower !== input.toolCall.toolName && tools[lower]) {
log.info("repairing tool call", {
tool: input.toolCall.toolName,
repaired: lower,
})
},
async experimental_repairToolCall(input) {
const lower = input.toolCall.toolName.toLowerCase()
if (lower !== input.toolCall.toolName && tools[lower]) {
log.info("repairing tool call", {
tool: input.toolCall.toolName,
repaired: lower,
})
return {
...input.toolCall,
toolName: lower,
}
}
return {
...input.toolCall,
input: JSON.stringify({
tool: input.toolCall.toolName,
error: input.error.message,
}),
toolName: "invalid",
toolName: lower,
}
},
headers: {
...(model.providerID.startsWith("opencode")
? {
"x-opencode-project": Instance.project.id,
"x-opencode-session": sessionID,
"x-opencode-request": lastUser.id,
}
: undefined),
...model.info.headers,
},
// set to 0, we handle loop
maxRetries: 0,
activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
maxOutputTokens: ProviderTransform.maxOutputTokens(
model.providerID,
params.options,
model.info.limit.output,
OUTPUT_TOKEN_MAX,
}
return {
...input.toolCall,
input: JSON.stringify({
tool: input.toolCall.toolName,
error: input.error.message,
}),
toolName: "invalid",
}
},
headers: {
...(model.providerID.startsWith("opencode")
? {
"x-opencode-project": Instance.project.id,
"x-opencode-session": sessionID,
"x-opencode-request": lastUser.id,
}
: undefined),
...model.info.headers,
},
// set to 0, we handle loop
maxRetries: 0,
activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
maxOutputTokens: ProviderTransform.maxOutputTokens(
model.providerID,
params.options,
model.info.limit.output,
OUTPUT_TOKEN_MAX,
),
abortSignal: abort,
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, params.options),
stopWhen: stepCountIs(1),
temperature: params.temperature,
topP: params.topP,
messages: [
...system.map(
(x): ModelMessage => ({
role: "system",
content: x,
}),
),
abortSignal: abort,
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, params.options),
stopWhen: stepCountIs(1),
temperature: params.temperature,
topP: params.topP,
messages: [
...system.map(
(x): ModelMessage => ({
role: "system",
content: x,
}),
),
...MessageV2.toModelMessage(
msgs.filter((m) => {
if (m.info.role !== "assistant" || m.info.error === undefined) {
return true
}
if (
MessageV2.AbortedError.isInstance(m.info.error) &&
m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
) {
return true
}
...MessageV2.toModelMessage(
msgs.filter((m) => {
if (m.info.role !== "assistant" || m.info.error === undefined) {
return true
}
if (
MessageV2.AbortedError.isInstance(m.info.error) &&
m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
) {
return true
}
return false
}),
),
],
tools: model.info.tool_call === false ? undefined : tools,
model: wrapLanguageModel({
model: model.language,
middleware: [
{
async transformParams(args) {
if (args.type === "stream") {
// @ts-expect-error
args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID)
}
// Transform tool schemas for provider compatibility
if (args.params.tools && Array.isArray(args.params.tools)) {
args.params.tools = args.params.tools.map((tool: any) => {
// Tools at middleware level have inputSchema, not parameters
if (tool.inputSchema && typeof tool.inputSchema === "object") {
// Transform the inputSchema for provider compatibility
return {
...tool,
inputSchema: ProviderTransform.schema(model.providerID, model.modelID, tool.inputSchema),
}
return false
}),
),
],
tools: model.info.tool_call === false ? undefined : tools,
model: wrapLanguageModel({
model: model.language,
middleware: [
{
async transformParams(args) {
if (args.type === "stream") {
// @ts-expect-error
args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.info)
}
// Transform tool schemas for provider compatibility
if (args.params.tools && Array.isArray(args.params.tools)) {
args.params.tools = args.params.tools.map((tool: any) => {
// Tools at middleware level have inputSchema, not parameters
if (tool.inputSchema && typeof tool.inputSchema === "object") {
// Transform the inputSchema for provider compatibility
return {
...tool,
inputSchema: ProviderTransform.schema(model.providerID, model.info, tool.inputSchema),
}
// If no inputSchema, return tool unchanged
return tool
})
}
return args.params
},
}
// If no inputSchema, return tool unchanged
return tool
})
}
return args.params
},
],
}),
},
],
}),
)
})
if (result === "stop") break
continue
}
@ -646,14 +643,14 @@ export namespace SessionPrompt {
system?: string
agent: Agent.Info
providerID: string
modelID: string
model: ModelsDev.Model
}) {
let system = SystemPrompt.header(input.providerID)
system.push(
...(() => {
if (input.system) return [input.system]
if (input.agent.prompt) return [input.agent.prompt]
return SystemPrompt.provider(input.modelID)
return SystemPrompt.provider(input.model)
})(),
)
system.push(...(await SystemPrompt.environment()))
@ -666,10 +663,8 @@ export namespace SessionPrompt {
async function resolveTools(input: {
agent: Agent.Info
model: {
providerID: string
modelID: string
}
providerID: string
model: ModelsDev.Model
sessionID: string
tools?: Record<string, boolean>
processor: SessionProcessor.Info
@ -677,16 +672,12 @@ export namespace SessionPrompt {
const tools: Record<string, AITool> = {}
const enabledTools = pipe(
input.agent.tools,
mergeDeep(await ToolRegistry.enabled(input.model.providerID, input.model.modelID, input.agent)),
mergeDeep(await ToolRegistry.enabled(input.agent)),
mergeDeep(input.tools ?? {}),
)
for (const item of await ToolRegistry.tools(input.model.providerID, input.model.modelID)) {
for (const item of await ToolRegistry.tools(input.providerID)) {
if (Wildcard.all(item.id, enabledTools) === false) continue
const schema = ProviderTransform.schema(
input.model.providerID,
input.model.modelID,
z.toJSONSchema(item.parameters),
)
const schema = ProviderTransform.schema(input.providerID, input.model, z.toJSONSchema(item.parameters))
tools[item.id] = tool({
id: item.id as any,
description: item.description,
@ -1441,15 +1432,9 @@ export namespace SessionPrompt {
const options = pipe(
{},
mergeDeep(
ProviderTransform.options(
small.providerID,
small.modelID,
small.npm ?? "",
input.session.id,
provider?.options,
),
ProviderTransform.options(small.providerID, small.info, small.npm ?? "", input.session.id, provider?.options),
),
mergeDeep(ProviderTransform.smallOptions({ providerID: small.providerID, modelID: small.modelID })),
mergeDeep(ProviderTransform.smallOptions({ providerID: small.providerID, model: small.info })),
mergeDeep(small.info.options),
)
await generateText({

View File

@ -79,8 +79,8 @@ export namespace SessionSummary {
const options = pipe(
{},
mergeDeep(ProviderTransform.options(small.providerID, small.modelID, small.npm ?? "", assistantMsg.sessionID)),
mergeDeep(ProviderTransform.smallOptions({ providerID: small.providerID, modelID: small.modelID })),
mergeDeep(ProviderTransform.options(small.providerID, small.info, small.npm ?? "", assistantMsg.sessionID)),
mergeDeep(ProviderTransform.smallOptions({ providerID: small.providerID, model: small.info })),
mergeDeep(small.info.options),
)

View File

@ -17,6 +17,7 @@ import PROMPT_COMPACTION from "./prompt/compaction.txt"
import PROMPT_SUMMARIZE from "./prompt/summarize.txt"
import PROMPT_TITLE from "./prompt/title.txt"
import PROMPT_CODEX from "./prompt/codex.txt"
import type { ModelsDev } from "@/provider/models"
export namespace SystemPrompt {
export function header(providerID: string) {
@ -24,12 +25,13 @@ export namespace SystemPrompt {
return []
}
export function provider(modelID: string) {
if (modelID.includes("gpt-5")) return [PROMPT_CODEX]
if (modelID.includes("gpt-") || modelID.includes("o1") || modelID.includes("o3")) return [PROMPT_BEAST]
if (modelID.includes("gemini-")) return [PROMPT_GEMINI]
if (modelID.includes("claude")) return [PROMPT_ANTHROPIC]
if (modelID.includes("polaris-alpha")) return [PROMPT_POLARIS]
export function provider(model: ModelsDev.Model) {
if (model.target.includes("gpt-5")) return [PROMPT_CODEX]
if (model.target.includes("gpt-") || model.target.includes("o1") || model.target.includes("o3"))
return [PROMPT_BEAST]
if (model.target.includes("gemini-")) return [PROMPT_GEMINI]
if (model.target.includes("claude")) return [PROMPT_ANTHROPIC]
if (model.target.includes("polaris-alpha")) return [PROMPT_POLARIS]
return [PROMPT_ANTHROPIC_WITHOUT_TODO]
}

View File

@ -37,7 +37,7 @@ export const BatchTool = Tool.define("batch", async () => {
const discardedCalls = params.tool_calls.slice(10)
const { ToolRegistry } = await import("./registry")
const availableTools = await ToolRegistry.tools("", "")
const availableTools = await ToolRegistry.tools("")
const toolMap = new Map(availableTools.map((t) => [t.id, t]))
const executeCall = async (call: (typeof toolCalls)[0]) => {

View File

@ -108,7 +108,7 @@ export namespace ToolRegistry {
return all().then((x) => x.map((t) => t.id))
}
export async function tools(providerID: string, _modelID: string) {
export async function tools(providerID: string) {
const tools = await all()
const result = await Promise.all(
tools
@ -124,11 +124,7 @@ export namespace ToolRegistry {
return result
}
export async function enabled(
_providerID: string,
_modelID: string,
agent: Agent.Info,
): Promise<Record<string, boolean>> {
export async function enabled(agent: Agent.Info): Promise<Record<string, boolean>> {
const result: Record<string, boolean> = {}
if (agent.permission.edit === "deny") {

View File

@ -1110,6 +1110,7 @@ export type Config = {
[key: string]: {
id?: string
name?: string
target?: string
release_date?: string
attachment?: boolean
reasoning?: boolean
@ -1355,6 +1356,7 @@ export type Command = {
export type Model = {
id: string
name: string
target: string
release_date: string
attachment: boolean
reasoning: boolean