fix(opencode): keep user message variants scoped to model (#21332)

pull/16657/merge
Dax 2026-04-07 10:12:53 -04:00 committed by GitHub
parent 01c5eb679c
commit 1f94c48bdd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 42 additions and 36 deletions

View File

@ -146,7 +146,7 @@ beforeAll(async () => {
add: (value: {
directory?: string
sessionID?: string
message: { agent: string; model: { providerID: string; modelID: string }; variant?: string }
message: { agent: string; model: { providerID: string; modelID: string; variant?: string } }
}) => {
optimistic.push(value)
optimisticSeeded.push(
@ -310,8 +310,7 @@ describe("prompt submit worktree selection", () => {
expect(optimistic[0]).toMatchObject({
message: {
agent: "agent",
model: { providerID: "provider", modelID: "model" },
variant: "high",
model: { providerID: "provider", modelID: "model", variant: "high" },
},
})
})

View File

@ -121,8 +121,7 @@ export async function sendFollowupDraft(input: FollowupSendInput) {
role: "user",
time: { created: Date.now() },
agent: input.draft.agent,
model: input.draft.model,
variant: input.draft.variant,
model: { ...input.draft.model, variant: input.draft.variant },
}
const add = () =>

View File

@ -11,7 +11,7 @@ import { cycleModelVariant, getConfiguredAgentVariant, resolveModelVariant } fro
import { useSDK } from "./sdk"
import { useSync } from "./sync"
export type ModelKey = { providerID: string; modelID: string }
export type ModelKey = { providerID: string; modelID: string; variant?: string }
type State = {
agent?: string
@ -373,7 +373,7 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
handoff.set(handoffKey(dir, session), next)
setStore("draft", undefined)
},
restore(msg: { sessionID: string; agent: string; model: ModelKey; variant?: string }) {
restore(msg: { sessionID: string; agent: string; model: ModelKey }) {
const session = id()
if (!session) return
if (msg.sessionID !== session) return
@ -383,7 +383,7 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
setSaved("session", session, {
agent: msg.agent,
model: msg.model,
variant: msg.variant ?? null,
variant: msg.model.variant ?? null,
})
},
},

View File

@ -416,8 +416,7 @@ export const { use: useSync, provider: SyncProvider } = createSimpleContext({
role: "user",
time: { created: Date.now() },
agent: input.agent,
model: input.model,
variant: input.variant,
model: { ...input.model, variant: input.variant },
}
const [, setStore] = target()
setOptimistic(sdk.directory, input.sessionID, { message, parts: input.parts })

View File

@ -2,7 +2,7 @@ import { describe, expect, test } from "bun:test"
import type { UserMessage } from "@opencode-ai/sdk/v2"
import { resetSessionModel, syncSessionModel } from "./session-model-helpers"
const message = (input?: Partial<Pick<UserMessage, "agent" | "model" | "variant">>) =>
const message = (input?: { agent?: string; model?: UserMessage["model"] }) =>
({
id: "msg",
sessionID: "session",
@ -10,7 +10,6 @@ const message = (input?: Partial<Pick<UserMessage, "agent" | "model" | "variant"
time: { created: 1 },
agent: input?.agent ?? "build",
model: input?.model ?? { providerID: "anthropic", modelID: "claude-sonnet-4" },
variant: input?.variant,
}) as UserMessage
describe("syncSessionModel", () => {
@ -26,10 +25,12 @@ describe("syncSessionModel", () => {
reset() {},
},
},
message({ variant: "high" }),
message({ model: { providerID: "anthropic", modelID: "claude-sonnet-4", variant: "high" } }),
)
expect(calls).toEqual([message({ variant: "high" })])
expect(calls).toEqual([
message({ model: { providerID: "anthropic", modelID: "claude-sonnet-4", variant: "high" } }),
])
})
})

View File

@ -23,7 +23,7 @@ import { useRenderer, type JSX } from "@opentui/solid"
import { Editor } from "@tui/util/editor"
import { useExit } from "../../context/exit"
import { Clipboard } from "../../util/clipboard"
import type { AssistantMessage, FilePart } from "@opencode-ai/sdk/v2"
import type { AssistantMessage, FilePart, UserMessage } from "@opencode-ai/sdk/v2"
import { TuiEvent } from "../../event"
import { iife } from "@/util/iife"
import { Locale } from "@/util/locale"
@ -145,7 +145,7 @@ export function Prompt(props: PromptProps) {
if (!props.sessionID) return undefined
const messages = sync.data.message[props.sessionID]
if (!messages) return undefined
return messages.findLast((m) => m.role === "user")
return messages.findLast((m): m is UserMessage => m.role === "user")
})
const usage = createMemo(() => {
@ -209,8 +209,10 @@ export function Prompt(props: PromptProps) {
const isPrimaryAgent = local.agent.list().some((x) => x.name === msg.agent)
if (msg.agent && isPrimaryAgent) {
local.agent.set(msg.agent)
if (msg.model) local.model.set(msg.model)
if (msg.variant) local.model.variant.set(msg.variant)
if (msg.model) {
local.model.set(msg.model)
local.model.variant.set(msg.model.variant)
}
}
}
})

View File

@ -228,7 +228,7 @@ When constructing the summary, try to stick to this template:
sessionID: input.sessionID,
mode: "compaction",
agent: "compaction",
variant: userMessage.variant,
variant: userMessage.model.variant,
summary: true,
path: {
cwd: ctx.directory,
@ -295,7 +295,6 @@ When constructing the summary, try to stick to this template:
format: original.format,
tools: original.tools,
system: original.system,
variant: original.variant,
})
for (const part of replay.parts) {
if (part.type === "compaction") continue

View File

@ -127,7 +127,9 @@ export namespace LLM {
}
const variant =
!input.small && input.model.variants && input.user.variant ? input.model.variants[input.user.variant] : {}
!input.small && input.model.variants && input.user.model.variant
? input.model.variants[input.user.model.variant]
: {}
const base = input.small
? ProviderTransform.smallOptions(input.model)
: ProviderTransform.options({

View File

@ -371,10 +371,10 @@ export namespace MessageV2 {
model: z.object({
providerID: ProviderID.zod,
modelID: ModelID.zod,
variant: z.string().optional(),
}),
system: z.string().optional(),
tools: z.record(z.string(), z.boolean()).optional(),
variant: z.string().optional(),
}).meta({
ref: "UserMessage",
})

View File

@ -569,7 +569,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
sessionID,
mode: task.agent,
agent: task.agent,
variant: lastUser.variant,
variant: lastUser.model.variant,
path: { cwd: ctx.directory, root: ctx.worktree },
cost: 0,
tokens: { input: 0, output: 0, reasoning: 0, cache: { read: 0, write: 0 } },
@ -967,17 +967,20 @@ NOTE: At any point in time through this workflow you should feel free to ask the
: undefined
const variant = input.variant ?? (ag.variant && full?.variants?.[ag.variant] ? ag.variant : undefined)
const info: MessageV2.Info = {
const info: MessageV2.User = {
id: input.messageID ?? MessageID.ascending(),
role: "user",
sessionID: input.sessionID,
time: { created: Date.now() },
tools: input.tools,
agent: ag.name,
model,
model: {
providerID: model.providerID,
modelID: model.modelID,
variant,
},
system: input.system,
format: input.format,
variant,
}
yield* Effect.addFinalizer(() =>
@ -1436,7 +1439,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
role: "assistant",
mode: agent.name,
agent: agent.name,
variant: lastUser.variant,
variant: lastUser.model.variant,
path: { cwd: ctx.directory, root: ctx.worktree },
cost: 0,
tokens: { input: 0, output: 0, reasoning: 0, cache: { read: 0, write: 0 } },

View File

@ -342,8 +342,7 @@ describe("session.llm.stream", () => {
role: "user",
time: { created: Date.now() },
agent: agent.name,
model: { providerID: ProviderID.make(providerID), modelID: resolved.id },
variant: "high",
model: { providerID: ProviderID.make(providerID), modelID: resolved.id, variant: "high" },
} satisfies MessageV2.User
const stream = await LLM.stream({
@ -716,8 +715,7 @@ describe("session.llm.stream", () => {
role: "user",
time: { created: Date.now() },
agent: agent.name,
model: { providerID: ProviderID.make("openai"), modelID: resolved.id },
variant: "high",
model: { providerID: ProviderID.make("openai"), modelID: resolved.id, variant: "high" },
} satisfies MessageV2.User
const stream = await LLM.stream({

View File

@ -410,7 +410,7 @@ describe("session.prompt agent variant", () => {
parts: [{ type: "text", text: "hello" }],
})
if (other.info.role !== "user") throw new Error("expected user message")
expect(other.info.variant).toBeUndefined()
expect(other.info.model.variant).toBeUndefined()
const match = await SessionPrompt.prompt({
sessionID: session.id,
@ -419,8 +419,12 @@ describe("session.prompt agent variant", () => {
parts: [{ type: "text", text: "hello again" }],
})
if (match.info.role !== "user") throw new Error("expected user message")
expect(match.info.model).toEqual({ providerID: ProviderID.make("openai"), modelID: ModelID.make("gpt-5.2") })
expect(match.info.variant).toBe("xhigh")
expect(match.info.model).toEqual({
providerID: ProviderID.make("openai"),
modelID: ModelID.make("gpt-5.2"),
variant: "xhigh",
})
expect(match.info.model.variant).toBe("xhigh")
const override = await SessionPrompt.prompt({
sessionID: session.id,
@ -430,7 +434,7 @@ describe("session.prompt agent variant", () => {
parts: [{ type: "text", text: "hello third" }],
})
if (override.info.role !== "user") throw new Error("expected user message")
expect(override.info.variant).toBe("high")
expect(override.info.model.variant).toBe("high")
await Session.remove(session.id)
},

View File

@ -548,12 +548,12 @@ export type UserMessage = {
model: {
providerID: string
modelID: string
variant?: string
}
system?: string
tools?: {
[key: string]: boolean
}
variant?: string
}
export type AssistantMessage = {