fix(opencode): keep user message variants scoped to model (#21332)
parent
01c5eb679c
commit
1f94c48bdd
|
|
@ -146,7 +146,7 @@ beforeAll(async () => {
|
|||
add: (value: {
|
||||
directory?: string
|
||||
sessionID?: string
|
||||
message: { agent: string; model: { providerID: string; modelID: string }; variant?: string }
|
||||
message: { agent: string; model: { providerID: string; modelID: string; variant?: string } }
|
||||
}) => {
|
||||
optimistic.push(value)
|
||||
optimisticSeeded.push(
|
||||
|
|
@ -310,8 +310,7 @@ describe("prompt submit worktree selection", () => {
|
|||
expect(optimistic[0]).toMatchObject({
|
||||
message: {
|
||||
agent: "agent",
|
||||
model: { providerID: "provider", modelID: "model" },
|
||||
variant: "high",
|
||||
model: { providerID: "provider", modelID: "model", variant: "high" },
|
||||
},
|
||||
})
|
||||
})
|
||||
|
|
|
|||
|
|
@ -121,8 +121,7 @@ export async function sendFollowupDraft(input: FollowupSendInput) {
|
|||
role: "user",
|
||||
time: { created: Date.now() },
|
||||
agent: input.draft.agent,
|
||||
model: input.draft.model,
|
||||
variant: input.draft.variant,
|
||||
model: { ...input.draft.model, variant: input.draft.variant },
|
||||
}
|
||||
|
||||
const add = () =>
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import { cycleModelVariant, getConfiguredAgentVariant, resolveModelVariant } fro
|
|||
import { useSDK } from "./sdk"
|
||||
import { useSync } from "./sync"
|
||||
|
||||
export type ModelKey = { providerID: string; modelID: string }
|
||||
export type ModelKey = { providerID: string; modelID: string; variant?: string }
|
||||
|
||||
type State = {
|
||||
agent?: string
|
||||
|
|
@ -373,7 +373,7 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
|
|||
handoff.set(handoffKey(dir, session), next)
|
||||
setStore("draft", undefined)
|
||||
},
|
||||
restore(msg: { sessionID: string; agent: string; model: ModelKey; variant?: string }) {
|
||||
restore(msg: { sessionID: string; agent: string; model: ModelKey }) {
|
||||
const session = id()
|
||||
if (!session) return
|
||||
if (msg.sessionID !== session) return
|
||||
|
|
@ -383,7 +383,7 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
|
|||
setSaved("session", session, {
|
||||
agent: msg.agent,
|
||||
model: msg.model,
|
||||
variant: msg.variant ?? null,
|
||||
variant: msg.model.variant ?? null,
|
||||
})
|
||||
},
|
||||
},
|
||||
|
|
|
|||
|
|
@ -416,8 +416,7 @@ export const { use: useSync, provider: SyncProvider } = createSimpleContext({
|
|||
role: "user",
|
||||
time: { created: Date.now() },
|
||||
agent: input.agent,
|
||||
model: input.model,
|
||||
variant: input.variant,
|
||||
model: { ...input.model, variant: input.variant },
|
||||
}
|
||||
const [, setStore] = target()
|
||||
setOptimistic(sdk.directory, input.sessionID, { message, parts: input.parts })
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import { describe, expect, test } from "bun:test"
|
|||
import type { UserMessage } from "@opencode-ai/sdk/v2"
|
||||
import { resetSessionModel, syncSessionModel } from "./session-model-helpers"
|
||||
|
||||
const message = (input?: Partial<Pick<UserMessage, "agent" | "model" | "variant">>) =>
|
||||
const message = (input?: { agent?: string; model?: UserMessage["model"] }) =>
|
||||
({
|
||||
id: "msg",
|
||||
sessionID: "session",
|
||||
|
|
@ -10,7 +10,6 @@ const message = (input?: Partial<Pick<UserMessage, "agent" | "model" | "variant"
|
|||
time: { created: 1 },
|
||||
agent: input?.agent ?? "build",
|
||||
model: input?.model ?? { providerID: "anthropic", modelID: "claude-sonnet-4" },
|
||||
variant: input?.variant,
|
||||
}) as UserMessage
|
||||
|
||||
describe("syncSessionModel", () => {
|
||||
|
|
@ -26,10 +25,12 @@ describe("syncSessionModel", () => {
|
|||
reset() {},
|
||||
},
|
||||
},
|
||||
message({ variant: "high" }),
|
||||
message({ model: { providerID: "anthropic", modelID: "claude-sonnet-4", variant: "high" } }),
|
||||
)
|
||||
|
||||
expect(calls).toEqual([message({ variant: "high" })])
|
||||
expect(calls).toEqual([
|
||||
message({ model: { providerID: "anthropic", modelID: "claude-sonnet-4", variant: "high" } }),
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ import { useRenderer, type JSX } from "@opentui/solid"
|
|||
import { Editor } from "@tui/util/editor"
|
||||
import { useExit } from "../../context/exit"
|
||||
import { Clipboard } from "../../util/clipboard"
|
||||
import type { AssistantMessage, FilePart } from "@opencode-ai/sdk/v2"
|
||||
import type { AssistantMessage, FilePart, UserMessage } from "@opencode-ai/sdk/v2"
|
||||
import { TuiEvent } from "../../event"
|
||||
import { iife } from "@/util/iife"
|
||||
import { Locale } from "@/util/locale"
|
||||
|
|
@ -145,7 +145,7 @@ export function Prompt(props: PromptProps) {
|
|||
if (!props.sessionID) return undefined
|
||||
const messages = sync.data.message[props.sessionID]
|
||||
if (!messages) return undefined
|
||||
return messages.findLast((m) => m.role === "user")
|
||||
return messages.findLast((m): m is UserMessage => m.role === "user")
|
||||
})
|
||||
|
||||
const usage = createMemo(() => {
|
||||
|
|
@ -209,8 +209,10 @@ export function Prompt(props: PromptProps) {
|
|||
const isPrimaryAgent = local.agent.list().some((x) => x.name === msg.agent)
|
||||
if (msg.agent && isPrimaryAgent) {
|
||||
local.agent.set(msg.agent)
|
||||
if (msg.model) local.model.set(msg.model)
|
||||
if (msg.variant) local.model.variant.set(msg.variant)
|
||||
if (msg.model) {
|
||||
local.model.set(msg.model)
|
||||
local.model.variant.set(msg.model.variant)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -228,7 +228,7 @@ When constructing the summary, try to stick to this template:
|
|||
sessionID: input.sessionID,
|
||||
mode: "compaction",
|
||||
agent: "compaction",
|
||||
variant: userMessage.variant,
|
||||
variant: userMessage.model.variant,
|
||||
summary: true,
|
||||
path: {
|
||||
cwd: ctx.directory,
|
||||
|
|
@ -295,7 +295,6 @@ When constructing the summary, try to stick to this template:
|
|||
format: original.format,
|
||||
tools: original.tools,
|
||||
system: original.system,
|
||||
variant: original.variant,
|
||||
})
|
||||
for (const part of replay.parts) {
|
||||
if (part.type === "compaction") continue
|
||||
|
|
|
|||
|
|
@ -127,7 +127,9 @@ export namespace LLM {
|
|||
}
|
||||
|
||||
const variant =
|
||||
!input.small && input.model.variants && input.user.variant ? input.model.variants[input.user.variant] : {}
|
||||
!input.small && input.model.variants && input.user.model.variant
|
||||
? input.model.variants[input.user.model.variant]
|
||||
: {}
|
||||
const base = input.small
|
||||
? ProviderTransform.smallOptions(input.model)
|
||||
: ProviderTransform.options({
|
||||
|
|
|
|||
|
|
@ -371,10 +371,10 @@ export namespace MessageV2 {
|
|||
model: z.object({
|
||||
providerID: ProviderID.zod,
|
||||
modelID: ModelID.zod,
|
||||
variant: z.string().optional(),
|
||||
}),
|
||||
system: z.string().optional(),
|
||||
tools: z.record(z.string(), z.boolean()).optional(),
|
||||
variant: z.string().optional(),
|
||||
}).meta({
|
||||
ref: "UserMessage",
|
||||
})
|
||||
|
|
|
|||
|
|
@ -569,7 +569,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
|
|||
sessionID,
|
||||
mode: task.agent,
|
||||
agent: task.agent,
|
||||
variant: lastUser.variant,
|
||||
variant: lastUser.model.variant,
|
||||
path: { cwd: ctx.directory, root: ctx.worktree },
|
||||
cost: 0,
|
||||
tokens: { input: 0, output: 0, reasoning: 0, cache: { read: 0, write: 0 } },
|
||||
|
|
@ -967,17 +967,20 @@ NOTE: At any point in time through this workflow you should feel free to ask the
|
|||
: undefined
|
||||
const variant = input.variant ?? (ag.variant && full?.variants?.[ag.variant] ? ag.variant : undefined)
|
||||
|
||||
const info: MessageV2.Info = {
|
||||
const info: MessageV2.User = {
|
||||
id: input.messageID ?? MessageID.ascending(),
|
||||
role: "user",
|
||||
sessionID: input.sessionID,
|
||||
time: { created: Date.now() },
|
||||
tools: input.tools,
|
||||
agent: ag.name,
|
||||
model,
|
||||
model: {
|
||||
providerID: model.providerID,
|
||||
modelID: model.modelID,
|
||||
variant,
|
||||
},
|
||||
system: input.system,
|
||||
format: input.format,
|
||||
variant,
|
||||
}
|
||||
|
||||
yield* Effect.addFinalizer(() =>
|
||||
|
|
@ -1436,7 +1439,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
|
|||
role: "assistant",
|
||||
mode: agent.name,
|
||||
agent: agent.name,
|
||||
variant: lastUser.variant,
|
||||
variant: lastUser.model.variant,
|
||||
path: { cwd: ctx.directory, root: ctx.worktree },
|
||||
cost: 0,
|
||||
tokens: { input: 0, output: 0, reasoning: 0, cache: { read: 0, write: 0 } },
|
||||
|
|
|
|||
|
|
@ -342,8 +342,7 @@ describe("session.llm.stream", () => {
|
|||
role: "user",
|
||||
time: { created: Date.now() },
|
||||
agent: agent.name,
|
||||
model: { providerID: ProviderID.make(providerID), modelID: resolved.id },
|
||||
variant: "high",
|
||||
model: { providerID: ProviderID.make(providerID), modelID: resolved.id, variant: "high" },
|
||||
} satisfies MessageV2.User
|
||||
|
||||
const stream = await LLM.stream({
|
||||
|
|
@ -716,8 +715,7 @@ describe("session.llm.stream", () => {
|
|||
role: "user",
|
||||
time: { created: Date.now() },
|
||||
agent: agent.name,
|
||||
model: { providerID: ProviderID.make("openai"), modelID: resolved.id },
|
||||
variant: "high",
|
||||
model: { providerID: ProviderID.make("openai"), modelID: resolved.id, variant: "high" },
|
||||
} satisfies MessageV2.User
|
||||
|
||||
const stream = await LLM.stream({
|
||||
|
|
|
|||
|
|
@ -410,7 +410,7 @@ describe("session.prompt agent variant", () => {
|
|||
parts: [{ type: "text", text: "hello" }],
|
||||
})
|
||||
if (other.info.role !== "user") throw new Error("expected user message")
|
||||
expect(other.info.variant).toBeUndefined()
|
||||
expect(other.info.model.variant).toBeUndefined()
|
||||
|
||||
const match = await SessionPrompt.prompt({
|
||||
sessionID: session.id,
|
||||
|
|
@ -419,8 +419,12 @@ describe("session.prompt agent variant", () => {
|
|||
parts: [{ type: "text", text: "hello again" }],
|
||||
})
|
||||
if (match.info.role !== "user") throw new Error("expected user message")
|
||||
expect(match.info.model).toEqual({ providerID: ProviderID.make("openai"), modelID: ModelID.make("gpt-5.2") })
|
||||
expect(match.info.variant).toBe("xhigh")
|
||||
expect(match.info.model).toEqual({
|
||||
providerID: ProviderID.make("openai"),
|
||||
modelID: ModelID.make("gpt-5.2"),
|
||||
variant: "xhigh",
|
||||
})
|
||||
expect(match.info.model.variant).toBe("xhigh")
|
||||
|
||||
const override = await SessionPrompt.prompt({
|
||||
sessionID: session.id,
|
||||
|
|
@ -430,7 +434,7 @@ describe("session.prompt agent variant", () => {
|
|||
parts: [{ type: "text", text: "hello third" }],
|
||||
})
|
||||
if (override.info.role !== "user") throw new Error("expected user message")
|
||||
expect(override.info.variant).toBe("high")
|
||||
expect(override.info.model.variant).toBe("high")
|
||||
|
||||
await Session.remove(session.id)
|
||||
},
|
||||
|
|
|
|||
|
|
@ -548,12 +548,12 @@ export type UserMessage = {
|
|||
model: {
|
||||
providerID: string
|
||||
modelID: string
|
||||
variant?: string
|
||||
}
|
||||
system?: string
|
||||
tools?: {
|
||||
[key: string]: boolean
|
||||
}
|
||||
variant?: string
|
||||
}
|
||||
|
||||
export type AssistantMessage = {
|
||||
|
|
|
|||
Loading…
Reference in New Issue