fix(opencode): keep user message variants scoped to model (#21332)

pull/16657/merge
Dax 2026-04-07 10:12:53 -04:00 committed by GitHub
parent 01c5eb679c
commit 1f94c48bdd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 42 additions and 36 deletions

View File

@ -146,7 +146,7 @@ beforeAll(async () => {
add: (value: { add: (value: {
directory?: string directory?: string
sessionID?: string sessionID?: string
message: { agent: string; model: { providerID: string; modelID: string }; variant?: string } message: { agent: string; model: { providerID: string; modelID: string; variant?: string } }
}) => { }) => {
optimistic.push(value) optimistic.push(value)
optimisticSeeded.push( optimisticSeeded.push(
@ -310,8 +310,7 @@ describe("prompt submit worktree selection", () => {
expect(optimistic[0]).toMatchObject({ expect(optimistic[0]).toMatchObject({
message: { message: {
agent: "agent", agent: "agent",
model: { providerID: "provider", modelID: "model" }, model: { providerID: "provider", modelID: "model", variant: "high" },
variant: "high",
}, },
}) })
}) })

View File

@ -121,8 +121,7 @@ export async function sendFollowupDraft(input: FollowupSendInput) {
role: "user", role: "user",
time: { created: Date.now() }, time: { created: Date.now() },
agent: input.draft.agent, agent: input.draft.agent,
model: input.draft.model, model: { ...input.draft.model, variant: input.draft.variant },
variant: input.draft.variant,
} }
const add = () => const add = () =>

View File

@ -11,7 +11,7 @@ import { cycleModelVariant, getConfiguredAgentVariant, resolveModelVariant } fro
import { useSDK } from "./sdk" import { useSDK } from "./sdk"
import { useSync } from "./sync" import { useSync } from "./sync"
export type ModelKey = { providerID: string; modelID: string } export type ModelKey = { providerID: string; modelID: string; variant?: string }
type State = { type State = {
agent?: string agent?: string
@ -373,7 +373,7 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
handoff.set(handoffKey(dir, session), next) handoff.set(handoffKey(dir, session), next)
setStore("draft", undefined) setStore("draft", undefined)
}, },
restore(msg: { sessionID: string; agent: string; model: ModelKey; variant?: string }) { restore(msg: { sessionID: string; agent: string; model: ModelKey }) {
const session = id() const session = id()
if (!session) return if (!session) return
if (msg.sessionID !== session) return if (msg.sessionID !== session) return
@ -383,7 +383,7 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
setSaved("session", session, { setSaved("session", session, {
agent: msg.agent, agent: msg.agent,
model: msg.model, model: msg.model,
variant: msg.variant ?? null, variant: msg.model.variant ?? null,
}) })
}, },
}, },

View File

@ -416,8 +416,7 @@ export const { use: useSync, provider: SyncProvider } = createSimpleContext({
role: "user", role: "user",
time: { created: Date.now() }, time: { created: Date.now() },
agent: input.agent, agent: input.agent,
model: input.model, model: { ...input.model, variant: input.variant },
variant: input.variant,
} }
const [, setStore] = target() const [, setStore] = target()
setOptimistic(sdk.directory, input.sessionID, { message, parts: input.parts }) setOptimistic(sdk.directory, input.sessionID, { message, parts: input.parts })

View File

@ -2,7 +2,7 @@ import { describe, expect, test } from "bun:test"
import type { UserMessage } from "@opencode-ai/sdk/v2" import type { UserMessage } from "@opencode-ai/sdk/v2"
import { resetSessionModel, syncSessionModel } from "./session-model-helpers" import { resetSessionModel, syncSessionModel } from "./session-model-helpers"
const message = (input?: Partial<Pick<UserMessage, "agent" | "model" | "variant">>) => const message = (input?: { agent?: string; model?: UserMessage["model"] }) =>
({ ({
id: "msg", id: "msg",
sessionID: "session", sessionID: "session",
@ -10,7 +10,6 @@ const message = (input?: Partial<Pick<UserMessage, "agent" | "model" | "variant"
time: { created: 1 }, time: { created: 1 },
agent: input?.agent ?? "build", agent: input?.agent ?? "build",
model: input?.model ?? { providerID: "anthropic", modelID: "claude-sonnet-4" }, model: input?.model ?? { providerID: "anthropic", modelID: "claude-sonnet-4" },
variant: input?.variant,
}) as UserMessage }) as UserMessage
describe("syncSessionModel", () => { describe("syncSessionModel", () => {
@ -26,10 +25,12 @@ describe("syncSessionModel", () => {
reset() {}, reset() {},
}, },
}, },
message({ variant: "high" }), message({ model: { providerID: "anthropic", modelID: "claude-sonnet-4", variant: "high" } }),
) )
expect(calls).toEqual([message({ variant: "high" })]) expect(calls).toEqual([
message({ model: { providerID: "anthropic", modelID: "claude-sonnet-4", variant: "high" } }),
])
}) })
}) })

View File

@ -23,7 +23,7 @@ import { useRenderer, type JSX } from "@opentui/solid"
import { Editor } from "@tui/util/editor" import { Editor } from "@tui/util/editor"
import { useExit } from "../../context/exit" import { useExit } from "../../context/exit"
import { Clipboard } from "../../util/clipboard" import { Clipboard } from "../../util/clipboard"
import type { AssistantMessage, FilePart } from "@opencode-ai/sdk/v2" import type { AssistantMessage, FilePart, UserMessage } from "@opencode-ai/sdk/v2"
import { TuiEvent } from "../../event" import { TuiEvent } from "../../event"
import { iife } from "@/util/iife" import { iife } from "@/util/iife"
import { Locale } from "@/util/locale" import { Locale } from "@/util/locale"
@ -145,7 +145,7 @@ export function Prompt(props: PromptProps) {
if (!props.sessionID) return undefined if (!props.sessionID) return undefined
const messages = sync.data.message[props.sessionID] const messages = sync.data.message[props.sessionID]
if (!messages) return undefined if (!messages) return undefined
return messages.findLast((m) => m.role === "user") return messages.findLast((m): m is UserMessage => m.role === "user")
}) })
const usage = createMemo(() => { const usage = createMemo(() => {
@ -209,8 +209,10 @@ export function Prompt(props: PromptProps) {
const isPrimaryAgent = local.agent.list().some((x) => x.name === msg.agent) const isPrimaryAgent = local.agent.list().some((x) => x.name === msg.agent)
if (msg.agent && isPrimaryAgent) { if (msg.agent && isPrimaryAgent) {
local.agent.set(msg.agent) local.agent.set(msg.agent)
if (msg.model) local.model.set(msg.model) if (msg.model) {
if (msg.variant) local.model.variant.set(msg.variant) local.model.set(msg.model)
local.model.variant.set(msg.model.variant)
}
} }
} }
}) })

View File

@ -228,7 +228,7 @@ When constructing the summary, try to stick to this template:
sessionID: input.sessionID, sessionID: input.sessionID,
mode: "compaction", mode: "compaction",
agent: "compaction", agent: "compaction",
variant: userMessage.variant, variant: userMessage.model.variant,
summary: true, summary: true,
path: { path: {
cwd: ctx.directory, cwd: ctx.directory,
@ -295,7 +295,6 @@ When constructing the summary, try to stick to this template:
format: original.format, format: original.format,
tools: original.tools, tools: original.tools,
system: original.system, system: original.system,
variant: original.variant,
}) })
for (const part of replay.parts) { for (const part of replay.parts) {
if (part.type === "compaction") continue if (part.type === "compaction") continue

View File

@ -127,7 +127,9 @@ export namespace LLM {
} }
const variant = const variant =
!input.small && input.model.variants && input.user.variant ? input.model.variants[input.user.variant] : {} !input.small && input.model.variants && input.user.model.variant
? input.model.variants[input.user.model.variant]
: {}
const base = input.small const base = input.small
? ProviderTransform.smallOptions(input.model) ? ProviderTransform.smallOptions(input.model)
: ProviderTransform.options({ : ProviderTransform.options({

View File

@ -371,10 +371,10 @@ export namespace MessageV2 {
model: z.object({ model: z.object({
providerID: ProviderID.zod, providerID: ProviderID.zod,
modelID: ModelID.zod, modelID: ModelID.zod,
variant: z.string().optional(),
}), }),
system: z.string().optional(), system: z.string().optional(),
tools: z.record(z.string(), z.boolean()).optional(), tools: z.record(z.string(), z.boolean()).optional(),
variant: z.string().optional(),
}).meta({ }).meta({
ref: "UserMessage", ref: "UserMessage",
}) })

View File

@ -569,7 +569,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
sessionID, sessionID,
mode: task.agent, mode: task.agent,
agent: task.agent, agent: task.agent,
variant: lastUser.variant, variant: lastUser.model.variant,
path: { cwd: ctx.directory, root: ctx.worktree }, path: { cwd: ctx.directory, root: ctx.worktree },
cost: 0, cost: 0,
tokens: { input: 0, output: 0, reasoning: 0, cache: { read: 0, write: 0 } }, tokens: { input: 0, output: 0, reasoning: 0, cache: { read: 0, write: 0 } },
@ -967,17 +967,20 @@ NOTE: At any point in time through this workflow you should feel free to ask the
: undefined : undefined
const variant = input.variant ?? (ag.variant && full?.variants?.[ag.variant] ? ag.variant : undefined) const variant = input.variant ?? (ag.variant && full?.variants?.[ag.variant] ? ag.variant : undefined)
const info: MessageV2.Info = { const info: MessageV2.User = {
id: input.messageID ?? MessageID.ascending(), id: input.messageID ?? MessageID.ascending(),
role: "user", role: "user",
sessionID: input.sessionID, sessionID: input.sessionID,
time: { created: Date.now() }, time: { created: Date.now() },
tools: input.tools, tools: input.tools,
agent: ag.name, agent: ag.name,
model, model: {
providerID: model.providerID,
modelID: model.modelID,
variant,
},
system: input.system, system: input.system,
format: input.format, format: input.format,
variant,
} }
yield* Effect.addFinalizer(() => yield* Effect.addFinalizer(() =>
@ -1436,7 +1439,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
role: "assistant", role: "assistant",
mode: agent.name, mode: agent.name,
agent: agent.name, agent: agent.name,
variant: lastUser.variant, variant: lastUser.model.variant,
path: { cwd: ctx.directory, root: ctx.worktree }, path: { cwd: ctx.directory, root: ctx.worktree },
cost: 0, cost: 0,
tokens: { input: 0, output: 0, reasoning: 0, cache: { read: 0, write: 0 } }, tokens: { input: 0, output: 0, reasoning: 0, cache: { read: 0, write: 0 } },

View File

@ -342,8 +342,7 @@ describe("session.llm.stream", () => {
role: "user", role: "user",
time: { created: Date.now() }, time: { created: Date.now() },
agent: agent.name, agent: agent.name,
model: { providerID: ProviderID.make(providerID), modelID: resolved.id }, model: { providerID: ProviderID.make(providerID), modelID: resolved.id, variant: "high" },
variant: "high",
} satisfies MessageV2.User } satisfies MessageV2.User
const stream = await LLM.stream({ const stream = await LLM.stream({
@ -716,8 +715,7 @@ describe("session.llm.stream", () => {
role: "user", role: "user",
time: { created: Date.now() }, time: { created: Date.now() },
agent: agent.name, agent: agent.name,
model: { providerID: ProviderID.make("openai"), modelID: resolved.id }, model: { providerID: ProviderID.make("openai"), modelID: resolved.id, variant: "high" },
variant: "high",
} satisfies MessageV2.User } satisfies MessageV2.User
const stream = await LLM.stream({ const stream = await LLM.stream({

View File

@ -410,7 +410,7 @@ describe("session.prompt agent variant", () => {
parts: [{ type: "text", text: "hello" }], parts: [{ type: "text", text: "hello" }],
}) })
if (other.info.role !== "user") throw new Error("expected user message") if (other.info.role !== "user") throw new Error("expected user message")
expect(other.info.variant).toBeUndefined() expect(other.info.model.variant).toBeUndefined()
const match = await SessionPrompt.prompt({ const match = await SessionPrompt.prompt({
sessionID: session.id, sessionID: session.id,
@ -419,8 +419,12 @@ describe("session.prompt agent variant", () => {
parts: [{ type: "text", text: "hello again" }], parts: [{ type: "text", text: "hello again" }],
}) })
if (match.info.role !== "user") throw new Error("expected user message") if (match.info.role !== "user") throw new Error("expected user message")
expect(match.info.model).toEqual({ providerID: ProviderID.make("openai"), modelID: ModelID.make("gpt-5.2") }) expect(match.info.model).toEqual({
expect(match.info.variant).toBe("xhigh") providerID: ProviderID.make("openai"),
modelID: ModelID.make("gpt-5.2"),
variant: "xhigh",
})
expect(match.info.model.variant).toBe("xhigh")
const override = await SessionPrompt.prompt({ const override = await SessionPrompt.prompt({
sessionID: session.id, sessionID: session.id,
@ -430,7 +434,7 @@ describe("session.prompt agent variant", () => {
parts: [{ type: "text", text: "hello third" }], parts: [{ type: "text", text: "hello third" }],
}) })
if (override.info.role !== "user") throw new Error("expected user message") if (override.info.role !== "user") throw new Error("expected user message")
expect(override.info.variant).toBe("high") expect(override.info.model.variant).toBe("high")
await Session.remove(session.id) await Session.remove(session.id)
}, },

View File

@ -548,12 +548,12 @@ export type UserMessage = {
model: { model: {
providerID: string providerID: string
modelID: string modelID: string
variant?: string
} }
system?: string system?: string
tools?: { tools?: {
[key: string]: boolean [key: string]: boolean
} }
variant?: string
} }
export type AssistantMessage = { export type AssistantMessage = {